Skip to content

Commit 417ce3c

Browse files
authored
【PaddlePaddle Hackathon 3 No.23】为 Paddle 新增 paddle.incubate.sparse.is_same_shape 稀疏 API (#45086)
* add paddle.incubate.sparse.is_same_shape * add paddle.incubate.sparse.is_same_shape * add paddle.incubate.sparse.is_same_shape * add paddle.incubate.sparse.is_same_shape * fix doc
1 parent af3e27e commit 417ce3c

File tree

4 files changed

+173
-0
lines changed

4 files changed

+173
-0
lines changed

paddle/fluid/pybind/eager_method.cc

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1571,6 +1571,15 @@ static PyObject* tensor_method_to_sparse_csr(TensorObject* self,
15711571
EAGER_CATCH_AND_THROW_RETURN_NULL
15721572
}
15731573

1574+
static PyObject* tensor_method_is_same_shape(TensorObject* self,
1575+
PyObject* args,
1576+
PyObject* kwargs) {
1577+
EAGER_TRY
1578+
auto other = CastPyArg2Tensor(PyTuple_GET_ITEM(args, 0), 0);
1579+
return ToPyObject(self->tensor.shape() == other.shape());
1580+
EAGER_CATCH_AND_THROW_RETURN_NULL
1581+
}
1582+
15741583
static PyObject* tensor__inplace_version(TensorObject* self,
15751584
PyObject* args,
15761585
PyObject* kwargs) {
@@ -1966,6 +1975,10 @@ PyMethodDef variable_methods[] = {
19661975
(PyCFunction)(void (*)(void))tensor_method_is_sparse_csr,
19671976
METH_VARARGS | METH_KEYWORDS,
19681977
NULL},
1978+
{"is_same_shape",
1979+
(PyCFunction)(void (*)(void))tensor_method_is_same_shape,
1980+
METH_VARARGS | METH_KEYWORDS,
1981+
NULL},
19691982
{"to_sparse_csr",
19701983
(PyCFunction)(void (*)(void))tensor_method_to_sparse_csr,
19711984
METH_VARARGS | METH_KEYWORDS,
Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from __future__ import print_function
16+
import unittest
17+
18+
import paddle
19+
from paddle.incubate.sparse.binary import is_same_shape
20+
21+
22+
class TestSparseIsSameShapeAPI(unittest.TestCase):
23+
"""
24+
test paddle.incubate.sparse.is_same_shape
25+
"""
26+
27+
def setUp(self):
28+
self.shapes = [[2, 5, 8], [3, 4]]
29+
self.tensors = [
30+
paddle.rand(self.shapes[0]),
31+
paddle.rand(self.shapes[0]),
32+
paddle.rand(self.shapes[1])
33+
]
34+
self.sparse_dim = 2
35+
36+
def test_dense_dense(self):
37+
self.assertTrue(is_same_shape(self.tensors[0], self.tensors[1]))
38+
self.assertFalse(is_same_shape(self.tensors[0], self.tensors[2]))
39+
self.assertFalse(is_same_shape(self.tensors[1], self.tensors[2]))
40+
41+
def test_dense_csr(self):
42+
self.assertTrue(
43+
is_same_shape(self.tensors[0], self.tensors[1].to_sparse_csr()))
44+
self.assertFalse(
45+
is_same_shape(self.tensors[0], self.tensors[2].to_sparse_csr()))
46+
self.assertFalse(
47+
is_same_shape(self.tensors[1], self.tensors[2].to_sparse_csr()))
48+
49+
def test_dense_coo(self):
50+
self.assertTrue(
51+
is_same_shape(self.tensors[0],
52+
self.tensors[1].to_sparse_coo(self.sparse_dim)))
53+
self.assertFalse(
54+
is_same_shape(self.tensors[0],
55+
self.tensors[2].to_sparse_coo(self.sparse_dim)))
56+
self.assertFalse(
57+
is_same_shape(self.tensors[1],
58+
self.tensors[2].to_sparse_coo(self.sparse_dim)))
59+
60+
def test_csr_dense(self):
61+
self.assertTrue(
62+
is_same_shape(self.tensors[0].to_sparse_csr(), self.tensors[1]))
63+
self.assertFalse(
64+
is_same_shape(self.tensors[0].to_sparse_csr(), self.tensors[2]))
65+
self.assertFalse(
66+
is_same_shape(self.tensors[1].to_sparse_csr(), self.tensors[2]))
67+
68+
def test_csr_csr(self):
69+
self.assertTrue(
70+
is_same_shape(self.tensors[0].to_sparse_csr(),
71+
self.tensors[1].to_sparse_csr()))
72+
self.assertFalse(
73+
is_same_shape(self.tensors[0].to_sparse_csr(),
74+
self.tensors[2].to_sparse_csr()))
75+
self.assertFalse(
76+
is_same_shape(self.tensors[1].to_sparse_csr(),
77+
self.tensors[2].to_sparse_csr()))
78+
79+
def test_csr_coo(self):
80+
self.assertTrue(
81+
is_same_shape(self.tensors[0].to_sparse_csr(),
82+
self.tensors[1].to_sparse_coo(self.sparse_dim)))
83+
self.assertFalse(
84+
is_same_shape(self.tensors[0].to_sparse_csr(),
85+
self.tensors[2].to_sparse_coo(self.sparse_dim)))
86+
self.assertFalse(
87+
is_same_shape(self.tensors[1].to_sparse_csr(),
88+
self.tensors[2].to_sparse_coo(self.sparse_dim)))
89+
90+
def test_coo_dense(self):
91+
self.assertTrue(
92+
is_same_shape(self.tensors[0].to_sparse_coo(self.sparse_dim),
93+
self.tensors[1]))
94+
self.assertFalse(
95+
is_same_shape(self.tensors[0].to_sparse_coo(self.sparse_dim),
96+
self.tensors[2]))
97+
self.assertFalse(
98+
is_same_shape(self.tensors[1].to_sparse_coo(self.sparse_dim),
99+
self.tensors[2]))
100+
101+
def test_coo_csr(self):
102+
self.assertTrue(
103+
is_same_shape(self.tensors[0].to_sparse_coo(self.sparse_dim),
104+
self.tensors[1].to_sparse_csr()))
105+
self.assertFalse(
106+
is_same_shape(self.tensors[0].to_sparse_coo(self.sparse_dim),
107+
self.tensors[2].to_sparse_csr()))
108+
self.assertFalse(
109+
is_same_shape(self.tensors[1].to_sparse_coo(self.sparse_dim),
110+
self.tensors[2].to_sparse_csr()))
111+
112+
def test_coo_coo(self):
113+
self.assertTrue(
114+
is_same_shape(self.tensors[0].to_sparse_coo(self.sparse_dim),
115+
self.tensors[1].to_sparse_coo(self.sparse_dim)))
116+
self.assertFalse(
117+
is_same_shape(self.tensors[0].to_sparse_coo(self.sparse_dim),
118+
self.tensors[2].to_sparse_coo(self.sparse_dim)))
119+
self.assertFalse(
120+
is_same_shape(self.tensors[1].to_sparse_coo(self.sparse_dim),
121+
self.tensors[2].to_sparse_coo(self.sparse_dim)))
122+
123+
124+
if __name__ == "__main__":
125+
unittest.main()

python/paddle/incubate/sparse/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
from .binary import divide
4343
from .binary import multiply
4444
from .binary import subtract
45+
from .binary import is_same_shape
4546

4647
from .multiary import addmm
4748

@@ -77,4 +78,5 @@
7778
'multiply',
7879
'divide',
7980
'coalesce',
81+
'is_same_shape',
8082
]

python/paddle/incubate/sparse/binary.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -399,3 +399,36 @@ def divide(x, y, name=None):
399399
if y.dtype != x.dtype:
400400
y = _C_ops.sparse_cast(y, None, x.dtype)
401401
return _C_ops.sparse_divide(x, y)
402+
403+
404+
@dygraph_only
405+
def is_same_shape(x, y):
406+
"""
407+
Return the results of shape comparison between two Tensors, check whether x.shape equal to y.shape.
408+
Any two type Tensor among DenseTensor/SparseCooTensor/SparseCsrTensor are supported.
409+
410+
Args:
411+
x (Tensor): The input tensor. It can be DenseTensor/SparseCooTensor/SparseCsrTensor.
412+
y (Tensor): The input tensor. It can be DenseTensor/SparseCooTensor/SparseCsrTensor.
413+
414+
Returns:
415+
bool: True for same shape and False for different shape.
416+
417+
Examples:
418+
419+
.. code-block:: python
420+
421+
import paddle
422+
423+
x = paddle.rand([2, 3, 8])
424+
y = paddle.rand([2, 3, 8])
425+
y = y.to_sparse_csr()
426+
z = paddle.rand([2, 5])
427+
428+
paddle.incubate.sparse.is_same_shape(x, y)
429+
# True
430+
paddle.incubate.sparse.is_same_shape(x, z)
431+
# False
432+
433+
"""
434+
return x.is_same_shape(y)

0 commit comments

Comments
 (0)