Skip to content

Commit ed19d37

Browse files
authored
Add Unsqueeze op composite rule (#51527)
* first test * add unsqueeze_op
1 parent b76ab79 commit ed19d37

File tree

3 files changed

+29
-2
lines changed

3 files changed

+29
-2
lines changed

python/paddle/fluid/tests/unittests/CMakeLists.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1215,7 +1215,8 @@ set(TEST_CINN_OPS
12151215
test_elementwise_pow_op
12161216
test_transpose_op
12171217
test_reshape_op
1218-
test_mean_op)
1218+
test_mean_op
1219+
test_unsqueeze2_op)
12191220

12201221
foreach(TEST_CINN_OPS ${TEST_CINN_OPS})
12211222
if(WITH_CINN)

python/paddle/fluid/tests/unittests/test_unsqueeze2_op.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,9 +36,12 @@ def setUp(self):
3636
"Out": self.inputs["X"].reshape(self.new_shape),
3737
"XShape": np.random.random(self.ori_shape).astype("float64"),
3838
}
39+
self.prim_op_type = "comp"
3940

4041
def test_check_output(self):
41-
self.check_output(no_check_set=["XShape"], check_eager=True)
42+
self.check_output(
43+
no_check_set=["XShape"], check_eager=True, check_prim=True
44+
)
4245

4346
def test_check_grad(self):
4447
self.check_grad(["X"], "Out", check_eager=True)
@@ -89,20 +92,23 @@ def init_test_case(self):
8992
self.ori_shape = ()
9093
self.axes = (-1,)
9194
self.new_shape = 1
95+
self.enable_cinn = False
9296

9397

9498
class TestUnsqueezeOp_ZeroDim2(TestUnsqueezeOp):
9599
def init_test_case(self):
96100
self.ori_shape = ()
97101
self.axes = (-1, 1)
98102
self.new_shape = (1, 1)
103+
self.enable_cinn = False
99104

100105

101106
class TestUnsqueezeOp_ZeroDim3(TestUnsqueezeOp):
102107
def init_test_case(self):
103108
self.ori_shape = ()
104109
self.axes = (0, 1, 2)
105110
self.new_shape = (1, 1, 1)
111+
self.enable_cinn = False
106112

107113

108114
# axes is a list(with tensor)

python/paddle/incubate/autograd/composite_rules.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -371,3 +371,23 @@ def relu_composite(x):
371371
"""define composite rule of op relu."""
372372
# relu(x) = max(x, 0)
373373
return maximum(x, zeros_like(x))
374+
375+
376+
@REGISTER_COMPOSITE('unsqueeze2')
377+
def unsqueeze_composite(x, axis):
378+
"""define composite rule of op unsqueeze"""
379+
"""using reshape to implement unsqueeze op"""
380+
x_shape = list(x.shape)
381+
axis_list = list(axis)
382+
for i in axis_list:
383+
if i < 0:
384+
i += len(x_shape) + 1
385+
x_shape = (
386+
x_shape[:i]
387+
+ [
388+
1,
389+
]
390+
+ x_shape[i:]
391+
)
392+
out = reshape(x, x_shape)
393+
return [out, None]

0 commit comments

Comments
 (0)