1
+ import ast
2
+ from typing import Union
3
+
4
+ from ...utils import walk_python_files
5
+ from ..models import Agent , Guardrail
6
+ from .ast_utils import (
7
+ find_decorator_by_name ,
8
+ )
9
+
10
+
11
+ class GuardrailsVisitor (ast .NodeVisitor ):
12
+ GUARDRAIL_DECORATOR_NAMES = ["input_guardrail" , "output_guardrail" ]
13
+ GUARDRAIL_CLASS_NAMES = ["InputGuardrail" , "OutputGuardrail" ]
14
+
15
+ def __init__ (self ) -> None :
16
+ super ().__init__ ()
17
+ self .guardrails : dict [str , Guardrail ] = {}
18
+ self .functions : dict [str , Union [ast .FunctionDef , ast .AsyncFunctionDef ]] = {}
19
+
20
+ def visit_FunctionDef (self , node ):
21
+ self .functions [node .name ] = node
22
+ self ._visit_any_function_def (node )
23
+
24
+ def visit_AsyncFunctionDef (self , node ):
25
+ self .functions [node .name ] = node
26
+ self ._visit_any_function_def (node )
27
+
28
+ def visit_Assign (self , node ):
29
+ """Tracks cases like:
30
+ guardrail = InputGuardrail(guardrail_function=function, name=guardrail_name)
31
+ """
32
+ if isinstance (node .value , ast .Call ):
33
+ call = node .value
34
+
35
+ if isinstance (call .func , ast .Name ) and call .func .id in self .GUARDRAIL_CLASS_NAMES :
36
+ class_name = call .func .id
37
+
38
+ guardrail_function_name = None
39
+ for kw in call .keywords :
40
+ if kw .arg == "guardrail_function" :
41
+ if isinstance (kw .value , ast .Name ):
42
+ guardrail_function_name = kw .value .id
43
+ elif isinstance (kw .value , ast .Attribute ):
44
+ guardrail_function_name = kw .value .attr # for something like module.guardrail_fn BUT SEE IF WE WANT THIS EVEN
45
+
46
+ if len (node .targets ) > 0 and isinstance (node .targets [0 ], ast .Name ):
47
+ guardrail_variable_name = node .targets [0 ].id
48
+ self .guardrails [guardrail_variable_name ] = Guardrail (
49
+ name = guardrail_variable_name ,
50
+ placement = "input" if class_name == self .GUARDRAIL_CLASS_NAMES [0 ] else "output" ,
51
+ uses_agent = False ,
52
+ guardrail_function_name = guardrail_function_name ,
53
+ _guardrail_function_def = None ,
54
+ agent_instructions = None
55
+ )
56
+
57
+
58
+ def _visit_any_function_def (
59
+ self , node : Union [ast .FunctionDef , ast .AsyncFunctionDef ]
60
+ ):
61
+ input_guardrail_decorator = find_decorator_by_name (node , self .GUARDRAIL_DECORATOR_NAMES [0 ])
62
+ output_guardrail_decorator = find_decorator_by_name (node , self .GUARDRAIL_DECORATOR_NAMES [1 ])
63
+
64
+ if input_guardrail_decorator is None and output_guardrail_decorator is None :
65
+ return
66
+
67
+ guardrail_name = node .name
68
+
69
+ guardrail = Guardrail (
70
+ name = guardrail_name ,
71
+ placement = "input" if input_guardrail_decorator is not None else "output" ,
72
+ uses_agent = False ,
73
+ guardrail_function_name = guardrail_name ,
74
+ agent_instructions = None
75
+ )
76
+
77
+ guardrail ._guardrail_function_def = node
78
+ self .guardrails [guardrail_name ] = guardrail
79
+
80
+ def analyze_guardrail (guardrail_name : str , guardrail : Guardrail , functions : dict [str , Union [ast .FunctionDef , ast .AsyncFunctionDef ]], agent_assignments : dict [str , Agent ]) -> None :
81
+ # Check if the guardrail function has Runner.run in it, as this indicates that an agent is used
82
+ class RunnerCallVisitor (ast .NodeVisitor ):
83
+ def __init__ (self ):
84
+ self .found = False
85
+ self .starting_agent = None
86
+
87
+ def check_call (self , call : ast .Call ):
88
+ if isinstance (call .func , ast .Attribute ):
89
+ if (isinstance (call .func .value , ast .Name ) and
90
+ call .func .value .id == "Runner" and
91
+ call .func .attr == "run" ):
92
+ if call .args :
93
+ self .found = True
94
+ first_arg = call .args [0 ]
95
+ if isinstance (first_arg , ast .Name ):
96
+ self .starting_agent = first_arg .id
97
+
98
+ def visit_Assign (self , node : ast .Assign ):
99
+ if isinstance (node .value , ast .Await ) and isinstance (node .value .value , ast .Call ):
100
+ self .check_call (node .value .value )
101
+ self .generic_visit (node )
102
+
103
+ def visit_Expr (self , node : ast .Expr ):
104
+ if isinstance (node .value , ast .Await ) and isinstance (node .value .value , ast .Call ):
105
+ self .check_call (node .value .value )
106
+ self .generic_visit (node )
107
+
108
+ if guardrail ._guardrail_function_def is None :
109
+ guardrail ._guardrail_function_def = functions [guardrail .guardrail_function_name ]
110
+
111
+ visitor = RunnerCallVisitor ()
112
+ visitor .visit (guardrail ._guardrail_function_def )
113
+
114
+ if visitor .found :
115
+ guardrail .uses_agent = True
116
+ if visitor .starting_agent in agent_assignments .keys ():
117
+ guardrail .agent_instructions = agent_assignments [visitor .starting_agent ].instructions
118
+ guardrail .agent_name = visitor .starting_agent
119
+ agent_assignments [visitor .starting_agent ].is_guardrail = True
120
+ else :
121
+ print ("Oops, failed to find agent" )
122
+
123
+
124
+
125
+
126
+ def collect_guardrails (root_dir : str , agent_assignments : dict [str , Agent ]) -> dict [str , Guardrail ]:
127
+ all_guardrails : dict [str , Guardrail ] = {}
128
+ for file in walk_python_files (root_dir ):
129
+ with open (file , "r" ) as f :
130
+ try :
131
+ tree = ast .parse (f .read ())
132
+ except Exception as e :
133
+ print (f"Cannot parse Python module: { file } . Error: { e } " )
134
+ continue
135
+ guardrails_visitor = GuardrailsVisitor ()
136
+ guardrails_visitor .visit (tree )
137
+ for guardrail_name , guardrail in guardrails_visitor .guardrails .items ():
138
+ analyze_guardrail (guardrail_name , guardrail , guardrails_visitor .functions , agent_assignments )
139
+ all_guardrails |= guardrails_visitor .guardrails
140
+
141
+ return all_guardrails
0 commit comments