-
Couldn't load subscription status.
- Fork 80
【PaddlePaddle Hackathon 5th No.42】转换规则 第一组 #301
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
3c4e5dc
54d18ba
f068312
67def81
a6bc905
1e2ee48
c113a68
44b9c33
d630ccf
18a74eb
e55b000
5ddef1a
d517855
414768f
636da87
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -473,6 +473,41 @@ def get_paddle_nodes(self, args, kwargs): | |
| return ast.parse(code).body | ||
|
|
||
|
|
||
| class TensorTileMatcher(BaseMatcher): | ||
| def get_paddle_class_nodes(self, func, args, kwargs): | ||
| self.parse_func(func) | ||
|
|
||
| if len(args) == 1 and isinstance(args[0], (ast.List, ast.Tuple)): | ||
| repeat_times_list = self.parse_args(args)[0] | ||
| else: # len(args) >= 1 | ||
| repeat_times_list = self.parse_args(args) | ||
|
|
||
| kwargs = self.parse_kwargs(kwargs) | ||
| if kwargs is None: | ||
| return None | ||
|
|
||
| if "reps" in kwargs: | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
| kwargs = {"repeat_times": kwargs.pop("reps"), **kwargs} | ||
| else: | ||
| kwargs = {"repeat_times": str(repeat_times_list).replace("'", ""), **kwargs} | ||
|
|
||
| code = "{}.tile({})".format(self.paddleClass, self.kwargs_to_str(kwargs)) | ||
| return ast.parse(code).body | ||
|
|
||
|
|
||
| class StftMatcher(BaseMatcher): | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 已添加 |
||
| def generate_code(self, kwargs): | ||
| if "input" not in kwargs: | ||
| kwargs["x"] = self.paddleClass | ||
| if ("return_complex" in kwargs) and (kwargs.pop("return_complex") == "(False)"): | ||
| code = "paddle.as_real(paddle.signal.stft({}))".format( | ||
| self.kwargs_to_str(kwargs) | ||
| ) | ||
| else: | ||
| code = "paddle.signal.stft({})".format(self.kwargs_to_str(kwargs)) | ||
| return code | ||
|
|
||
|
|
||
| class DeviceMatcher(BaseMatcher): | ||
| def generate_code(self, kwargs): | ||
| if len(kwargs) == 1: | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,69 @@ | ||
| # Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| import textwrap | ||
|
|
||
| from apibase import APIBase | ||
|
|
||
| obj = APIBase("torch.Tensor.quantile") | ||
|
|
||
|
|
||
| def test_case_1(): | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
| pytorch_code = textwrap.dedent( | ||
| """ | ||
| import torch | ||
| result = torch.tensor([[ 0.0795, -1.2117, 0.9765], [ 1.1707, 0.6706, 0.4884]],dtype=torch.float64).quantile(0.6) | ||
| """ | ||
| ) | ||
| obj.run(pytorch_code, ["result"]) | ||
|
|
||
|
|
||
| def test_case_2(): | ||
| pytorch_code = textwrap.dedent( | ||
| """ | ||
| import torch | ||
| result = torch.tensor([[ 0.0795, -1.2117, 0.9765], [ 1.1707, 0.6706, 0.4884]],dtype=torch.float64).quantile(0.6, dim=None) | ||
| """ | ||
| ) | ||
| obj.run(pytorch_code, ["result"]) | ||
|
|
||
|
|
||
| def test_case_3(): | ||
| pytorch_code = textwrap.dedent( | ||
| """ | ||
| import torch | ||
| result = torch.tensor([[ 0.0795, -1.2117, 0.9765], [ 1.1707, 0.6706, 0.4884]],dtype=torch.float64).quantile(0.6, dim=0, keepdim=False) | ||
| """ | ||
| ) | ||
| obj.run(pytorch_code, ["result"]) | ||
|
|
||
|
|
||
| def test_case_4(): | ||
| pytorch_code = textwrap.dedent( | ||
| """ | ||
| import torch | ||
| result = torch.tensor([[ 0.0795, -1.2117, 0.9765], [ 1.1707, 0.6706, 0.4884]],dtype=torch.float64).quantile(0.6, dim=1, keepdim=False) | ||
| """ | ||
| ) | ||
| obj.run(pytorch_code, ["result"]) | ||
|
|
||
|
|
||
| def test_case_5(): | ||
| pytorch_code = textwrap.dedent( | ||
| """ | ||
| import torch | ||
| result = torch.tensor([[ 0.0795, -1.2117, 0.9765], [ 1.1707, 0.6706, 0.4884]],dtype=torch.float64).quantile(0.6, dim=1, keepdim=True) | ||
| """ | ||
| ) | ||
| obj.run(pytorch_code, ["result"]) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,69 @@ | ||
| # Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| import textwrap | ||
|
|
||
| from apibase import APIBase | ||
|
|
||
| obj = APIBase("torch.Tensor.stft") | ||
|
|
||
|
|
||
| def test_case_1(): | ||
| pytorch_code = textwrap.dedent( | ||
| """ | ||
| import torch | ||
| x = torch.tensor([[ (5.975718021392822+0j) , | ||
| (5.975718021392822+0j) , | ||
| (5.341437339782715+0j) , | ||
| (5.404394626617432+0j) , | ||
| (5.404394626617432+0j) ], | ||
| [ (0.0629572868347168+0j) , | ||
| 0.0629572868347168j , | ||
| (-0.0629572868347168-0.6342806816101074j), | ||
| (0.6342806816101074+0j) , | ||
| 0.6342806816101074j ], | ||
| [(-0.4979677200317383+0j) , | ||
| (0.4979677200317383+0j) , | ||
| (0.13631296157836914+0j) , | ||
| (-0.19927024841308594+0j) , | ||
| (0.19927024841308594+0j) ]]) | ||
| result = x.stft(n_fft=4, onesided=False, return_complex=True) | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 同上 |
||
| """ | ||
| ) | ||
| obj.run(pytorch_code, ["result"], rtol=1e-1, atol=1e-04) | ||
|
|
||
|
|
||
| def test_case_2(): | ||
| pytorch_code = textwrap.dedent( | ||
| """ | ||
| import torch | ||
| x = torch.tensor([[ (5.975718021392822+0j) , | ||
| (5.975718021392822+0j) , | ||
| (5.341437339782715+0j) , | ||
| (5.404394626617432+0j) , | ||
| (5.404394626617432+0j) ], | ||
| [ (0.0629572868347168+0j) , | ||
| 0.0629572868347168j , | ||
| (-0.0629572868347168-0.6342806816101074j), | ||
| (0.6342806816101074+0j) , | ||
| 0.6342806816101074j ], | ||
| [(-0.4979677200317383+0j) , | ||
| (0.4979677200317383+0j) , | ||
| (0.13631296157836914+0j) , | ||
| (-0.19927024841308594+0j) , | ||
| (0.19927024841308594+0j) ]]) | ||
| result = x.stft(n_fft=4, onesided=False, return_complex=False) | ||
| """ | ||
| ) | ||
| obj.run(pytorch_code, ["result"], rtol=1e-1, atol=1e-04) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -30,6 +30,18 @@ def test_case_1(): | |
| obj.run( | ||
| pytorch_code, | ||
| ["result"], | ||
| unsupport=True, | ||
| reason="Paddle not support this api convert now", | ||
| ) | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 参考最新的test_Tensor_permute测试,包含十几种用例 |
||
|
|
||
|
|
||
| def test_case_2(): | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 同上,测试内容不全 参考test_Tensor_reshape的测试 |
||
| pytorch_code = textwrap.dedent( | ||
| """ | ||
| import torch | ||
| a = torch.Tensor([[1.,2.], [3.,4.]]) | ||
| result = a.tile(1, 2) | ||
| """ | ||
| ) | ||
| obj.run( | ||
| pytorch_code, | ||
| ["result"], | ||
| ) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,49 @@ | ||
| # Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| import textwrap | ||
|
|
||
| from apibase import APIBase | ||
|
|
||
| obj = APIBase("torch.Tensor.to_sparse") | ||
|
|
||
|
|
||
| def test_case_1(): | ||
| pytorch_code = textwrap.dedent( | ||
| """ | ||
| import torch | ||
| a = torch.Tensor([[1.,2.], [3.,4.]]) | ||
| b = a.to_sparse(1) | ||
| result = b.to_dense() | ||
| """ | ||
| ) | ||
| obj.run( | ||
| pytorch_code, | ||
| ["result"], | ||
| ) | ||
|
|
||
|
|
||
| def test_case_2(): | ||
| pytorch_code = textwrap.dedent( | ||
| """ | ||
| import torch | ||
| a = torch.Tensor([[1.,2.], [3.,4.]]) | ||
| b = a.to_sparse(sparse_dim = 1) | ||
| result = b.to_dense() | ||
| """ | ||
| ) | ||
| obj.run( | ||
| pytorch_code, | ||
| ["result"], | ||
| ) |





Uh oh!
There was an error while loading. Please reload this page.