@@ -231,7 +231,7 @@ def squeeze_converter(network, paddle_op, inputs):
231231 return layer .get_output (0 )
232232
233233
234- @converter_registry .register ("pd_op.expand" , trt_version = "8.x " )
234+ @converter_registry .register ("pd_op.expand" , trt_version = "trt_version_ge=8.0 " )
235235def expand_converter (network , paddle_op , inputs ):
236236 input = inputs [0 ]
237237 input_dims = input .shape
@@ -253,7 +253,9 @@ def expand_converter(network, paddle_op, inputs):
253253 return trt_expand (network , input , rank , shape_tensor , shape_rank )
254254
255255
256- @converter_registry .register ("pd_op.expand_as" , trt_version = "8.x" )
256+ @converter_registry .register (
257+ "pd_op.expand_as" , trt_version = "trt_version_ge=8.0"
258+ )
257259def expand_as_converter (network , paddle_op , inputs ):
258260 input = inputs [0 ]
259261 input_dims = input .shape
@@ -299,7 +301,7 @@ def cast_converter(network, paddle_op, inputs):
299301 return cast_layer .get_output (0 )
300302
301303
302- @converter_registry .register ("pd_op.slice" , trt_version = "8.x " )
304+ @converter_registry .register ("pd_op.slice" , trt_version = "trt_version_ge=8.0 " )
303305def slice_converter (network , paddle_op , inputs ):
304306 input_tensor = inputs [0 ]
305307 input_shape = paddle_op .operands ()[0 ].source ().shape
@@ -308,7 +310,7 @@ def slice_converter(network, paddle_op, inputs):
308310
309311 starts_op = paddle_op .operands ()[1 ].source ().get_defining_op ()
310312 ends_op = paddle_op .operands ()[2 ].source ().get_defining_op ()
311- input_shape_tensor = network . add_shape ( input_tensor ). get_output ( 0 )
313+ input_shape_tensor = trt_shape ( network , input_tensor )
312314 input_rank = len (input_tensor .shape )
313315
314316 starts_tensor = []
@@ -410,7 +412,7 @@ def slice_converter(network, paddle_op, inputs):
410412
411413 # Handle decrease_axis
412414 if decrease_axis :
413- output_shape = network . add_shape ( output_tensor ). get_output ( 0 )
415+ output_shape = trt_shape ( network , output_tensor )
414416 new_shape_dims = []
415417 for i in range (output_shape .shape [0 ]):
416418 if i not in decrease_axis :
0 commit comments