Skip to content

Commit 973dadd

Browse files
authored
【SCU】【Paddle Tensor No.11】新增 Tensor.__dlpack__() (#69781)
* resubmit * modify * add test file * finish * add test * modify * precommit * modify * modify * modify
1 parent 39cab51 commit 973dadd

File tree

3 files changed

+334
-1
lines changed

3 files changed

+334
-1
lines changed

python/paddle/base/dygraph/tensor_patch_methods.py

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,6 @@ def _to_static_var(self, to_parameter=False, **kwargs):
129129
'strides',
130130
'offset',
131131
'__cuda_array_interface__',
132-
'__dlpack_device__',
133132
]
134133
param_keys = ['stop_gradient', 'trainable']
135134
if isinstance(self, EagerParamBase):
@@ -1366,6 +1365,39 @@ def __cuda_array_interface__(self):
13661365
"version": 2,
13671366
}
13681367

1368+
def __dlpack__(self, stream=None):
1369+
"""
1370+
Creates a DLPack capsule of the current tensor to be exported to other libraries.
1371+
Args:
1372+
stream (int | None): An optional Python integer representing a pointer
1373+
to a CUDA stream. Synchronizes the tensor with this
1374+
stream before exporting.
1375+
If None or -1, no synchronization is performed.
1376+
If 0, the default stream is used.
1377+
"""
1378+
1379+
if self.is_sparse():
1380+
raise AttributeError(
1381+
"Can't get __dlpack__ from a Tensor that requires gradients, "
1382+
"use tensor.detach() if gradients are not required."
1383+
)
1384+
1385+
if not self.stop_gradient:
1386+
raise RuntimeError(
1387+
"Can't get __dlpack__ from Tensor that requires gradients. "
1388+
"If gradients aren't required, use tensor.detach() to get a tensor without gradient."
1389+
)
1390+
1391+
if stream is not None:
1392+
if self.place.is_gpu_place():
1393+
current_stream = paddle.device.cuda.current_stream()
1394+
if stream != current_stream:
1395+
event = paddle.device.cuda.Event()
1396+
event.record(current_stream)
1397+
current_stream.synchronize()
1398+
1399+
return paddle.to_dlpack(self)
1400+
13691401
if not hasattr(core, "eager"):
13701402
return
13711403

@@ -1410,6 +1442,7 @@ def __cuda_array_interface__(self):
14101442
("_use_gpudnn", _use_gpudnn),
14111443
("_md5sum", _md5sum),
14121444
("__cuda_array_interface__", __cuda_array_interface__),
1445+
("__dlpack__", __dlpack__),
14131446
("__dlpack_device__", __dlpack_device__),
14141447
):
14151448
setattr(core.eager.Tensor, method_name, method)

test/dygraph_to_static/test_tensor_attr_consistency.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@
7878
'value',
7979
'zero_',
8080
"__cuda_array_interface__",
81+
'__dlpack__',
8182
"__dlpack_device__",
8283
]
8384
)
Lines changed: 299 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,299 @@
1+
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import unittest
16+
17+
import numpy as np
18+
from utils import dygraph_guard
19+
20+
import paddle
21+
from paddle import base
22+
23+
24+
@unittest.skipIf(
25+
paddle.core.is_compiled_with_xpu(),
26+
"xpu does not support dlpack",
27+
)
28+
class TestDLPack(unittest.TestCase):
29+
def test_dlpack_dygraph(self):
30+
with dygraph_guard():
31+
tensor = paddle.to_tensor(np.array([1, 2, 3, 4]).astype("int"))
32+
dlpack_v1 = paddle.utils.dlpack.to_dlpack(tensor)
33+
out_from_dlpack_v1 = paddle.utils.dlpack.from_dlpack(dlpack_v1)
34+
dlpack_v2 = tensor.__dlpack__()
35+
out_from_dlpack_v2 = paddle.from_dlpack(dlpack_v2)
36+
self.assertTrue(
37+
isinstance(out_from_dlpack_v1, paddle.base.core.eager.Tensor)
38+
)
39+
self.assertTrue(
40+
isinstance(out_from_dlpack_v2, paddle.base.core.eager.Tensor)
41+
)
42+
self.assertEqual(str(tensor.place), str(out_from_dlpack_v1.place))
43+
self.assertEqual(str(tensor.place), str(out_from_dlpack_v2.place))
44+
np.testing.assert_array_equal(
45+
out_from_dlpack_v1.numpy(), np.array([1, 2, 3, 4]).astype("int")
46+
)
47+
np.testing.assert_array_equal(
48+
out_from_dlpack_v2.numpy(), np.array([1, 2, 3, 4]).astype("int")
49+
)
50+
51+
def test_dlpack_tensor_larger_than_2dim(self):
52+
with dygraph_guard():
53+
numpy_data = np.random.randn(4, 5, 6)
54+
t = paddle.to_tensor(numpy_data)
55+
dlpack_v1 = paddle.utils.dlpack.to_dlpack(t)
56+
dlpack_v2 = t.__dlpack__()
57+
out_v1 = paddle.utils.dlpack.from_dlpack(dlpack_v1)
58+
out_v2 = paddle.from_dlpack(dlpack_v2)
59+
self.assertEqual(str(t.place), str(out_v1.place))
60+
self.assertEqual(str(t.place), str(out_v2.place))
61+
np.testing.assert_allclose(numpy_data, out_v1.numpy(), rtol=1e-05)
62+
np.testing.assert_allclose(numpy_data, out_v2.numpy(), rtol=1e-05)
63+
64+
def test_dlpack_dtype_and_place_consistency(self):
65+
with dygraph_guard():
66+
dtypes = [
67+
"float16",
68+
"float32",
69+
"float64",
70+
"int8",
71+
"int16",
72+
"int32",
73+
"int64",
74+
"uint8",
75+
"bool",
76+
]
77+
places = [paddle.CPUPlace()]
78+
if paddle.device.is_compiled_with_cuda():
79+
places.append(base.CUDAPlace(0))
80+
dtypes.append("bfloat16")
81+
82+
data = np.ones((2, 3, 4))
83+
for place in places:
84+
for dtype in dtypes:
85+
x = paddle.to_tensor(data, dtype=dtype, place=place)
86+
dlpack_v1 = paddle.utils.dlpack.to_dlpack(x)
87+
o_v1 = paddle.utils.dlpack.from_dlpack(dlpack_v1)
88+
dlpack_v2 = x.__dlpack__()
89+
o_v2 = paddle.from_dlpack(dlpack_v2)
90+
self.assertEqual(x.dtype, o_v1.dtype)
91+
self.assertEqual(x.dtype, o_v2.dtype)
92+
np.testing.assert_allclose(
93+
x.numpy(), o_v1.numpy(), rtol=1e-05
94+
)
95+
np.testing.assert_allclose(
96+
x.numpy(), o_v2.numpy(), rtol=1e-05
97+
)
98+
self.assertEqual(str(x.place), str(o_v1.place))
99+
self.assertEqual(str(x.place), str(o_v2.place))
100+
101+
complex_dtypes = ["complex64", "complex128"]
102+
for place in places:
103+
for dtype in complex_dtypes:
104+
x = paddle.to_tensor(
105+
[[1 + 6j, 2 + 5j, 3 + 4j], [4 + 3j, 5 + 2j, 6 + 1j]],
106+
dtype=dtype,
107+
place=place,
108+
)
109+
dlpack_v1 = paddle.utils.dlpack.to_dlpack(x)
110+
o_v1 = paddle.utils.dlpack.from_dlpack(dlpack_v1)
111+
dlpack_v2 = x.__dlpack__()
112+
o_v2 = paddle.from_dlpack(dlpack_v2)
113+
self.assertEqual(x.dtype, o_v1.dtype)
114+
self.assertEqual(x.dtype, o_v2.dtype)
115+
np.testing.assert_allclose(
116+
x.numpy(), o_v1.numpy(), rtol=1e-05
117+
)
118+
np.testing.assert_allclose(
119+
x.numpy(), o_v2.numpy(), rtol=1e-05
120+
)
121+
self.assertEqual(str(x.place), str(o_v1.place))
122+
self.assertEqual(str(x.place), str(o_v2.place))
123+
124+
def test_dlpack_deletion(self):
125+
# See Paddle issue 47171
126+
with dygraph_guard():
127+
places = [base.CPUPlace()]
128+
if paddle.is_compiled_with_cuda():
129+
places.append(base.CUDAPlace(0))
130+
for place in places:
131+
for _ in range(4):
132+
a = paddle.rand(shape=[3, 5], dtype="float32").to(
133+
device=place
134+
)
135+
dlpack_v1 = paddle.utils.dlpack.to_dlpack(a)
136+
dlpack_v2 = a.__dlpack__()
137+
b1 = paddle.utils.dlpack.from_dlpack(dlpack_v1)
138+
b2 = paddle.from_dlpack(dlpack_v2)
139+
self.assertEqual(str(a.place), str(b1.place))
140+
self.assertEqual(str(a.place), str(b2.place))
141+
142+
def test_to_dlpack_for_loop(self):
143+
# See Paddle issue 50120
144+
with dygraph_guard():
145+
places = [base.CPUPlace()]
146+
if paddle.is_compiled_with_cuda():
147+
places.append(base.CUDAPlace(0))
148+
for place in places:
149+
for _ in range(4):
150+
x = paddle.rand([3, 5]).to(device=place)
151+
dlpack_v1 = paddle.utils.dlpack.to_dlpack(x)
152+
dlpack_v2 = x.__dlpack__()
153+
154+
def test_to_dlpack_modification(self):
155+
# See Paddle issue 50120
156+
with dygraph_guard():
157+
places = [base.CPUPlace()]
158+
if paddle.is_compiled_with_cuda():
159+
places.append(base.CUDAPlace(0))
160+
for place in places:
161+
for _ in range(4):
162+
x = paddle.rand([3, 5]).to(device=place)
163+
dlpack_v1 = paddle.utils.dlpack.to_dlpack(x)
164+
dlpack_v2 = x.__dlpack__()
165+
y1 = paddle.utils.dlpack.from_dlpack(dlpack_v1)
166+
y2 = paddle.from_dlpack(dlpack_v2)
167+
y1[1:2, 2:5] = 2.0
168+
y2[1:2, 2:5] = 2.0
169+
np.testing.assert_allclose(x.numpy(), y1.numpy())
170+
np.testing.assert_allclose(x.numpy(), y2.numpy())
171+
self.assertEqual(str(x.place), str(y1.place))
172+
self.assertEqual(str(x.place), str(y2.place))
173+
174+
def test_to_dlpack_data_ptr_consistency(self):
175+
# See Paddle issue 50120
176+
with dygraph_guard():
177+
places = [base.CPUPlace()]
178+
if paddle.is_compiled_with_cuda():
179+
places.append(base.CUDAPlace(0))
180+
for place in places:
181+
for _ in range(4):
182+
x = paddle.rand([3, 5]).to(device=place)
183+
dlpack_v1 = paddle.utils.dlpack.to_dlpack(x)
184+
dlpack_v2 = x.__dlpack__()
185+
y1 = paddle.utils.dlpack.from_dlpack(dlpack_v1)
186+
y2 = paddle.from_dlpack(dlpack_v2)
187+
188+
self.assertEqual(x.data_ptr(), y1.data_ptr())
189+
self.assertEqual(x.data_ptr(), y2.data_ptr())
190+
self.assertEqual(str(x.place), str(y1.place))
191+
self.assertEqual(str(x.place), str(y2.place))
192+
193+
def test_to_dlpack_strides_consistency(self):
194+
with dygraph_guard():
195+
places = [base.CPUPlace()]
196+
if paddle.is_compiled_with_cuda():
197+
places.append(base.CUDAPlace(0))
198+
for place in places:
199+
for _ in range(4):
200+
x = paddle.rand([10, 10]).to(device=place)
201+
x_strided = x[::2, ::2]
202+
dlpack_v1 = paddle.utils.dlpack.to_dlpack(x_strided)
203+
dlpack_v2 = x_strided.__dlpack__()
204+
y1 = paddle.utils.dlpack.from_dlpack(dlpack_v1)
205+
y2 = paddle.from_dlpack(dlpack_v2)
206+
207+
self.assertEqual(x_strided.strides, y1.strides)
208+
self.assertEqual(x_strided.strides, y2.strides)
209+
self.assertEqual(str(x_strided.place), str(y1.place))
210+
self.assertEqual(str(x_strided.place), str(y2.place))
211+
np.testing.assert_equal(x_strided.numpy(), y1.numpy())
212+
np.testing.assert_equal(x_strided.numpy(), y2.numpy())
213+
214+
def test_to_dlpack_from_zero_dim(self):
215+
with dygraph_guard():
216+
places = [base.CPUPlace()]
217+
if paddle.is_compiled_with_cuda():
218+
places.append(base.CUDAPlace(0))
219+
for place in places:
220+
for _ in range(4):
221+
x = paddle.to_tensor(1.0, place=place)
222+
dlpack_v1 = paddle.utils.dlpack.to_dlpack(x)
223+
dlpack_v2 = x.__dlpack__()
224+
y1 = paddle.utils.dlpack.from_dlpack(dlpack_v1)
225+
y2 = paddle.from_dlpack(dlpack_v2)
226+
self.assertEqual(x.data_ptr(), y1.data_ptr())
227+
self.assertEqual(x.data_ptr(), y2.data_ptr())
228+
self.assertEqual(str(x.place), str(y1.place))
229+
self.assertEqual(str(x.place), str(y2.place))
230+
self.assertEqual(y1.shape, [])
231+
self.assertEqual(y2.shape, [])
232+
self.assertEqual(y1.numel().item(), 1)
233+
self.assertEqual(y2.numel().item(), 1)
234+
np.testing.assert_array_equal(x.numpy(), y1.numpy())
235+
np.testing.assert_array_equal(x.numpy(), y2.numpy())
236+
237+
def test_to_dlpack_from_zero_size(self):
238+
with dygraph_guard():
239+
places = [base.CPUPlace()]
240+
if paddle.is_compiled_with_cuda():
241+
places.append(base.CUDAPlace(0))
242+
for place in places:
243+
for _ in range(4):
244+
x = paddle.zeros([0, 10]).to(device=place)
245+
dlpack_v1 = paddle.utils.dlpack.to_dlpack(x)
246+
dlpack_v2 = x.__dlpack__()
247+
y1 = paddle.utils.dlpack.from_dlpack(dlpack_v1)
248+
y2 = paddle.from_dlpack(dlpack_v2)
249+
self.assertEqual(x.data_ptr(), y1.data_ptr())
250+
self.assertEqual(x.data_ptr(), y2.data_ptr())
251+
self.assertEqual(str(x.place), str(y1.place))
252+
self.assertEqual(str(x.place), str(y2.place))
253+
self.assertEqual(y1.shape, [0, 10])
254+
self.assertEqual(y2.shape, [0, 10])
255+
self.assertEqual(y1.numel().item(), 0)
256+
self.assertEqual(y2.numel().item(), 0)
257+
np.testing.assert_array_equal(x.numpy(), y1.numpy())
258+
np.testing.assert_array_equal(x.numpy(), y2.numpy())
259+
260+
def test_dlpack_with_custom_stream(self):
261+
if not paddle.is_compiled_with_cuda():
262+
self.skipTest("Test requires CUDA support.")
263+
with dygraph_guard():
264+
paddle.set_device('gpu:0')
265+
s1 = paddle.device.Stream()
266+
s2 = paddle.device.Stream()
267+
e = paddle.device.Event()
268+
s2.wait_event(e)
269+
x = paddle.to_tensor([1, 2, 3], dtype='float32')
270+
s1.synchronize()
271+
dlpack_capsule = x.__dlpack__(s1)
272+
y = paddle.from_dlpack(dlpack_capsule)
273+
np.testing.assert_array_equal(x.numpy(), y.numpy())
274+
self.assertTrue(s1.query(), "Stream s1 did not complete all tasks.")
275+
self.assertTrue(s2.query(), "Stream s2 did not complete all tasks.")
276+
277+
278+
@unittest.skipIf(
279+
paddle.core.is_compiled_with_xpu(),
280+
"xpu does not support dlpack",
281+
)
282+
class TestRaiseError(unittest.TestCase):
283+
def test_dlpack_invalid_sparse(self):
284+
sparse_tensor = paddle.sparse.sparse_coo_tensor(
285+
indices=[[0]], values=[1], shape=[3]
286+
)
287+
with self.assertRaises(AttributeError):
288+
sparse_tensor.__dlpack__()
289+
290+
def test_dlpack_requires_grad(self):
291+
tensor_with_grad = paddle.to_tensor(
292+
[1.0, 2.0, 3.0], stop_gradient=False
293+
)
294+
with self.assertRaises(RuntimeError):
295+
tensor_with_grad.__dlpack__()
296+
297+
298+
if __name__ == "__main__":
299+
unittest.main()

0 commit comments

Comments
 (0)