Skip to content

Commit f05046a

Browse files
committed
Update test_converter_others.py
1 parent 9087b4c commit f05046a

File tree

1 file changed

+7
-5
lines changed

1 file changed

+7
-5
lines changed

test/tensorrt/test_converter_others.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -526,7 +526,9 @@ def wrapper_temporal_shift(x):
526526

527527

528528
def wrapper_temporal_shift_2(x, seg_num, shift_ratio):
529-
return paddle.nn.functional.temporal_shift(x=x, seg_num=seg_num, shift_ratio=shift_ratio)
529+
return paddle.nn.functional.temporal_shift(
530+
x=paddle.randn([4, 9, 7, 7]), seg_num=seg_num, shift_ratio=shift_ratio
531+
)
530532

531533

532534
class TestTemporalShiftTRTPatternError1(TensorRTBaseTest):
@@ -548,14 +550,14 @@ class TestTemporalShiftTRTPatternError2(TensorRTBaseTest):
548550
def setUp(self):
549551
self.python_api = wrapper_temporal_shift_2
550552
self.api_args = {
551-
"x": np.random.random([4, 9, 7, 7, 7]).astype(np.float32),
553+
"x": np.random.random([4, 9, 7]).astype(np.float32),
552554
"seg_num": 2,
553555
"shift_ratio": 0.2,
554556
}
555557
self.program_config = {"feed_list": ["x"]}
556-
self.min_shape = {"x": [2, 9, 7, 7, 7]}
557-
self.opt_shape = {"x": [2, 9, 7, 7, 7]}
558-
self.max_shape = {"x": [10, 9, 7, 7, 7]}
558+
self.min_shape = {"x": [2, 9, 7]}
559+
self.opt_shape = {"x": [2, 9, 7]}
560+
self.max_shape = {"x": [10, 9, 7]}
559561

560562
def test_trt_result(self):
561563
self.check_marker(expected_result=False)

0 commit comments

Comments
 (0)