Skip to content

Commit 95bb67f

Browse files
authored
Feat(agent-vulnerability): OpenAI Agents Improvements
* add: initial commit * feat: change template layout * fix: remove Tool Vulnerabilities section if no vulnerabilities are mapped * fix: ruff and mypy * add: initial commit * feat: change template layout * fix: remove Tool Vulnerabilities section if no vulnerabilities are mapped * fix: ruff and mypy * fix: openai client definition * fix: template changes * feat: add explanation column * fix: ruff * fix: stop analysis if there is no OpenAI key * fix: ruff * fix: ruff2 * feat: add agent vulnerability explanation table * feat: add mitigations and align to top
1 parent c10fe1a commit 95bb67f

File tree

14 files changed

+645
-66
lines changed

14 files changed

+645
-66
lines changed

agentic_radar/analysis/openai_agents/analyze.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,10 @@
44
from agentic_radar.analysis.openai_agents.graph import create_graph_definition
55
from agentic_radar.analysis.openai_agents.parsing import (
66
collect_agent_assignments,
7+
collect_guardrails,
78
collect_mcp_servers,
89
collect_tool_assignments,
10+
get_agent_vulnerabilities,
911
load_predefined_tools,
1012
)
1113
from agentic_radar.analysis.openai_agents.tool_categorizer.categorizer import (
@@ -28,10 +30,13 @@ def analyze(self, root_directory: str) -> GraphDefinition:
2830
predefined_tools=predefined_tools,
2931
mcp_servers=mcp_servers,
3032
)
33+
guardrails = collect_guardrails(root_directory, agent_assignments)
34+
get_agent_vulnerabilities(agent_assignments, guardrails)
3135
tool_categories = load_tool_categories()
3236
graph_definition = create_graph_definition(
3337
graph_name=Path(root_directory).name,
3438
agent_assignments=agent_assignments,
3539
tool_categories=tool_categories,
40+
guardrails=guardrails
3641
)
3742
return graph_definition

agentic_radar/analysis/openai_agents/graph.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
import json
22

3-
from agentic_radar.analysis.openai_agents.models import Agent, Tool
3+
from agentic_radar.analysis.openai_agents.models import Agent, Guardrail, Tool
44
from agentic_radar.graph import (
55
Agent as ReportAgent,
66
)
77
from agentic_radar.graph import (
8+
AgentVulnerabilityDefinition,
89
EdgeDefinition,
910
GraphDefinition,
1011
NodeDefinition,
@@ -17,6 +18,7 @@ def create_graph_definition(
1718
graph_name: str,
1819
agent_assignments: dict[str, Agent],
1920
tool_categories: dict[str, ToolType],
21+
guardrails: dict[str, Guardrail]
2022
) -> GraphDefinition:
2123
nodes = []
2224
edges = []
@@ -69,6 +71,12 @@ def create_graph_definition(
6971
graph_mcp_server_nodes.add(name)
7072

7173
edges.append(EdgeDefinition(start=agent.name, end=name, condition="mcp"))
74+
75+
for guardrail_name, guardrail in guardrails.items():
76+
if guardrail.uses_agent:
77+
if guardrail_name in agent.guardrails["input"] or guardrail_name in agent.guardrails["output"]:
78+
if guardrail.agent_name and (guardrail_agent:=agent_assignments.get(guardrail.agent_name)):
79+
edges.append(EdgeDefinition(start=agent.name, end=guardrail_agent.name))
7280

7381
nodes, edges = _add_start_end_nodes(nodes=nodes, edges=edges)
7482

@@ -77,6 +85,14 @@ def create_graph_definition(
7785
name=agent.name,
7886
llm=agent.model or "gpt-4o",
7987
system_prompt=agent.instructions or "",
88+
is_guardrail=agent.is_guardrail,
89+
vulnerabilities=[
90+
AgentVulnerabilityDefinition(
91+
name=vulnerability.name,
92+
mitigation_level=vulnerability.mitigation_level,
93+
guardrail_explanation=vulnerability.guardrail_explanation,
94+
instruction_explanation=vulnerability.instruction_explanation
95+
) for vulnerability in agent.vulnerabilities]
8096
)
8197
for agent in agent_assignments.values()
8298
]

agentic_radar/analysis/openai_agents/models.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
1+
import ast
12
from enum import Enum
2-
from typing import Optional
3+
from typing import Literal, Optional, Union
34

4-
from pydantic import BaseModel
5+
from pydantic import BaseModel, Field, PrivateAttr
56

67

78
class Tool(BaseModel):
@@ -22,10 +23,30 @@ class MCPServerInfo(BaseModel):
2223
params: dict[str, str] = {}
2324

2425

26+
class Guardrail(BaseModel):
27+
name: str
28+
placement: Literal['input', 'output']
29+
uses_agent: bool
30+
guardrail_function_name: str
31+
_guardrail_function_def: Optional[Union[ast.FunctionDef, ast.AsyncFunctionDef]] = PrivateAttr(default=None)
32+
agent_instructions: Optional[str] = None
33+
agent_name: Optional[str] = None
34+
35+
36+
class AgentVulnerability(BaseModel):
37+
name: str
38+
mitigation_level: Literal["None", "Partial", "Full"]
39+
guardrail_explanation: str
40+
instruction_explanation: str
41+
42+
2543
class Agent(BaseModel):
2644
name: str
2745
tools: list[Tool]
2846
handoffs: list[str]
2947
instructions: Optional[str] = None
3048
model: Optional[str] = None
3149
mcp_servers: list[MCPServerInfo] = []
50+
is_guardrail: bool = False
51+
guardrails: dict[str, list[str]] = Field(default_factory=dict)
52+
vulnerabilities: list[AgentVulnerability] = Field(default_factory=list)
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,15 @@
11
from .agents import collect_agent_assignments
2+
from .guardrails import collect_guardrails
23
from .mcp import collect_mcp_servers
34
from .predefined_tools import load_predefined_tools
45
from .tools import collect_tool_assignments
6+
from .vulnerabilities import get_agent_vulnerabilities
57

68
__all__ = [
79
"collect_agent_assignments",
810
"collect_tool_assignments",
911
"load_predefined_tools",
1012
"collect_mcp_servers",
13+
"collect_guardrails",
14+
"get_agent_vulnerabilities"
1115
]

agentic_radar/analysis/openai_agents/parsing/agents.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import ast
2+
from typing import Union
23

34
from pydantic import ValidationError
45

@@ -84,13 +85,16 @@ def _extract_agent(self, agent_node: ast.Call) -> Agent:
8485
model = get_string_keyword_arg(agent_node, "model")
8586
mcp_servers = self._extract_agent_mcp_servers(agent_node)
8687

88+
guardrails = self._extract_agent_guardrails(agent_node)
89+
8790
return Agent(
8891
name=name,
8992
tools=tools,
9093
handoffs=handoffs,
9194
instructions=instructions,
9295
model=model,
9396
mcp_servers=mcp_servers,
97+
guardrails=guardrails
9498
)
9599
except (ValueError, ValidationError, ValueError) as e:
96100
raise InvalidAgentConstructorError from e
@@ -127,6 +131,26 @@ def _extract_agent_tools(self, agent_node: ast.Call) -> list[Tool]:
127131
)
128132

129133
return tools
134+
135+
def _extract_agent_guardrails(self, agent_node: ast.Call) -> dict[str,list[str]]:
136+
def extract_guardrail_names(node: Union[ast.AST, None]) -> list[str]:
137+
if node is None:
138+
return []
139+
elif isinstance(node, ast.List):
140+
return [
141+
elt.id if isinstance(elt, ast.Name) else elt.attr if isinstance(elt, ast.Attribute) else None
142+
for elt in node.elts
143+
if isinstance(elt, (ast.Name, ast.Attribute))
144+
]
145+
return []
146+
147+
input_guardrails_node = get_keyword_arg_value(agent_node, "input_guardrails")
148+
output_guardrails_node = get_keyword_arg_value(agent_node, "output_guardrails")
149+
150+
return {
151+
"input": extract_guardrail_names(input_guardrails_node),
152+
"output": extract_guardrail_names(output_guardrails_node),
153+
}
130154

131155
def _extract_agent_handoffs(self, agent_node: ast.Call) -> list[str]:
132156
handoffs_node = get_keyword_arg_value(agent_node, "handoffs")
Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
import ast
2+
from typing import Union
3+
4+
from ...utils import walk_python_files
5+
from ..models import Agent, Guardrail
6+
from .ast_utils import (
7+
find_decorator_by_name,
8+
)
9+
10+
11+
class GuardrailsVisitor(ast.NodeVisitor):
12+
GUARDRAIL_DECORATOR_NAMES = ["input_guardrail", "output_guardrail"]
13+
GUARDRAIL_CLASS_NAMES = ["InputGuardrail", "OutputGuardrail"]
14+
15+
def __init__(self) -> None:
16+
super().__init__()
17+
self.guardrails: dict[str, Guardrail] = {}
18+
self.functions: dict[str, Union[ast.FunctionDef, ast.AsyncFunctionDef]] = {}
19+
20+
def visit_FunctionDef(self, node):
21+
self.functions[node.name] = node
22+
self._visit_any_function_def(node)
23+
24+
def visit_AsyncFunctionDef(self, node):
25+
self.functions[node.name] = node
26+
self._visit_any_function_def(node)
27+
28+
def visit_Assign(self, node):
29+
"""Tracks cases like:
30+
guardrail = InputGuardrail(guardrail_function=function, name=guardrail_name)
31+
"""
32+
if isinstance(node.value, ast.Call):
33+
call = node.value
34+
35+
if isinstance(call.func, ast.Name) and call.func.id in self.GUARDRAIL_CLASS_NAMES:
36+
class_name = call.func.id
37+
38+
guardrail_function_name = None
39+
for kw in call.keywords:
40+
if kw.arg == "guardrail_function":
41+
if isinstance(kw.value, ast.Name):
42+
guardrail_function_name = kw.value.id
43+
elif isinstance(kw.value, ast.Attribute):
44+
guardrail_function_name = kw.value.attr # for something like module.guardrail_fn BUT SEE IF WE WANT THIS EVEN
45+
46+
if len(node.targets) > 0 and isinstance(node.targets[0], ast.Name):
47+
guardrail_variable_name = node.targets[0].id
48+
self.guardrails[guardrail_variable_name] = Guardrail(
49+
name=guardrail_variable_name,
50+
placement="input" if class_name==self.GUARDRAIL_CLASS_NAMES[0] else "output",
51+
uses_agent=False,
52+
guardrail_function_name=guardrail_function_name,
53+
_guardrail_function_def=None,
54+
agent_instructions=None
55+
)
56+
57+
58+
def _visit_any_function_def(
59+
self, node: Union[ast.FunctionDef, ast.AsyncFunctionDef]
60+
):
61+
input_guardrail_decorator = find_decorator_by_name(node, self.GUARDRAIL_DECORATOR_NAMES[0])
62+
output_guardrail_decorator = find_decorator_by_name(node, self.GUARDRAIL_DECORATOR_NAMES[1])
63+
64+
if input_guardrail_decorator is None and output_guardrail_decorator is None:
65+
return
66+
67+
guardrail_name = node.name
68+
69+
guardrail = Guardrail(
70+
name=guardrail_name,
71+
placement="input" if input_guardrail_decorator is not None else "output",
72+
uses_agent=False,
73+
guardrail_function_name=guardrail_name,
74+
agent_instructions=None
75+
)
76+
77+
guardrail._guardrail_function_def = node
78+
self.guardrails[guardrail_name] = guardrail
79+
80+
def analyze_guardrail(guardrail_name: str, guardrail: Guardrail, functions: dict[str, Union[ast.FunctionDef, ast.AsyncFunctionDef]], agent_assignments: dict[str, Agent]) -> None:
81+
# Check if the guardrail function has Runner.run in it, as this indicates that an agent is used
82+
class RunnerCallVisitor(ast.NodeVisitor):
83+
def __init__(self):
84+
self.found = False
85+
self.starting_agent = None
86+
87+
def check_call(self, call: ast.Call):
88+
if isinstance(call.func, ast.Attribute):
89+
if (isinstance(call.func.value, ast.Name) and
90+
call.func.value.id == "Runner" and
91+
call.func.attr == "run"):
92+
if call.args:
93+
self.found = True
94+
first_arg = call.args[0]
95+
if isinstance(first_arg, ast.Name):
96+
self.starting_agent = first_arg.id
97+
98+
def visit_Assign(self, node: ast.Assign):
99+
if isinstance(node.value, ast.Await) and isinstance(node.value.value, ast.Call):
100+
self.check_call(node.value.value)
101+
self.generic_visit(node)
102+
103+
def visit_Expr(self, node: ast.Expr):
104+
if isinstance(node.value, ast.Await) and isinstance(node.value.value, ast.Call):
105+
self.check_call(node.value.value)
106+
self.generic_visit(node)
107+
108+
if guardrail._guardrail_function_def is None:
109+
guardrail._guardrail_function_def = functions[guardrail.guardrail_function_name]
110+
111+
visitor = RunnerCallVisitor()
112+
visitor.visit(guardrail._guardrail_function_def)
113+
114+
if visitor.found:
115+
guardrail.uses_agent = True
116+
if visitor.starting_agent in agent_assignments.keys():
117+
guardrail.agent_instructions = agent_assignments[visitor.starting_agent].instructions
118+
guardrail.agent_name = visitor.starting_agent
119+
agent_assignments[visitor.starting_agent].is_guardrail = True
120+
else:
121+
print("Oops, failed to find agent")
122+
123+
124+
125+
126+
def collect_guardrails(root_dir: str, agent_assignments: dict[str, Agent]) -> dict[str, Guardrail]:
127+
all_guardrails: dict[str, Guardrail] = {}
128+
for file in walk_python_files(root_dir):
129+
with open(file, "r") as f:
130+
try:
131+
tree = ast.parse(f.read())
132+
except Exception as e:
133+
print(f"Cannot parse Python module: {file}. Error: {e}")
134+
continue
135+
guardrails_visitor = GuardrailsVisitor()
136+
guardrails_visitor.visit(tree)
137+
for guardrail_name, guardrail in guardrails_visitor.guardrails.items():
138+
analyze_guardrail(guardrail_name, guardrail, guardrails_visitor.functions, agent_assignments)
139+
all_guardrails |= guardrails_visitor.guardrails
140+
141+
return all_guardrails

0 commit comments

Comments
 (0)