Skip to content

Commit a2ec045

Browse files
Adapt 3 manipulation converters to TRT 10
- "pd_op.expand" - "pd_op.expand_as" - "pd_op.slice"
1 parent 0377dce commit a2ec045

File tree

1 file changed

+6
-4
lines changed

1 file changed

+6
-4
lines changed

python/paddle/tensorrt/impls/manipulation.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -265,7 +265,7 @@ def squeeze_converter(network, paddle_op, inputs):
265265
return layer.get_output(0)
266266

267267

268-
@converter_registry.register("pd_op.expand", trt_version="8.x")
268+
@converter_registry.register("pd_op.expand", trt_version="trt_version_ge=8.0")
269269
def expand_converter(network, paddle_op, inputs):
270270
input = inputs[0]
271271
input_dims = input.shape
@@ -287,7 +287,9 @@ def expand_converter(network, paddle_op, inputs):
287287
return trt_expand(network, input, rank, shape_tensor, shape_rank)
288288

289289

290-
@converter_registry.register("pd_op.expand_as", trt_version="8.x")
290+
@converter_registry.register(
291+
"pd_op.expand_as", trt_version="trt_version_ge=8.0"
292+
)
291293
def expand_as_converter(network, paddle_op, inputs):
292294
input = inputs[0]
293295
input_dims = input.shape
@@ -333,15 +335,15 @@ def cast_converter(network, paddle_op, inputs):
333335
return cast_layer.get_output(0)
334336

335337

336-
@converter_registry.register("pd_op.slice", trt_version="8.x")
338+
@converter_registry.register("pd_op.slice", trt_version="trt_version_ge=8.0")
337339
def slice_converter(network, paddle_op, inputs):
338340
input_tensor = inputs[0]
339341
axes = paddle_op.attrs()["axes"]
340342
decrease_axis = paddle_op.attrs().get("decrease_axis")
341343

342344
starts_op = paddle_op.operands()[1].source().get_defining_op()
343345
ends_op = paddle_op.operands()[2].source().get_defining_op()
344-
input_shape_tensor = network.add_shape(input_tensor).get_output(0)
346+
input_shape_tensor = trt_shape(network, input_tensor)
345347
input_rank = len(input_tensor.shape)
346348

347349
starts_tensor = []

0 commit comments

Comments
 (0)