Skip to content

Commit 252f83d

Browse files
Adapt 3 manipulation converters to TRT 10
- "pd_op.expand" - "pd_op.expand_as" - "pd_op.slice"
1 parent 85d9888 commit 252f83d

File tree

1 file changed

+7
-5
lines changed

1 file changed

+7
-5
lines changed

python/paddle/tensorrt/impls/manipulation.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -231,7 +231,7 @@ def squeeze_converter(network, paddle_op, inputs):
231231
return layer.get_output(0)
232232

233233

234-
@converter_registry.register("pd_op.expand", trt_version="8.x")
234+
@converter_registry.register("pd_op.expand", trt_version="trt_version_ge=8.0")
235235
def expand_converter(network, paddle_op, inputs):
236236
input = inputs[0]
237237
input_dims = input.shape
@@ -253,7 +253,9 @@ def expand_converter(network, paddle_op, inputs):
253253
return trt_expand(network, input, rank, shape_tensor, shape_rank)
254254

255255

256-
@converter_registry.register("pd_op.expand_as", trt_version="8.x")
256+
@converter_registry.register(
257+
"pd_op.expand_as", trt_version="trt_version_ge=8.0"
258+
)
257259
def expand_as_converter(network, paddle_op, inputs):
258260
input = inputs[0]
259261
input_dims = input.shape
@@ -299,7 +301,7 @@ def cast_converter(network, paddle_op, inputs):
299301
return cast_layer.get_output(0)
300302

301303

302-
@converter_registry.register("pd_op.slice", trt_version="8.x")
304+
@converter_registry.register("pd_op.slice", trt_version="trt_version_ge=8.0")
303305
def slice_converter(network, paddle_op, inputs):
304306
input_tensor = inputs[0]
305307
input_shape = paddle_op.operands()[0].source().shape
@@ -308,7 +310,7 @@ def slice_converter(network, paddle_op, inputs):
308310

309311
starts_op = paddle_op.operands()[1].source().get_defining_op()
310312
ends_op = paddle_op.operands()[2].source().get_defining_op()
311-
input_shape_tensor = network.add_shape(input_tensor).get_output(0)
313+
input_shape_tensor = trt_shape(network, input_tensor)
312314
input_rank = len(input_tensor.shape)
313315

314316
starts_tensor = []
@@ -410,7 +412,7 @@ def slice_converter(network, paddle_op, inputs):
410412

411413
# Handle decrease_axis
412414
if decrease_axis:
413-
output_shape = network.add_shape(output_tensor).get_output(0)
415+
output_shape = trt_shape(network, output_tensor)
414416
new_shape_dims = []
415417
for i in range(output_shape.shape[0]):
416418
if i not in decrease_axis:

0 commit comments

Comments
 (0)