Skip to content

Commit 1bb12d0

Browse files
authored
[SOT] Support .tolist() in SOT Static Mode (#71201)
1 parent 973e1ba commit 1bb12d0

File tree

4 files changed

+50
-6
lines changed

4 files changed

+50
-6
lines changed

python/paddle/jit/sot/utils/paddle_api_config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,7 @@ def _get_tensor_methods():
123123
'register_hook',
124124
'numpy',
125125
'clear_gradient',
126+
'tolist',
126127
# TODO: Browse all possible functions and make prior judgments.
127128
}
128129

python/paddle/pir/math_op_patch.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1000,6 +1000,7 @@ def numpy(self):
10001000
ndarray: dtype is same as current Variable
10011001
Examples:
10021002
.. code-block:: python
1003+
10031004
>>> import paddle
10041005
>>> import paddle.base as base
10051006
>>> from paddle.nn import Linear
@@ -1013,6 +1014,32 @@ def numpy(self):
10131014
"""
10141015
pass
10151016

1017+
@fake_interface_only
1018+
def tolist(self):
1019+
"""
1020+
**Notes**:
1021+
**This API is ONLY available in Dygraph mode**
1022+
Returns a Python list that contains the elements of current :ref:`api_guide_Variable_en`
1023+
1024+
Returns:
1025+
list: The Python list containing the elements of current Variable.
1026+
1027+
Returns type:
1028+
list: Elements have the same dtype as current Variable
1029+
1030+
Examples:
1031+
.. code-block:: python
1032+
1033+
>>> import paddle
1034+
>>> import paddle.base as base
1035+
>>> import numpy as np
1036+
>>> data = np.random.uniform(-1, 1, [2, 3]).astype('float32')
1037+
>>> with base.dygraph.guard():
1038+
... x = paddle.to_tensor(data)
1039+
... print(x.tolist()) # Convert tensor to Python list
1040+
"""
1041+
pass
1042+
10161043
@fake_interface_only
10171044
def register_hook(self, hook):
10181045
"""
@@ -1049,6 +1076,7 @@ def register_hook(self, hook):
10491076
('values', values),
10501077
("_to", _to),
10511078
("to", to),
1079+
("tolist", tolist),
10521080
("numpy", numpy),
10531081
("register_hook", register_hook),
10541082
# For basic operators

test/dygraph_to_static/test_tensor_attr_consistency.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,6 @@
7575
'strides',
7676
'to_sparse_coo',
7777
'to_sparse_csr',
78-
'tolist',
7978
'value',
8079
'zero_',
8180
"__cuda_array_interface__",

test/sot/test_18_tensor_method.py

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,16 @@
1717
from test_case_base import TestCaseBase, test_with_faster_guard
1818

1919
import paddle
20+
from paddle.jit.sot.psdb import check_no_breakgraph
2021

2122

23+
@check_no_breakgraph
2224
def tensor_method_call_1(x: paddle.Tensor):
2325
y = x + 1
2426
return y.mean()
2527

2628

29+
@check_no_breakgraph
2730
def tensor_method_call_2(a: paddle.Tensor, b: paddle.Tensor):
2831
c = a.add(b)
2932
d = c.multiply(a)
@@ -39,15 +42,15 @@ def tensor_method_passed_by_user(a: paddle.Tensor, func: paddle.Tensor):
3942
return func(a)
4043

4144

42-
def tensor_method_property(a: paddle.Tensor, b: paddle.Tensor):
45+
@check_no_breakgraph
46+
def tensor_method_property_without_breakgraph(
47+
a: paddle.Tensor, b: paddle.Tensor
48+
):
4349
return (
4450
a.name,
45-
str(a.place),
4651
a.persistable,
4752
a.dtype,
48-
a.type,
4953
a.is_tensor(),
50-
a.clear_gradient(),
5154
a @ b.T.astype(a.dtype)
5255
+ len(a.shape)
5356
+ b.size
@@ -57,10 +60,22 @@ def tensor_method_property(a: paddle.Tensor, b: paddle.Tensor):
5760
)
5861

5962

63+
def tensor_method_property_with_breakgraph(a: paddle.Tensor, b: paddle.Tensor):
64+
return (
65+
a.type,
66+
a.numpy(),
67+
a.tolist(),
68+
str(a.place),
69+
a.clear_gradient(),
70+
)
71+
72+
73+
@check_no_breakgraph
6074
def tensor_method_property_mT(a: paddle.Tensor):
6175
return a.mT
6276

6377

78+
@check_no_breakgraph
6479
def middle_tensor_name(a: paddle.Tensor, b: paddle.Tensor):
6580
c = a + b
6681
return c.name
@@ -87,7 +102,8 @@ def test_tensor_method_passed_by_user(self):
87102
def test_tensor_method_property(self):
88103
x = paddle.rand([42, 24], dtype='float64')
89104
y = paddle.rand([42, 24], dtype='float32')
90-
self.assert_results(tensor_method_property, x, y)
105+
self.assert_results(tensor_method_property_without_breakgraph, x, y)
106+
self.assert_results(tensor_method_property_with_breakgraph, x, y)
91107

92108
@unittest.skip("TODO: dynamic tensor name is different")
93109
def test_middle_tensor_name(self):

0 commit comments

Comments
 (0)