1
1
import ast
2
2
import re
3
3
import types
4
- from typing import Optional , Set , Union
4
+ from typing import List , Optional , Set , Tuple , Union
5
5
6
6
COG_IMPORT_MODULES = {
7
7
"cog" ,
@@ -25,7 +25,7 @@ def load_module_from_string(
25
25
return module
26
26
27
27
28
- def extract_class_source (source_code : str , class_name : str ) -> str :
28
+ def extract_class_sources (source_code : str , class_names : List [ str ] ) -> List [ str ] :
29
29
"""
30
30
Extracts the source code for a specified class from a given source text.
31
31
Args:
@@ -35,23 +35,36 @@ def extract_class_source(source_code: str, class_name: str) -> str:
35
35
The source code of the specified class if found; otherwise, an empty string.
36
36
"""
37
37
class_name_pattern = re .compile (r"\b[A-Z]\w*\b" )
38
- all_class_names = class_name_pattern .findall (class_name )
38
+ all_class_names = []
39
+ for class_name in class_names :
40
+ all_class_names .extend (class_name_pattern .findall (class_name ))
39
41
40
42
class ClassExtractor (ast .NodeVisitor ):
41
43
def __init__ (self ) -> None :
42
- self .class_source = None
44
+ self .class_sources = []
43
45
44
46
def visit_ClassDef (self , node : ast .ClassDef ) -> None : # pylint: disable=invalid-name
45
- if node .name in all_class_names :
46
- self .class_source = ast .get_source_segment (source_code , node )
47
+ self .class_sources .append (node )
47
48
48
49
tree = ast .parse (source_code )
49
50
extractor = ClassExtractor ()
50
51
extractor .visit (tree )
51
- return extractor .class_source if extractor .class_source else ""
52
52
53
+ valid_class_names = set (all_class_names )
54
+ for node in extractor .class_sources :
55
+ if node .name not in valid_class_names :
56
+ continue
57
+ for base_name in node .bases :
58
+ valid_class_names .add (base_name .id )
53
59
54
- def extract_function_source (source_code : str , function_name : str ) -> str :
60
+ return [
61
+ str (ast .get_source_segment (source_code , x ))
62
+ for x in extractor .class_sources
63
+ if x .name in valid_class_names
64
+ ]
65
+
66
+
67
+ def extract_function_source (source_code : str , function_names : List [str ]) -> str :
55
68
"""
56
69
Extracts the source code for a specified function from a given source text.
57
70
Args:
@@ -63,20 +76,24 @@ def extract_function_source(source_code: str, function_name: str) -> str:
63
76
64
77
class FunctionExtractor (ast .NodeVisitor ):
65
78
def __init__ (self ) -> None :
66
- self .function_source = None
79
+ self .function_sources = []
67
80
68
81
def visit_FunctionDef (self , node : ast .FunctionDef ) -> None : # pylint: disable=invalid-name
69
- if node .name == function_name and not isinstance (node , ast .Module ):
82
+ if node .name in function_names and not isinstance (node , ast .Module ):
70
83
# Extract the source segment for this function definition
71
- self .function_source = ast .get_source_segment (source_code , node )
84
+ self .function_sources . append ( ast .get_source_segment (source_code , node ) )
72
85
73
86
tree = ast .parse (source_code )
74
87
extractor = FunctionExtractor ()
75
88
extractor .visit (tree )
76
- return extractor . function_source if extractor .function_source else ""
89
+ return " \n " . join ( extractor .function_sources )
77
90
78
91
79
- def make_class_methods_empty (source_code : Union [str , ast .AST ], class_name : str ) -> str :
92
+ def make_class_methods_empty (
93
+ source_code : Union [str , ast .AST ],
94
+ class_name : Optional [str ],
95
+ global_vars : List [ast .Assign ],
96
+ ) -> Tuple [str , List [ast .Assign ]]:
80
97
"""
81
98
Transforms the source code of a specified class to remove the bodies of all its methods
82
99
and replace them with 'return None'.
@@ -88,26 +105,47 @@ def make_class_methods_empty(source_code: Union[str, ast.AST], class_name: str)
88
105
"""
89
106
90
107
class MethodBodyTransformer (ast .NodeTransformer ):
108
+ def __init__ (self , global_vars : List [ast .Assign ]) -> None :
109
+ self .used_globals = set ()
110
+ self ._targets = {
111
+ target .id : global_name
112
+ for global_name in global_vars
113
+ for target in global_name .targets
114
+ if isinstance (target , ast .Name )
115
+ }
116
+
91
117
def visit_ClassDef (self , node : ast .ClassDef ) -> Optional [ast .AST ]: # pylint: disable=invalid-name
92
- if node .name == class_name :
118
+ if class_name is None or node .name == class_name :
93
119
for body_item in node .body :
94
120
if isinstance (body_item , ast .FunctionDef ):
95
121
# Replace the body of the method with `return None`
96
122
body_item .body = [ast .Return (value = ast .Constant (value = None ))]
123
+ # Remove decorators from the function
124
+ body_item .decorator_list = []
125
+ # Determine if one our globals is referenced by the function.
126
+ for default in body_item .args .defaults :
127
+ if isinstance (default , ast .Call ):
128
+ for keyword in default .keywords :
129
+ if isinstance (keyword .value , ast .Name ):
130
+ corresponding_global = self ._targets .get (
131
+ keyword .value .id
132
+ )
133
+ if corresponding_global is not None :
134
+ self .used_globals .add (corresponding_global )
97
135
return node
98
136
99
137
return None
100
138
101
139
tree = source_code if isinstance (source_code , ast .AST ) else ast .parse (source_code )
102
- transformer = MethodBodyTransformer ()
140
+ transformer = MethodBodyTransformer (global_vars )
103
141
transformed_tree = transformer .visit (tree )
104
142
class_code = ast .unparse (transformed_tree )
105
- return class_code
143
+ return class_code , list ( transformer . used_globals )
106
144
107
145
108
146
def extract_method_return_type (
109
- source_code : Union [str , ast .AST ], class_name : str , method_name : str
110
- ) -> Optional [str ]:
147
+ source_code : Union [str , ast .AST ], class_names : List [ str ], method_names : List [ str ]
148
+ ) -> List [str ]:
111
149
"""
112
150
Extracts the return type annotation of a specified method within a given class from the source code.
113
151
Args:
@@ -120,26 +158,26 @@ def extract_method_return_type(
120
158
121
159
class MethodReturnTypeExtractor (ast .NodeVisitor ):
122
160
def __init__ (self ) -> None :
123
- self .return_type = None
161
+ self .return_types = []
124
162
125
163
def visit_ClassDef (self , node : ast .ClassDef ) -> None : # pylint: disable=invalid-name
126
- if node .name == class_name :
164
+ if node .name in class_names :
127
165
self .generic_visit (node )
128
166
129
167
def visit_FunctionDef (self , node : ast .FunctionDef ) -> None : # pylint: disable=invalid-name
130
- if node .name == method_name and node .returns :
131
- self .return_type = ast .unparse (node .returns )
168
+ if node .name in method_names and node .returns :
169
+ self .return_types . append ( ast .unparse (node .returns ) )
132
170
133
171
tree = source_code if isinstance (source_code , ast .AST ) else ast .parse (source_code )
134
172
extractor = MethodReturnTypeExtractor ()
135
173
extractor .visit (tree )
136
174
137
- return extractor .return_type
175
+ return extractor .return_types
138
176
139
177
140
- def extract_function_return_type (
141
- source_code : Union [str , ast .AST ], function_name : str
142
- ) -> Optional [str ]:
178
+ def extract_function_return_types (
179
+ source_code : Union [str , ast .AST ], function_names : List [ str ]
180
+ ) -> List [str ]:
143
181
"""
144
182
Extracts the return type annotation of a specified function from the source code.
145
183
Args:
@@ -151,21 +189,23 @@ def extract_function_return_type(
151
189
152
190
class FunctionReturnTypeExtractor (ast .NodeVisitor ):
153
191
def __init__ (self ) -> None :
154
- self .return_type = None
192
+ self .return_types = []
155
193
156
194
def visit_FunctionDef (self , node : ast .FunctionDef ) -> None : # pylint: disable=invalid-name
157
- if node .name == function_name and node .returns :
195
+ if node .name in function_names and node .returns :
158
196
# Extract and return the string representation of the return type
159
- self .return_type = ast .unparse (node .returns )
197
+ self .return_types . append ( ast .unparse (node .returns ) )
160
198
161
199
tree = source_code if isinstance (source_code , ast .AST ) else ast .parse (source_code )
162
200
extractor = FunctionReturnTypeExtractor ()
163
201
extractor .visit (tree )
164
202
165
- return extractor .return_type
203
+ return extractor .return_types
166
204
167
205
168
- def make_function_empty (source_code : Union [str , ast .AST ], function_name : str ) -> str :
206
+ def make_function_empty (
207
+ source_code : Union [str , ast .AST ], function_names : List [str ]
208
+ ) -> str :
169
209
"""
170
210
Transforms the source code to remove the body of a specified function
171
211
and replace it with 'return None'.
@@ -178,7 +218,7 @@ def make_function_empty(source_code: Union[str, ast.AST], function_name: str) ->
178
218
179
219
class FunctionBodyTransformer (ast .NodeTransformer ):
180
220
def visit_FunctionDef (self , node : ast .FunctionDef ) -> Optional [ast .AST ]: # pylint: disable=invalid-name
181
- if node .name == function_name :
221
+ if node .name in function_names :
182
222
# Replace the body of the function with `return None`
183
223
node .body = [ast .Return (value = ast .Constant (value = None ))]
184
224
return node
@@ -224,8 +264,19 @@ def visit_ImportFrom(self, node: ast.ImportFrom) -> None: # pylint: disable=inv
224
264
return "\n " .join (extractor .imports )
225
265
226
266
267
+ def _extract_globals (source_code : Union [str , ast .AST ]) -> List [ast .Assign ]:
268
+ tree = source_code if isinstance (source_code , ast .AST ) else ast .parse (source_code )
269
+ if isinstance (tree , ast .Module ):
270
+ return [x for x in tree .body if isinstance (x , ast .Assign )]
271
+ return []
272
+
273
+
274
+ def _render_globals (global_vars : List [ast .Assign ]) -> str :
275
+ return "\n " .join ([ast .unparse (x ) for x in global_vars ])
276
+
277
+
227
278
def strip_model_source_code (
228
- source_code : str , class_name : str , method_name : str
279
+ source_code : str , class_names : List [ str ], method_names : List [ str ]
229
280
) -> Optional [str ]:
230
281
"""
231
282
Strips down the model source code by extracting relevant classes and making methods empty.
@@ -240,32 +291,43 @@ def strip_model_source_code(
240
291
Returns None if neither the class nor the function specified could be found or processed.
241
292
"""
242
293
imports = extract_specific_imports (source_code , COG_IMPORT_MODULES )
243
- class_source = (
244
- None if not class_name else extract_class_source (source_code , class_name )
294
+ class_sources = (
295
+ None if not class_names else extract_class_sources (source_code , class_names )
245
296
)
246
- if class_source :
247
- class_source = make_class_methods_empty ( class_source , class_name )
248
- return_type = extract_method_return_type ( class_source , class_name , method_name )
249
- return_class_source = (
250
- extract_class_source ( source_code , return_type ) if return_type else ""
297
+ global_vars = _extract_globals ( source_code )
298
+ if class_sources :
299
+ class_source = " \n " . join ( class_sources )
300
+ class_source , global_vars = make_class_methods_empty (
301
+ class_source , None , global_vars
251
302
)
252
- model_source = (
253
- imports + "\n \n " + return_class_source + "\n \n " + class_source + "\n "
303
+ return_types = extract_method_return_type (
304
+ class_source , class_names , method_names
305
+ )
306
+ return_class_sources = (
307
+ extract_class_sources (source_code , return_types ) if return_types else ""
308
+ )
309
+ return_class_source = "\n " .join (return_class_sources )
310
+ rendered_globals = _render_globals (global_vars )
311
+ model_source = "\n " .join (
312
+ [
313
+ x
314
+ for x in [imports , rendered_globals , return_class_source , class_source ]
315
+ if x
316
+ ]
254
317
)
255
318
else :
256
319
# use class_name specified in cog.yaml as method_name
257
- method_name = class_name
258
- function_source = extract_function_source (source_code , method_name )
320
+ method_names = class_names
321
+ function_source = extract_function_source (source_code , method_names )
259
322
if not function_source :
260
323
return None
261
- function_source = make_function_empty (function_source , method_name )
324
+ function_source = make_function_empty (function_source , method_names )
262
325
if not function_source :
263
326
return None
264
- return_type = extract_function_return_type (function_source , method_name )
265
- return_class_source = (
266
- extract_class_source (source_code , return_type ) if return_type else ""
267
- )
268
- model_source = (
269
- imports + "\n \n " + return_class_source + "\n \n " + function_source + "\n "
327
+ return_types = extract_function_return_types (function_source , method_names )
328
+ return_class_sources = (
329
+ extract_class_sources (source_code , return_types ) if return_types else ""
270
330
)
331
+ return_class_source = "\n " .join (return_class_sources )
332
+ model_source = "\n " .join ([imports , return_class_source , function_source ])
271
333
return model_source
0 commit comments