@@ -526,7 +526,9 @@ def wrapper_temporal_shift(x):
526526
527527
528528def wrapper_temporal_shift_2 (x , seg_num , shift_ratio ):
529- return paddle .nn .functional .temporal_shift (x = x , seg_num = seg_num , shift_ratio = shift_ratio )
529+ return paddle .nn .functional .temporal_shift (
530+ x = paddle .randn ([4 , 9 , 7 , 7 ]), seg_num = seg_num , shift_ratio = shift_ratio
531+ )
530532
531533
532534class TestTemporalShiftTRTPatternError1 (TensorRTBaseTest ):
@@ -548,14 +550,14 @@ class TestTemporalShiftTRTPatternError2(TensorRTBaseTest):
548550 def setUp (self ):
549551 self .python_api = wrapper_temporal_shift_2
550552 self .api_args = {
551- "x" : np .random .random ([4 , 9 , 7 , 7 , 7 ]).astype (np .float32 ),
553+ "x" : np .random .random ([4 , 9 , 7 ]).astype (np .float32 ),
552554 "seg_num" : 2 ,
553555 "shift_ratio" : 0.2 ,
554556 }
555557 self .program_config = {"feed_list" : ["x" ]}
556- self .min_shape = {"x" : [2 , 9 , 7 , 7 , 7 ]}
557- self .opt_shape = {"x" : [2 , 9 , 7 , 7 , 7 ]}
558- self .max_shape = {"x" : [10 , 9 , 7 , 7 , 7 ]}
558+ self .min_shape = {"x" : [2 , 9 , 7 ]}
559+ self .opt_shape = {"x" : [2 , 9 , 7 ]}
560+ self .max_shape = {"x" : [10 , 9 , 7 ]}
559561
560562 def test_trt_result (self ):
561563 self .check_marker (expected_result = False )
0 commit comments