Skip to content

Commit 34343f5

Browse files
authored
[SOT] Refactor sot simulation to avoid hold real frame in graph and codegen (#71275)
1 parent 2562cd1 commit 34343f5

File tree

5 files changed

+58
-43
lines changed

5 files changed

+58
-43
lines changed

python/paddle/jit/sot/opcode_translator/executor/executor_cache.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
log_do,
3636
)
3737
from ..custom_code import CustomCode
38+
from .function_graph import FunctionGraph
3839
from .guard import Guard
3940
from .opcode_executor import OpcodeExecutor, OpcodeExecutorBase
4041

@@ -235,12 +236,13 @@ def start_translate(
235236
Returns:
236237
tuple[CustomCode, Guard | None]: The translated code object and its guard function, or None if translation fails.
237238
"""
238-
simulator = OpcodeExecutor(frame, **kwargs)
239+
graph = FunctionGraph(frame.f_code, frame.f_globals, **kwargs)
240+
simulator = OpcodeExecutor(frame, graph)
239241
try:
240242
simulator.check_code_simulatable()
241243
InfoCollector().attach(CompileCountInfo, frame.f_code)
242244
with sot_simulation_mode_guard(True):
243-
new_custom_code, guard_fn = simulator.transform()
245+
new_custom_code, guard_fn = simulator.transform(frame)
244246
if not simulator._graph.need_cache:
245247
return (
246248
CustomCode(None, True),

python/paddle/jit/sot/opcode_translator/executor/function_graph.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from collections import namedtuple
2323
from copy import deepcopy
2424
from functools import cached_property, reduce
25-
from typing import Any, Callable, Tuple, Union
25+
from typing import TYPE_CHECKING, Any, Callable, Tuple, Union
2626

2727
from typing_extensions import TypeAlias, TypeGuard
2828

@@ -83,6 +83,10 @@
8383
map_variables,
8484
)
8585

86+
if TYPE_CHECKING:
87+
import types
88+
89+
8690
CompileGraphResult: TypeAlias = Tuple[
8791
Callable[..., Any],
8892
Tuple[
@@ -211,11 +215,13 @@ class FunctionGraph:
211215
],
212216
)
213217

214-
def __init__(self, frame, **kwargs):
218+
def __init__(
219+
self, code: types.CodeType, globals: dict[str, object], **kwargs
220+
):
215221
self.sir_ctx = SymbolicTraceContext()
216222
self.inner_out = set()
217223
self.input_variables = [] # Store variables required within a function
218-
self.pycode_gen = PyCodeGen(frame, disable_eval_frame=True)
224+
self.pycode_gen = PyCodeGen(code, globals, disable_eval_frame=True)
219225
self.side_effects = SideEffects()
220226
self.need_cache = True
221227
self._global_guarded_variables: OrderedSet[VariableBase] = OrderedSet()

python/paddle/jit/sot/opcode_translator/executor/opcode_executor.py

Lines changed: 28 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,6 @@
6666
operator_not_in,
6767
)
6868
from .dispatcher import Dispatcher
69-
from .function_graph import FunctionGraph
7069
from .instr_flag import (
7170
CALL_FUNCTION_EX_FLAG as CFE,
7271
CONVERT_VALUE_FLAG as CV,
@@ -111,7 +110,7 @@
111110
)
112111

113112
if TYPE_CHECKING:
114-
from .function_graph import CompileGraphResult
113+
from .function_graph import CompileGraphResult, FunctionGraph
115114

116115
SUPPORT_COMPARE_OP = {
117116
">": operator.gt,
@@ -456,16 +455,6 @@ def _break_graph_when_if(self, result, instr: Instruction):
456455
"""
457456
raise NotImplementedError
458457

459-
def transform(self):
460-
"""
461-
Abstract method need to be implemented to symbolic translate each instruction.
462-
463-
Raises:
464-
NotImplementedError: If the method is not implemented.
465-
466-
"""
467-
raise NotImplementedError
468-
469458
def find_space_of_var_name(self, name):
470459
code = self._graph.pycode_gen._origin_code
471460
if name in (code.co_freevars + code.co_cellvars):
@@ -1886,18 +1875,17 @@ class OpcodeExecutor(OpcodeExecutorBase):
18861875
18871876
"""
18881877

1889-
def __init__(self, frame: types.FrameType, **kwargs):
1890-
graph = FunctionGraph(frame, **kwargs)
1891-
self._frame = frame
1878+
def __init__(self, frame: types.FrameType, graph: FunctionGraph):
1879+
self._frame = frame # TODO: Don't hold frame in executor, just hold vframe instead
18921880
self._name = "Executor"
18931881
self.call_stack[:] = []
18941882
super().__init__(frame.f_code, graph)
18951883
Dispatcher.graph = graph
18961884

1897-
def transform(self):
1898-
static_function = get_static_function(self._frame, "eval_frame")
1885+
def transform(self, frame: types.FrameType):
1886+
static_function = get_static_function(frame, "eval_frame")
18991887
if static_function is not None:
1900-
code = self._frame.f_code
1888+
code = frame.f_code
19011889
inputs = []
19021890
for i in range(code.co_argcount):
19031891
arg_name = code.co_varnames[i]
@@ -1932,11 +1920,9 @@ def _prepare_virtual_env(self):
19321920
"""
19331921
log(
19341922
3,
1935-
f"[Executor] code options: co_cellvars={self._frame.f_code.co_cellvars}\n",
1936-
)
1937-
free_or_cell_vars = (
1938-
self._frame.f_code.co_cellvars + self._frame.f_code.co_freevars
1923+
f"[Executor] code options: co_cellvars={self._code.co_cellvars}\n",
19391924
)
1925+
free_or_cell_vars = self._code.co_cellvars + self._code.co_freevars
19401926
for name, value in self._frame.f_locals.items():
19411927
tracker = (
19421928
CellTracker(name)
@@ -2109,7 +2095,10 @@ def create_if_branch_fn(
21092095
):
21102096
return None
21112097
cache_key = (ResumeFunctionType.IF_RESUME, self._code, start_idx)
2112-
resume_fn_creator = ResumeFunctionCreator(self._frame)
2098+
resume_fn_creator = ResumeFunctionCreator(
2099+
self._graph.pycode_gen._origin_code,
2100+
self._graph.pycode_gen._real_globals,
2101+
)
21132102
if (
21142103
maybe_resume_fn := resume_fn_creator.lookup(cache_key)
21152104
) is not None:
@@ -2265,7 +2254,10 @@ def create_resume_fn(null_indices):
22652254
if self._instructions[next_index].opname == "RETURN_VALUE":
22662255
return None
22672256
cache_key = (ResumeFunctionType.CALL_RESUME, self._code, next_index)
2268-
resume_fn_creator = ResumeFunctionCreator(self._frame)
2257+
resume_fn_creator = ResumeFunctionCreator(
2258+
self._graph.pycode_gen._origin_code,
2259+
self._graph.pycode_gen._real_globals,
2260+
)
22692261
if (
22702262
maybe_resume_fn := resume_fn_creator.lookup(cache_key)
22712263
) is not None:
@@ -2370,7 +2362,10 @@ def create_loop_body():
23702362
loop_body_start_idx,
23712363
loop_body_end_idx,
23722364
)
2373-
resume_fn_creator = ResumeFunctionCreator(self._frame)
2365+
resume_fn_creator = ResumeFunctionCreator(
2366+
self._graph.pycode_gen._origin_code,
2367+
self._graph.pycode_gen._real_globals,
2368+
)
23742369
if (
23752370
maybe_resume_fn := resume_fn_creator.lookup(cache_key)
23762371
) is not None:
@@ -2441,7 +2436,10 @@ def create_after_loop_fn():
24412436
self._code,
24422437
loop_body_end_idx,
24432438
)
2444-
resume_fn_creator = ResumeFunctionCreator(self._frame)
2439+
resume_fn_creator = ResumeFunctionCreator(
2440+
self._graph.pycode_gen._origin_code,
2441+
self._graph.pycode_gen._real_globals,
2442+
)
24452443
if (
24462444
maybe_resume_fn := resume_fn_creator.lookup(cache_key)
24472445
) is not None:
@@ -2601,7 +2599,10 @@ def create_inline_call_fn():
26012599
start_idx,
26022600
end_idx,
26032601
)
2604-
resume_fn_creator = ResumeFunctionCreator(self._frame)
2602+
resume_fn_creator = ResumeFunctionCreator(
2603+
self._graph.pycode_gen._origin_code,
2604+
self._graph.pycode_gen._real_globals,
2605+
)
26052606
if (
26062607
maybe_resume_fn := resume_fn_creator.lookup(cache_key)
26072608
) is not None:

python/paddle/jit/sot/opcode_translator/executor/pycode_generator.py

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -413,7 +413,10 @@ class PyCodeGen:
413413
"""Helper to create new code object"""
414414

415415
def __init__(
416-
self, frame: types.FrameType, disable_eval_frame: bool = False
416+
self,
417+
real_code: types.CodeType,
418+
real_globals: dict[str, object],
419+
disable_eval_frame: bool = False,
417420
):
418421
"""
419422
Initializes a PyCodeGen object.
@@ -422,11 +425,10 @@ def __init__(
422425
frame: The frame to be translated.
423426
disable_eval_frame (bool): Whether to disable the evaluation frame. Defaults to False.
424427
"""
425-
self._frame = frame
426-
self._origin_code = frame.f_code
428+
self._origin_code = real_code
427429
self._code_options = gen_code_options(self._origin_code)
428430
self.update_code_name("", is_resumed_fn=False)
429-
self._f_globals = frame.f_globals
431+
self._real_globals = real_globals
430432
self._instructions = []
431433
self.disable_eval_frame = disable_eval_frame
432434
self.hooks = []
@@ -666,8 +668,8 @@ def gen_load_object(self, obj, obj_name: str, push_null: bool = True):
666668
obj_name (str): The name of the object.
667669
"""
668670

669-
if obj_name not in self._f_globals:
670-
self._f_globals[obj_name] = obj
671+
if obj_name not in self._real_globals:
672+
self._real_globals[obj_name] = obj
671673
return self.gen_load_global(obj_name, push_null=push_null)
672674

673675
def gen_load_null_variable(self):
@@ -1007,9 +1009,12 @@ class ResumeFunctionCreator:
10071009
CODE_CACHE = {}
10081010

10091011
def __init__(
1010-
self, frame: types.FrameType, disable_eval_frame: bool = False
1012+
self,
1013+
code: types.CodeType,
1014+
globals: dict[str, object],
1015+
disable_eval_frame: bool = False,
10111016
):
1012-
self.codegen = PyCodeGen(frame, disable_eval_frame)
1017+
self.codegen = PyCodeGen(code, globals, disable_eval_frame)
10131018
self.name = ResumeFnNameFactory().next()
10141019

10151020
def set_inputs(
@@ -1064,7 +1069,7 @@ def lookup(self, cache_key):
10641069
cached_code = self.CODE_CACHE[cache_key]
10651070
ResumeFunctionCreator.validate_code(cached_code)
10661071
return types.FunctionType(
1067-
cached_code, self.codegen._f_globals, cached_code.co_name
1072+
cached_code, self.codegen._real_globals, cached_code.co_name
10681073
)
10691074
return None
10701075

@@ -1080,6 +1085,6 @@ def generate(self, cache_key=None) -> types.FunctionType:
10801085
self.CODE_CACHE[cache_key] = new_code
10811086
ResumeFunctionCreator.validate_code(new_code)
10821087
fn = types.FunctionType(
1083-
new_code, self.codegen._f_globals, new_code.co_name
1088+
new_code, self.codegen._real_globals, new_code.co_name
10841089
)
10851090
return fn

test/sot/test_sir_rollback.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,8 @@ def try_add(x, y):
4646
class TestRollback(TestCaseBase):
4747
def test_rollback(self):
4848
frame = inspect.currentframe()
49-
graph = FunctionGraph(frame)
49+
assert frame is not None
50+
graph = FunctionGraph(frame.f_code, frame.f_globals)
5051
a = paddle.to_tensor(1.0)
5152
b = paddle.to_tensor(2.0)
5253
a = VariableFactory().from_value(a, graph, LocalTracker("a"))

0 commit comments

Comments
 (0)