-
Notifications
You must be signed in to change notification settings - Fork 5.9k
[Dy2St]Add backend for to_static API #52596
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 5 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -35,6 +35,7 @@ | |
| from paddle.fluid import core, unique_name | ||
| from paddle.fluid.data_feeder import convert_dtype | ||
| from paddle.fluid.layer_helper import LayerHelper | ||
| from paddle.fluid.wrapped_decorator import signature_safe_contextmanager | ||
| from paddle.utils import gast | ||
|
|
||
| from .ast_utils import ast_to_source_code | ||
|
|
@@ -1491,7 +1492,10 @@ def _out_grad_names(program_desc, fwd_end_op_index, out_size): | |
| return names | ||
|
|
||
|
|
||
| def prim_or_cinn_is_enabled(build_strategy): | ||
| def prim_or_cinn_is_enabled(build_strategy, backend): | ||
| if backend == 'CINN': | ||
| return True | ||
|
|
||
| if build_strategy is not None and build_strategy.build_cinn_pass: | ||
| return True | ||
|
|
||
|
|
@@ -1527,3 +1531,18 @@ def name_judge(): | |
| return True | ||
| else: | ||
| return False | ||
|
|
||
|
|
||
| @signature_safe_contextmanager | ||
| def dy2st_prim_guard(backend): | ||
|
||
| core.check_and_set_prim_all_enabled() | ||
| orign_fwd = core._is_fwd_prim_enabled() | ||
| orign_bwd = core._is_bwd_prim_enabled() | ||
|
|
||
| if backend == 'CINN': | ||
| core._set_prim_all_enabled(True) | ||
| try: | ||
| yield | ||
| finally: | ||
| core._set_prim_forward_enabled(orign_fwd) | ||
| core._set_prim_backward_enabled(orign_bwd) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -163,5 +163,20 @@ def test_cinn_prim(self): | |
| ) | ||
|
|
||
|
|
||
| class TestBackend(unittest.TestCase): | ||
| def test_backend(self): | ||
| x = paddle.randn([2, 4]) | ||
| out1 = self.forward(x, 'CINN') | ||
| out2 = self.forward(x, None) | ||
| np.testing.assert_allclose(out1, out2, rtol=1e-6) | ||
|
|
||
| def forward(self, x, beckend=None): | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. typo
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 我下面提个PR再修改,感谢~ |
||
| paddle.seed(2022) | ||
| net = PrimeNet() | ||
| net = paddle.jit.to_static(net, backend=beckend) | ||
| out = net(x) | ||
| return out | ||
|
|
||
|
|
||
| if __name__ == '__main__': | ||
| unittest.main() | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
property参数从positional/keyword arg 皆可, 改成了 keyword-only arg,是不兼容升级了。检查过其影响吗?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
property参数目前只有语音套件会使用,已经和语音侧同步,语音侧使用方式都是通过k=v的方式传入该参数,所以对主框架以及套件不会产生其他影响。不兼容升级的邮件已经发送,评委已经通过