Skip to content
Merged
134 changes: 0 additions & 134 deletions test/tensorrt/pass_test.py

This file was deleted.

6 changes: 3 additions & 3 deletions test/tensorrt/test_converter_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ def setUp(self):
self.python_api = bilinear_python_api
self.api_args = {
"x": np.random.random([2, 3, 6, 10]).astype("float32"),
"OutSize": np.array([12, 12], dtype="int32"),
"OutSize": np.array([12, 12], dtype="int64"),
"SizeTensor": None,
"Scale": None,
"attrs": {
Expand Down Expand Up @@ -225,8 +225,8 @@ def setUp(self):
"x": x_nchw,
"OutSize": None,
"SizeTensor": [
np.array([12], dtype="int32"),
np.array([12], dtype="int32"),
np.array([12], dtype="int64"),
np.array([12], dtype="int64"),
],
"Scale": None,
"attrs": {
Expand Down
8 changes: 4 additions & 4 deletions test/tensorrt/test_converter_creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,9 +104,9 @@ class TestArangeTRTPattern(TensorRTBaseTest):
def setUp(self):
self.python_api = paddle.arange
self.api_args = {
"start": np.array([0]).astype("int32"),
"end": np.array([6]).astype("int32"),
"step": np.array([1]).astype("int32"),
"start": np.array([0]).astype("int64"),
"end": np.array([6]).astype("int64"),
"step": np.array([1]).astype("int64"),
}
self.program_config = {"feed_list": []}
self.min_shape = {}
Expand Down Expand Up @@ -195,7 +195,7 @@ class TestFullWithTensorTRTPattern(TensorRTBaseTest):
def setUp(self):
self.python_api = paddle.tensor.fill_constant
self.api_args = {
"shape": np.array([1]).astype("int32"),
"shape": np.array([1]).astype("int64"),
"dtype": "float32",
"value": np.array([0.0]).astype("float32"),
}
Expand Down
20 changes: 10 additions & 10 deletions test/tensorrt/test_converter_logic.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@ class TestGreaterThanFloat32TRTPattern(TensorRTBaseTest):
def setUp(self):
self.python_api = paddle.greater_than
self.api_args = {
"x": np.random.randn(2, 3).astype(np.float32),
"y": np.random.randn(3).astype(np.float32),
"x": np.random.randn(2, 3).astype("float32"),
"y": np.random.randn(3).astype("float32"),
}
self.program_config = {"feed_list": ["x", "y"]}
self.min_shape = {"x": [1, 3], "y": [3]}
Expand All @@ -35,12 +35,12 @@ def test_trt_result(self):
self.check_trt_result()


class TestGreaterThanInt32TRTPattern(TensorRTBaseTest):
class TestGreaterThanInt64TRTPattern(TensorRTBaseTest):
def setUp(self):
self.python_api = paddle.greater_than
self.api_args = {
"x": np.random.randn(3).astype(np.int32),
"y": np.random.randn(3).astype(np.int32),
"x": np.random.randn(3).astype("int64"),
"y": np.random.randn(3).astype("int64"),
}
self.program_config = {"feed_list": ["x", "y"]}
self.min_shape = {"x": [1], "y": [1]}
Expand All @@ -54,8 +54,8 @@ class TestLessThanFloat32TRTPattern(TensorRTBaseTest):
def setUp(self):
self.python_api = paddle.less_than
self.api_args = {
"x": np.random.randn(2, 3).astype(np.float32),
"y": np.random.randn(3).astype(np.float32),
"x": np.random.randn(2, 3).astype("float32"),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这类修改的作用是什么呢?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

规范了样式

"y": np.random.randn(3).astype("float32"),
}
self.program_config = {"feed_list": ["x", "y"]}
self.min_shape = {"x": [1, 3], "y": [3]}
Expand All @@ -65,12 +65,12 @@ def test_trt_result(self):
self.check_trt_result()


class TestLessThanInt32TRTPattern(TensorRTBaseTest):
class TestLessThanInt64TRTPattern(TensorRTBaseTest):
def setUp(self):
self.python_api = paddle.less_than
self.api_args = {
"x": np.random.randn(3).astype(np.int32),
"y": np.random.randn(3).astype(np.int32),
"x": np.random.randn(3).astype("int64"),
"y": np.random.randn(3).astype("int64"),
}
self.program_config = {"feed_list": ["x", "y"]}
self.min_shape = {"x": [1], "y": [1]}
Expand Down
Loading