Skip to content

Commit 176e123

Browse files
authored
Support more cases in code_xforms (#1996)
* Support more cases in code_xforms * Support multiple class names and functions to handle extracting both prediction and training functions. * Handle globals referenced in function inputs. * Handle subclasses referenced in the major class. * Add test_strip_model_source_code_keeps_referenced_class_from_function * Confirm that this behaviour is consistent with functions. * Fix globals name to global_vars * Do not collide with global namespace
1 parent de93023 commit 176e123

File tree

4 files changed

+406
-131
lines changed

4 files changed

+406
-131
lines changed

python/cog/code_xforms.py

Lines changed: 113 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import ast
22
import re
33
import types
4-
from typing import Optional, Set, Union
4+
from typing import List, Optional, Set, Tuple, Union
55

66
COG_IMPORT_MODULES = {
77
"cog",
@@ -25,7 +25,7 @@ def load_module_from_string(
2525
return module
2626

2727

28-
def extract_class_source(source_code: str, class_name: str) -> str:
28+
def extract_class_sources(source_code: str, class_names: List[str]) -> List[str]:
2929
"""
3030
Extracts the source code for a specified class from a given source text.
3131
Args:
@@ -35,23 +35,36 @@ def extract_class_source(source_code: str, class_name: str) -> str:
3535
The source code of the specified class if found; otherwise, an empty string.
3636
"""
3737
class_name_pattern = re.compile(r"\b[A-Z]\w*\b")
38-
all_class_names = class_name_pattern.findall(class_name)
38+
all_class_names = []
39+
for class_name in class_names:
40+
all_class_names.extend(class_name_pattern.findall(class_name))
3941

4042
class ClassExtractor(ast.NodeVisitor):
4143
def __init__(self) -> None:
42-
self.class_source = None
44+
self.class_sources = []
4345

4446
def visit_ClassDef(self, node: ast.ClassDef) -> None: # pylint: disable=invalid-name
45-
if node.name in all_class_names:
46-
self.class_source = ast.get_source_segment(source_code, node)
47+
self.class_sources.append(node)
4748

4849
tree = ast.parse(source_code)
4950
extractor = ClassExtractor()
5051
extractor.visit(tree)
51-
return extractor.class_source if extractor.class_source else ""
5252

53+
valid_class_names = set(all_class_names)
54+
for node in extractor.class_sources:
55+
if node.name not in valid_class_names:
56+
continue
57+
for base_name in node.bases:
58+
valid_class_names.add(base_name.id)
5359

54-
def extract_function_source(source_code: str, function_name: str) -> str:
60+
return [
61+
str(ast.get_source_segment(source_code, x))
62+
for x in extractor.class_sources
63+
if x.name in valid_class_names
64+
]
65+
66+
67+
def extract_function_source(source_code: str, function_names: List[str]) -> str:
5568
"""
5669
Extracts the source code for a specified function from a given source text.
5770
Args:
@@ -63,20 +76,24 @@ def extract_function_source(source_code: str, function_name: str) -> str:
6376

6477
class FunctionExtractor(ast.NodeVisitor):
6578
def __init__(self) -> None:
66-
self.function_source = None
79+
self.function_sources = []
6780

6881
def visit_FunctionDef(self, node: ast.FunctionDef) -> None: # pylint: disable=invalid-name
69-
if node.name == function_name and not isinstance(node, ast.Module):
82+
if node.name in function_names and not isinstance(node, ast.Module):
7083
# Extract the source segment for this function definition
71-
self.function_source = ast.get_source_segment(source_code, node)
84+
self.function_sources.append(ast.get_source_segment(source_code, node))
7285

7386
tree = ast.parse(source_code)
7487
extractor = FunctionExtractor()
7588
extractor.visit(tree)
76-
return extractor.function_source if extractor.function_source else ""
89+
return "\n".join(extractor.function_sources)
7790

7891

79-
def make_class_methods_empty(source_code: Union[str, ast.AST], class_name: str) -> str:
92+
def make_class_methods_empty(
93+
source_code: Union[str, ast.AST],
94+
class_name: Optional[str],
95+
global_vars: List[ast.Assign],
96+
) -> Tuple[str, List[ast.Assign]]:
8097
"""
8198
Transforms the source code of a specified class to remove the bodies of all its methods
8299
and replace them with 'return None'.
@@ -88,26 +105,47 @@ def make_class_methods_empty(source_code: Union[str, ast.AST], class_name: str)
88105
"""
89106

90107
class MethodBodyTransformer(ast.NodeTransformer):
108+
def __init__(self, global_vars: List[ast.Assign]) -> None:
109+
self.used_globals = set()
110+
self._targets = {
111+
target.id: global_name
112+
for global_name in global_vars
113+
for target in global_name.targets
114+
if isinstance(target, ast.Name)
115+
}
116+
91117
def visit_ClassDef(self, node: ast.ClassDef) -> Optional[ast.AST]: # pylint: disable=invalid-name
92-
if node.name == class_name:
118+
if class_name is None or node.name == class_name:
93119
for body_item in node.body:
94120
if isinstance(body_item, ast.FunctionDef):
95121
# Replace the body of the method with `return None`
96122
body_item.body = [ast.Return(value=ast.Constant(value=None))]
123+
# Remove decorators from the function
124+
body_item.decorator_list = []
125+
# Determine if one our globals is referenced by the function.
126+
for default in body_item.args.defaults:
127+
if isinstance(default, ast.Call):
128+
for keyword in default.keywords:
129+
if isinstance(keyword.value, ast.Name):
130+
corresponding_global = self._targets.get(
131+
keyword.value.id
132+
)
133+
if corresponding_global is not None:
134+
self.used_globals.add(corresponding_global)
97135
return node
98136

99137
return None
100138

101139
tree = source_code if isinstance(source_code, ast.AST) else ast.parse(source_code)
102-
transformer = MethodBodyTransformer()
140+
transformer = MethodBodyTransformer(global_vars)
103141
transformed_tree = transformer.visit(tree)
104142
class_code = ast.unparse(transformed_tree)
105-
return class_code
143+
return class_code, list(transformer.used_globals)
106144

107145

108146
def extract_method_return_type(
109-
source_code: Union[str, ast.AST], class_name: str, method_name: str
110-
) -> Optional[str]:
147+
source_code: Union[str, ast.AST], class_names: List[str], method_names: List[str]
148+
) -> List[str]:
111149
"""
112150
Extracts the return type annotation of a specified method within a given class from the source code.
113151
Args:
@@ -120,26 +158,26 @@ def extract_method_return_type(
120158

121159
class MethodReturnTypeExtractor(ast.NodeVisitor):
122160
def __init__(self) -> None:
123-
self.return_type = None
161+
self.return_types = []
124162

125163
def visit_ClassDef(self, node: ast.ClassDef) -> None: # pylint: disable=invalid-name
126-
if node.name == class_name:
164+
if node.name in class_names:
127165
self.generic_visit(node)
128166

129167
def visit_FunctionDef(self, node: ast.FunctionDef) -> None: # pylint: disable=invalid-name
130-
if node.name == method_name and node.returns:
131-
self.return_type = ast.unparse(node.returns)
168+
if node.name in method_names and node.returns:
169+
self.return_types.append(ast.unparse(node.returns))
132170

133171
tree = source_code if isinstance(source_code, ast.AST) else ast.parse(source_code)
134172
extractor = MethodReturnTypeExtractor()
135173
extractor.visit(tree)
136174

137-
return extractor.return_type
175+
return extractor.return_types
138176

139177

140-
def extract_function_return_type(
141-
source_code: Union[str, ast.AST], function_name: str
142-
) -> Optional[str]:
178+
def extract_function_return_types(
179+
source_code: Union[str, ast.AST], function_names: List[str]
180+
) -> List[str]:
143181
"""
144182
Extracts the return type annotation of a specified function from the source code.
145183
Args:
@@ -151,21 +189,23 @@ def extract_function_return_type(
151189

152190
class FunctionReturnTypeExtractor(ast.NodeVisitor):
153191
def __init__(self) -> None:
154-
self.return_type = None
192+
self.return_types = []
155193

156194
def visit_FunctionDef(self, node: ast.FunctionDef) -> None: # pylint: disable=invalid-name
157-
if node.name == function_name and node.returns:
195+
if node.name in function_names and node.returns:
158196
# Extract and return the string representation of the return type
159-
self.return_type = ast.unparse(node.returns)
197+
self.return_types.append(ast.unparse(node.returns))
160198

161199
tree = source_code if isinstance(source_code, ast.AST) else ast.parse(source_code)
162200
extractor = FunctionReturnTypeExtractor()
163201
extractor.visit(tree)
164202

165-
return extractor.return_type
203+
return extractor.return_types
166204

167205

168-
def make_function_empty(source_code: Union[str, ast.AST], function_name: str) -> str:
206+
def make_function_empty(
207+
source_code: Union[str, ast.AST], function_names: List[str]
208+
) -> str:
169209
"""
170210
Transforms the source code to remove the body of a specified function
171211
and replace it with 'return None'.
@@ -178,7 +218,7 @@ def make_function_empty(source_code: Union[str, ast.AST], function_name: str) ->
178218

179219
class FunctionBodyTransformer(ast.NodeTransformer):
180220
def visit_FunctionDef(self, node: ast.FunctionDef) -> Optional[ast.AST]: # pylint: disable=invalid-name
181-
if node.name == function_name:
221+
if node.name in function_names:
182222
# Replace the body of the function with `return None`
183223
node.body = [ast.Return(value=ast.Constant(value=None))]
184224
return node
@@ -224,8 +264,19 @@ def visit_ImportFrom(self, node: ast.ImportFrom) -> None: # pylint: disable=inv
224264
return "\n".join(extractor.imports)
225265

226266

267+
def _extract_globals(source_code: Union[str, ast.AST]) -> List[ast.Assign]:
268+
tree = source_code if isinstance(source_code, ast.AST) else ast.parse(source_code)
269+
if isinstance(tree, ast.Module):
270+
return [x for x in tree.body if isinstance(x, ast.Assign)]
271+
return []
272+
273+
274+
def _render_globals(global_vars: List[ast.Assign]) -> str:
275+
return "\n".join([ast.unparse(x) for x in global_vars])
276+
277+
227278
def strip_model_source_code(
228-
source_code: str, class_name: str, method_name: str
279+
source_code: str, class_names: List[str], method_names: List[str]
229280
) -> Optional[str]:
230281
"""
231282
Strips down the model source code by extracting relevant classes and making methods empty.
@@ -240,32 +291,43 @@ def strip_model_source_code(
240291
Returns None if neither the class nor the function specified could be found or processed.
241292
"""
242293
imports = extract_specific_imports(source_code, COG_IMPORT_MODULES)
243-
class_source = (
244-
None if not class_name else extract_class_source(source_code, class_name)
294+
class_sources = (
295+
None if not class_names else extract_class_sources(source_code, class_names)
245296
)
246-
if class_source:
247-
class_source = make_class_methods_empty(class_source, class_name)
248-
return_type = extract_method_return_type(class_source, class_name, method_name)
249-
return_class_source = (
250-
extract_class_source(source_code, return_type) if return_type else ""
297+
global_vars = _extract_globals(source_code)
298+
if class_sources:
299+
class_source = "\n".join(class_sources)
300+
class_source, global_vars = make_class_methods_empty(
301+
class_source, None, global_vars
251302
)
252-
model_source = (
253-
imports + "\n\n" + return_class_source + "\n\n" + class_source + "\n"
303+
return_types = extract_method_return_type(
304+
class_source, class_names, method_names
305+
)
306+
return_class_sources = (
307+
extract_class_sources(source_code, return_types) if return_types else ""
308+
)
309+
return_class_source = "\n".join(return_class_sources)
310+
rendered_globals = _render_globals(global_vars)
311+
model_source = "\n".join(
312+
[
313+
x
314+
for x in [imports, rendered_globals, return_class_source, class_source]
315+
if x
316+
]
254317
)
255318
else:
256319
# use class_name specified in cog.yaml as method_name
257-
method_name = class_name
258-
function_source = extract_function_source(source_code, method_name)
320+
method_names = class_names
321+
function_source = extract_function_source(source_code, method_names)
259322
if not function_source:
260323
return None
261-
function_source = make_function_empty(function_source, method_name)
324+
function_source = make_function_empty(function_source, method_names)
262325
if not function_source:
263326
return None
264-
return_type = extract_function_return_type(function_source, method_name)
265-
return_class_source = (
266-
extract_class_source(source_code, return_type) if return_type else ""
267-
)
268-
model_source = (
269-
imports + "\n\n" + return_class_source + "\n\n" + function_source + "\n"
327+
return_types = extract_function_return_types(function_source, method_names)
328+
return_class_sources = (
329+
extract_class_sources(source_code, return_types) if return_types else ""
270330
)
331+
return_class_source = "\n".join(return_class_sources)
332+
model_source = "\n".join([imports, return_class_source, function_source])
271333
return model_source

python/cog/predictor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -198,7 +198,7 @@ def load_slim_predictor_from_file(
198198
) -> Optional[types.ModuleType]:
199199
with open(module_path, encoding="utf-8") as file:
200200
source_code = file.read()
201-
stripped_source = strip_model_source_code(source_code, class_name, method_name)
201+
stripped_source = strip_model_source_code(source_code, [class_name], [method_name])
202202
module = load_module_from_string(uuid.uuid4().hex, stripped_source)
203203
return module
204204

python/tests/server/test_code_xforms.py

Lines changed: 0 additions & 79 deletions
This file was deleted.

0 commit comments

Comments
 (0)