Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
log_do,
)
from ..custom_code import CustomCode
from .function_graph import FunctionGraph
from .guard import Guard
from .opcode_executor import OpcodeExecutor, OpcodeExecutorBase

Expand Down Expand Up @@ -235,12 +236,13 @@ def start_translate(
Returns:
tuple[CustomCode, Guard | None]: The translated code object and its guard function, or None if translation fails.
"""
simulator = OpcodeExecutor(frame, **kwargs)
graph = FunctionGraph(frame.f_code, frame.f_globals, **kwargs)
simulator = OpcodeExecutor(frame, graph)
try:
simulator.check_code_simulatable()
InfoCollector().attach(CompileCountInfo, frame.f_code)
with sot_simulation_mode_guard(True):
new_custom_code, guard_fn = simulator.transform()
new_custom_code, guard_fn = simulator.transform(frame)
if not simulator._graph.need_cache:
return (
CustomCode(None, True),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from collections import namedtuple
from copy import deepcopy
from functools import cached_property, reduce
from typing import Any, Callable, Tuple, Union
from typing import TYPE_CHECKING, Any, Callable, Tuple, Union

from typing_extensions import TypeAlias, TypeGuard

Expand Down Expand Up @@ -83,6 +83,10 @@
map_variables,
)

if TYPE_CHECKING:
import types


CompileGraphResult: TypeAlias = Tuple[
Callable[..., Any],
Tuple[
Expand Down Expand Up @@ -211,11 +215,13 @@ class FunctionGraph:
],
)

def __init__(self, frame, **kwargs):
def __init__(
self, code: types.CodeType, globals: dict[str, object], **kwargs
):
self.sir_ctx = SymbolicTraceContext()
self.inner_out = set()
self.input_variables = [] # Store variables required within a function
self.pycode_gen = PyCodeGen(frame, disable_eval_frame=True)
self.pycode_gen = PyCodeGen(code, globals, disable_eval_frame=True)
self.side_effects = SideEffects()
self.need_cache = True
self._global_guarded_variables: OrderedSet[VariableBase] = OrderedSet()
Expand Down
55 changes: 28 additions & 27 deletions python/paddle/jit/sot/opcode_translator/executor/opcode_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,6 @@
operator_not_in,
)
from .dispatcher import Dispatcher
from .function_graph import FunctionGraph
from .instr_flag import (
CALL_FUNCTION_EX_FLAG as CFE,
CONVERT_VALUE_FLAG as CV,
Expand Down Expand Up @@ -111,7 +110,7 @@
)

if TYPE_CHECKING:
from .function_graph import CompileGraphResult
from .function_graph import CompileGraphResult, FunctionGraph

SUPPORT_COMPARE_OP = {
">": operator.gt,
Expand Down Expand Up @@ -456,16 +455,6 @@ def _break_graph_when_if(self, result, instr: Instruction):
"""
raise NotImplementedError

def transform(self):
"""
Abstract method need to be implemented to symbolic translate each instruction.

Raises:
NotImplementedError: If the method is not implemented.

"""
raise NotImplementedError

def find_space_of_var_name(self, name):
code = self._graph.pycode_gen._origin_code
if name in (code.co_freevars + code.co_cellvars):
Expand Down Expand Up @@ -1886,18 +1875,17 @@ class OpcodeExecutor(OpcodeExecutorBase):

"""

def __init__(self, frame: types.FrameType, **kwargs):
graph = FunctionGraph(frame, **kwargs)
self._frame = frame
def __init__(self, frame: types.FrameType, graph: FunctionGraph):
self._frame = frame # TODO: Don't hold frame in executor, just hold vframe instead
self._name = "Executor"
self.call_stack[:] = []
super().__init__(frame.f_code, graph)
Dispatcher.graph = graph

def transform(self):
static_function = get_static_function(self._frame, "eval_frame")
def transform(self, frame: types.FrameType):
static_function = get_static_function(frame, "eval_frame")
if static_function is not None:
code = self._frame.f_code
code = frame.f_code
inputs = []
for i in range(code.co_argcount):
arg_name = code.co_varnames[i]
Expand Down Expand Up @@ -1932,11 +1920,9 @@ def _prepare_virtual_env(self):
"""
log(
3,
f"[Executor] code options: co_cellvars={self._frame.f_code.co_cellvars}\n",
)
free_or_cell_vars = (
self._frame.f_code.co_cellvars + self._frame.f_code.co_freevars
f"[Executor] code options: co_cellvars={self._code.co_cellvars}\n",
)
free_or_cell_vars = self._code.co_cellvars + self._code.co_freevars
for name, value in self._frame.f_locals.items():
tracker = (
CellTracker(name)
Expand Down Expand Up @@ -2109,7 +2095,10 @@ def create_if_branch_fn(
):
return None
cache_key = (ResumeFunctionType.IF_RESUME, self._code, start_idx)
resume_fn_creator = ResumeFunctionCreator(self._frame)
resume_fn_creator = ResumeFunctionCreator(
self._graph.pycode_gen._origin_code,
self._graph.pycode_gen._real_globals,
)
if (
maybe_resume_fn := resume_fn_creator.lookup(cache_key)
) is not None:
Expand Down Expand Up @@ -2265,7 +2254,10 @@ def create_resume_fn(null_indices):
if self._instructions[next_index].opname == "RETURN_VALUE":
return None
cache_key = (ResumeFunctionType.CALL_RESUME, self._code, next_index)
resume_fn_creator = ResumeFunctionCreator(self._frame)
resume_fn_creator = ResumeFunctionCreator(
self._graph.pycode_gen._origin_code,
self._graph.pycode_gen._real_globals,
)
if (
maybe_resume_fn := resume_fn_creator.lookup(cache_key)
) is not None:
Expand Down Expand Up @@ -2370,7 +2362,10 @@ def create_loop_body():
loop_body_start_idx,
loop_body_end_idx,
)
resume_fn_creator = ResumeFunctionCreator(self._frame)
resume_fn_creator = ResumeFunctionCreator(
self._graph.pycode_gen._origin_code,
self._graph.pycode_gen._real_globals,
)
if (
maybe_resume_fn := resume_fn_creator.lookup(cache_key)
) is not None:
Expand Down Expand Up @@ -2441,7 +2436,10 @@ def create_after_loop_fn():
self._code,
loop_body_end_idx,
)
resume_fn_creator = ResumeFunctionCreator(self._frame)
resume_fn_creator = ResumeFunctionCreator(
self._graph.pycode_gen._origin_code,
self._graph.pycode_gen._real_globals,
)
if (
maybe_resume_fn := resume_fn_creator.lookup(cache_key)
) is not None:
Expand Down Expand Up @@ -2601,7 +2599,10 @@ def create_inline_call_fn():
start_idx,
end_idx,
)
resume_fn_creator = ResumeFunctionCreator(self._frame)
resume_fn_creator = ResumeFunctionCreator(
self._graph.pycode_gen._origin_code,
self._graph.pycode_gen._real_globals,
)
if (
maybe_resume_fn := resume_fn_creator.lookup(cache_key)
) is not None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -413,7 +413,10 @@ class PyCodeGen:
"""Helper to create new code object"""

def __init__(
self, frame: types.FrameType, disable_eval_frame: bool = False
self,
real_code: types.CodeType,
real_globals: dict[str, object],
disable_eval_frame: bool = False,
):
"""
Initializes a PyCodeGen object.
Expand All @@ -422,11 +425,10 @@ def __init__(
frame: The frame to be translated.
disable_eval_frame (bool): Whether to disable the evaluation frame. Defaults to False.
"""
self._frame = frame
self._origin_code = frame.f_code
self._origin_code = real_code
self._code_options = gen_code_options(self._origin_code)
self.update_code_name("", is_resumed_fn=False)
self._f_globals = frame.f_globals
self._real_globals = real_globals
self._instructions = []
self.disable_eval_frame = disable_eval_frame
self.hooks = []
Expand Down Expand Up @@ -666,8 +668,8 @@ def gen_load_object(self, obj, obj_name: str, push_null: bool = True):
obj_name (str): The name of the object.
"""

if obj_name not in self._f_globals:
self._f_globals[obj_name] = obj
if obj_name not in self._real_globals:
self._real_globals[obj_name] = obj
return self.gen_load_global(obj_name, push_null=push_null)

def gen_load_null_variable(self):
Expand Down Expand Up @@ -1007,9 +1009,12 @@ class ResumeFunctionCreator:
CODE_CACHE = {}

def __init__(
self, frame: types.FrameType, disable_eval_frame: bool = False
self,
code: types.CodeType,
globals: dict[str, object],
disable_eval_frame: bool = False,
):
self.codegen = PyCodeGen(frame, disable_eval_frame)
self.codegen = PyCodeGen(code, globals, disable_eval_frame)
self.name = ResumeFnNameFactory().next()

def set_inputs(
Expand Down Expand Up @@ -1064,7 +1069,7 @@ def lookup(self, cache_key):
cached_code = self.CODE_CACHE[cache_key]
ResumeFunctionCreator.validate_code(cached_code)
return types.FunctionType(
cached_code, self.codegen._f_globals, cached_code.co_name
cached_code, self.codegen._real_globals, cached_code.co_name
)
return None

Expand All @@ -1080,6 +1085,6 @@ def generate(self, cache_key=None) -> types.FunctionType:
self.CODE_CACHE[cache_key] = new_code
ResumeFunctionCreator.validate_code(new_code)
fn = types.FunctionType(
new_code, self.codegen._f_globals, new_code.co_name
new_code, self.codegen._real_globals, new_code.co_name
)
return fn
3 changes: 2 additions & 1 deletion test/sot/test_sir_rollback.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,8 @@ def try_add(x, y):
class TestRollback(TestCaseBase):
def test_rollback(self):
frame = inspect.currentframe()
graph = FunctionGraph(frame)
assert frame is not None
graph = FunctionGraph(frame.f_code, frame.f_globals)
a = paddle.to_tensor(1.0)
b = paddle.to_tensor(2.0)
a = VariableFactory().from_value(a, graph, LocalTracker("a"))
Expand Down
Loading