Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1071,6 +1071,7 @@ def generate(self, cache_key=None) -> types.FunctionType:
self.codegen._code_options['co_flags'] &= ~(
inspect.CO_VARARGS | inspect.CO_VARKEYWORDS
)
self.codegen._code_options['co_kwonlyargcount'] = 0
new_code = self.codegen.gen_pycode()
# TODO(SigureMo): cache_key should not be None
if cache_key is not None:
Expand Down
39 changes: 39 additions & 0 deletions test/sot/test_06_call_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
)

import paddle
from paddle.jit.sot.psdb import check_no_breakgraph


def add(x, y):
Expand Down Expand Up @@ -178,5 +179,43 @@ def test_apply_fn(self):
self.assertEqual(ctx.translate_count, 2)


@check_no_breakgraph
def positional_only_basic(x, /, y):
z = x + y
return x + z


def positional_only_breakgraph(x, /, y):
z = x + y
paddle.jit.sot.psdb.breakgraph()
return x + z


@check_no_breakgraph
def keyword_only_basic(x, *, y):
z = x + y
return x + z


def keyword_only_breakgraph(x, *, y):
z = x + y
paddle.jit.sot.psdb.breakgraph()
return x + z


class TestPositionalKeywordOnly(TestCaseBase):
def test_positional_only_basic(self):
self.assert_results(positional_only_basic, 1, 2)

def test_positional_only_breakgraph(self):
self.assert_results(positional_only_breakgraph, 1, 2)

def test_keyword_only_basic(self):
self.assert_results(keyword_only_basic, x=1, y=2)

def test_keyword_only_breakgraph(self):
self.assert_results(keyword_only_breakgraph, x=1, y=2)


if __name__ == "__main__":
unittest.main()
25 changes: 13 additions & 12 deletions test/sot/test_case_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,21 +94,22 @@ def assert_nest_match(self, x, y):
else:
self.assertEqual(x, y)

def assert_results(self, func, *inputs):
sym_output = symbolic_translate(func)(*inputs)
paddle_output = func(*inputs)
def assert_results(self, func, *args, **kwargs):
sym_output = symbolic_translate(func)(*args, **kwargs)
paddle_output = func(*args, **kwargs)
self.assert_nest_match(sym_output, paddle_output)

def assert_results_with_side_effects(self, func, *inputs):
sym_inputs = copy.deepcopy(inputs)
sym_output = symbolic_translate(func)(*sym_inputs)
paddle_inputs = copy.deepcopy(inputs)
paddle_output = func(*paddle_inputs)
self.assert_nest_match(sym_inputs, paddle_inputs)
def assert_results_with_side_effects(self, func, *args, **kwargs):
sym_args, sym_kwargs = copy.deepcopy((args, kwargs))
sym_output = symbolic_translate(func)(*sym_args, **sym_kwargs)
paddle_args, paddle_kwargs = copy.deepcopy((args, kwargs))
paddle_output = func(*paddle_args, **paddle_kwargs)
self.assert_nest_match(sym_args, paddle_args)
self.assert_nest_match(sym_kwargs, paddle_kwargs)
self.assert_nest_match(sym_output, paddle_output)

def assert_results_with_global_check(
self, func, global_keys: list[str], *inputs
self, func, global_keys: list[str], *args, **kwargs
):
def copy_fn(fn):
return types.FunctionType(
Expand All @@ -122,8 +123,8 @@ def copy_fn(fn):
sym_copied_fn = copy_fn(func)
sym_fn = symbolic_translate(sym_copied_fn)
paddle_fn = copy_fn(func)
sym_output = sym_fn(*inputs)
paddle_output = paddle_fn(*inputs)
sym_output = sym_fn(*args, **kwargs)
paddle_output = paddle_fn(*args, **kwargs)
for key in global_keys:
self.assert_nest_match(
sym_copied_fn.__globals__[key], paddle_fn.__globals__[key]
Expand Down