Skip to content

Commit b655946

Browse files
committed
add unittest of elementwise mul, sub and div
1 parent c163aae commit b655946

File tree

2 files changed

+26
-2
lines changed

2 files changed

+26
-2
lines changed

python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_elementwise.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,10 @@ def generate_input(shape):
150150
return np.random.random(shape).astype(np.float32)
151151

152152
for shape in [[4], [4, 32], [2, 64, 32], [1, 8, 16, 32]]:
153-
for op_type in ["elementwise_add", "elementwise_mul"]:
153+
for op_type in [
154+
"elementwise_add", "elementwise_mul", "elementwise_sub",
155+
"elementwise_div"
156+
]:
154157
for axis in [0, -1]:
155158
self.dims = len(shape)
156159
dics = [{"axis": axis}]
@@ -306,7 +309,10 @@ def generate_input(shape):
306309
input1_shape = input1_shape_list[i]
307310
for j in range(6):
308311
input2_shape = input2_shape_list[j][i]
309-
for op_type in ["elementwise_add", "elementwise_mul"]:
312+
for op_type in [
313+
"elementwise_add", "elementwise_mul", "elementwise_sub",
314+
"elementwise_div"
315+
]:
310316
for axis in axis_list[j][i]:
311317
self.shape1 = input1_shape
312318
self.shape2 = input2_shape

python/paddle/fluid/tests/unittests/ir/inference/test_trt_elementwise_op.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,5 +56,23 @@ def test_check_output(self):
5656
PassVersionChecker.IsCompatible('tensorrt_subgraph_pass'))
5757

5858

59+
class TensorRTSubgraphPassElementwiseBroadcastTest1(
60+
TensorRTSubgraphPassElementwiseBroadcastTest):
61+
def append_eltwise(self, data1, data2):
62+
return fluid.layers.elementwise_sub(x=data1, y=data2, axis=0)
63+
64+
65+
class TensorRTSubgraphPassElementwiseBroadcastTest2(
66+
TensorRTSubgraphPassElementwiseBroadcastTest):
67+
def append_eltwise(self, data1, data2):
68+
return fluid.layers.elementwise_mul(x=data1, y=data2, axis=0)
69+
70+
71+
class TensorRTSubgraphPassElementwiseBroadcastTest3(
72+
TensorRTSubgraphPassElementwiseBroadcastTest):
73+
def append_eltwise(self, data1, data2):
74+
return fluid.layers.elementwise_div(x=data1, y=data2, axis=0)
75+
76+
5977
if __name__ == "__main__":
6078
unittest.main()

0 commit comments

Comments
 (0)