Skip to content

Commit c4c6db0

Browse files
committed
add paddle.incubate.sparse.is_same_shape
1 parent e6bf0bf commit c4c6db0

File tree

4 files changed

+35
-34
lines changed

4 files changed

+35
-34
lines changed

python/paddle/fluid/tests/unittests/test_sparse_multiary.py renamed to python/paddle/fluid/tests/unittests/test_sparse_is_same_shape.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,10 @@
1616
import unittest
1717

1818
import paddle
19-
from paddle.incubate.sparse.multiary import is_same_shape
19+
from paddle.incubate.sparse.binary import is_same_shape
2020

2121

22-
class TestSparseMultiaryAPI(unittest.TestCase):
22+
class TestSparseIsSameShapeAPI(unittest.TestCase):
2323
"""
2424
test paddle.incubate.sparse.is_same_shape
2525
"""

python/paddle/incubate/sparse/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,9 +42,9 @@
4242
from .binary import divide
4343
from .binary import multiply
4444
from .binary import subtract
45+
from .binary import is_same_shape
4546

4647
from .multiary import addmm
47-
from .multiary import is_same_shape
4848

4949
from . import nn
5050

python/paddle/incubate/sparse/binary.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -399,3 +399,35 @@ def divide(x, y, name=None):
399399
if y.dtype != x.dtype:
400400
y = _C_ops.final_state_sparse_cast(y, None, x.dtype)
401401
return _C_ops.final_state_sparse_divide(x, y)
402+
403+
404+
@dygraph_only
405+
def is_same_shape(x, y):
406+
"""
407+
Check whether x.shape equal to y.shape.
408+
409+
Args:
410+
x (Tensor): The input tensor. It can be DenseTensor/SparseCooTensor/SparseCsrTensor.
411+
y (Tensor): The input tensor. It can be DenseTensor/SparseCooTensor/SparseCsrTensor.
412+
413+
Returns:
414+
bool: True for same shape and False for different shape.
415+
416+
Examples:
417+
418+
.. code-block:: python
419+
420+
import paddle
421+
422+
x = paddle.rand([2, 3, 8])
423+
y = paddle.rand([2, 3, 8])
424+
y = y.to_sparse_csr()
425+
z = paddle.rand([2, 5])
426+
427+
paddle.incubate.sparse.is_same_shape(x, y)
428+
# True
429+
paddle.incubate.sparse.is_same_shape(x, z)
430+
# False
431+
432+
"""
433+
return x.is_same_shape(y)

python/paddle/incubate/sparse/multiary.py

Lines changed: 0 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -79,34 +79,3 @@ def addmm(input, x, y, beta=1.0, alpha=1.0, name=None):
7979
8080
"""
8181
return _C_ops.final_state_sparse_addmm(input, x, y, alpha, beta)
82-
83-
84-
@dygraph_only
85-
def is_same_shape(x, y):
86-
"""
87-
Check whether x.shape equal to y.shape.
88-
89-
Args:
90-
x (Tensor): The input tensor. It can be DenseTensor/SparseCooTensor/SparseCsrTensor.
91-
y (Tensor): The input tensor. It can be DenseTensor/SparseCooTensor/SparseCsrTensor.
92-
93-
Returns:
94-
bool: True for same shape and False for different shape.
95-
96-
Examples:
97-
98-
.. code-block:: python
99-
100-
import paddle
101-
102-
x = paddle.rand([2, 3, 8])
103-
y = paddle.rand([2, 3, 8])
104-
z = paddle.rand([2, 5])
105-
106-
paddle.incubate.sparse.is_same_shape(x, y)
107-
# True
108-
paddle.incubate.sparse.is_same_shape(x, z)
109-
# False
110-
111-
"""
112-
return x.is_same_shape(y)

0 commit comments

Comments
 (0)