Skip to content

Commit 93d8812

Browse files
authored
[SOT] Mark some APIs can be directly run in simulation mode (#70293)
1 parent d172876 commit 93d8812

File tree

4 files changed

+84
-0
lines changed

4 files changed

+84
-0
lines changed

python/paddle/jit/sot/opcode_translator/executor/variables/callable.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
is_break_graph_api,
3737
is_break_graph_tensor_methods,
3838
is_builtin_fn,
39+
is_directly_run_api,
3940
is_not_supported_paddle_layer,
4041
is_paddle_api,
4142
magic_method_builtin_dispatch,
@@ -699,6 +700,22 @@ def call_function(self, /, *args, **kwargs):
699700
)
700701
return handler(*args, **kwargs)
701702

703+
# If API can be directly called in simulation mode (e.g. user defined native code
704+
# without graph affect), we can directly call it.
705+
if is_directly_run_api(self.value):
706+
from ..function_graph import convert_to_py_value
707+
708+
res = self.value(
709+
*convert_to_py_value(args),
710+
**convert_to_py_value(kwargs),
711+
)
712+
713+
return VariableFactory.from_value(
714+
res,
715+
self.graph,
716+
DummyTracker([self, *list(args), *list(kwargs.values())]),
717+
)
718+
702719
# Try to inline call the magic function
703720
magic_methods = magic_method_builtin_dispatch(self.value)
704721
for magic_method in magic_methods:

python/paddle/jit/sot/utils/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949
from .paddle_api_config import ( # noqa: F401
5050
get_tensor_methods,
5151
is_break_graph_tensor_methods,
52+
is_directly_run_api,
5253
is_inplace_api,
5354
is_not_supported_paddle_layer,
5455
)

python/paddle/jit/sot/utils/paddle_api_config.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,3 +129,28 @@ def is_break_graph_tensor_methods(method_name):
129129

130130
def add_break_graph_apis(apis: list):
131131
break_graph_set.update(apis)
132+
133+
134+
def is_directly_run_api(api):
135+
from .utils import hashable
136+
137+
if not hashable(api):
138+
return False
139+
NATIVE_CODE_PURE_FUNCTIONS = {
140+
paddle.base.libpaddle.is_compiled_with_avx,
141+
paddle.base.libpaddle.is_compiled_with_cuda,
142+
paddle.base.libpaddle.is_compiled_with_cudnn_frontend,
143+
paddle.base.libpaddle.is_compiled_with_rocm,
144+
paddle.base.libpaddle.is_compiled_with_custom_device,
145+
paddle.base.libpaddle.is_compiled_with_ipu,
146+
paddle.base.libpaddle.is_compiled_with_xpu,
147+
paddle.base.libpaddle.is_compiled_with_mkldnn,
148+
paddle.base.libpaddle.is_compiled_with_nccl,
149+
paddle.base.libpaddle.is_compiled_with_mpi,
150+
paddle.base.libpaddle.is_compiled_with_mpi_aware,
151+
paddle.base.libpaddle.is_compiled_with_cinn,
152+
paddle.base.libpaddle.is_compiled_with_distribute,
153+
paddle.base.libpaddle.is_compiled_with_brpc,
154+
paddle.base.libpaddle.is_compiled_with_dist,
155+
}
156+
return api in NATIVE_CODE_PURE_FUNCTIONS

test/sot/test_builtin_dispatch.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -388,5 +388,46 @@ def test_builtin_type_conversion_breakgraph(self):
388388
)
389389

390390

391+
@check_no_breakgraph
392+
def test_native_code_function():
393+
res1 = paddle.base.libpaddle.is_compiled_with_avx()
394+
res2 = paddle.base.libpaddle.is_compiled_with_cuda()
395+
res3 = paddle.base.libpaddle.is_compiled_with_cudnn_frontend()
396+
res4 = paddle.base.libpaddle.is_compiled_with_rocm()
397+
res5 = paddle.base.libpaddle.is_compiled_with_custom_device("npu")
398+
res6 = paddle.base.libpaddle.is_compiled_with_ipu()
399+
res7 = paddle.base.libpaddle.is_compiled_with_xpu()
400+
res8 = paddle.base.libpaddle.is_compiled_with_mkldnn()
401+
res9 = paddle.base.libpaddle.is_compiled_with_nccl()
402+
res10 = paddle.base.libpaddle.is_compiled_with_mpi()
403+
res11 = paddle.base.libpaddle.is_compiled_with_mpi_aware()
404+
res12 = paddle.base.libpaddle.is_compiled_with_cinn()
405+
res13 = paddle.base.libpaddle.is_compiled_with_distribute()
406+
res14 = paddle.base.libpaddle.is_compiled_with_brpc()
407+
res15 = paddle.base.libpaddle.is_compiled_with_dist()
408+
return (
409+
res1,
410+
res2,
411+
res3,
412+
res4,
413+
res5,
414+
res6,
415+
res7,
416+
res8,
417+
res9,
418+
res10,
419+
res11,
420+
res12,
421+
res13,
422+
res14,
423+
res15,
424+
)
425+
426+
427+
class TestNativeCodeFunction(TestCaseBase):
428+
def test_native_code_function(self):
429+
self.assert_results(test_native_code_function)
430+
431+
391432
if __name__ == "__main__":
392433
unittest.main()

0 commit comments

Comments
 (0)