Skip to content

Commit 05091af

Browse files
committed
fix bug
1 parent 97e1610 commit 05091af

File tree

1 file changed

+8
-1
lines changed
  • python/paddle/tensorrt/impls

1 file changed

+8
-1
lines changed

python/paddle/tensorrt/impls/math.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -275,7 +275,8 @@ def cumsum_converter(network, paddle_op, inputs):
275275
input_sliced.set_input(2, slice_shape)
276276

277277
# squeeze axis
278-
shape_list.pop(axis)
278+
if rank > 1:
279+
shape_list.pop(axis)
279280
new_shape = network.add_concatenation(shape_list).get_output(0)
280281
squeeze_layer = network.add_shuffle(input_sliced.get_output(0))
281282
squeeze_layer.set_input(1, new_shape)
@@ -301,6 +302,12 @@ def cumsum_converter(network, paddle_op, inputs):
301302
lhs_val, cast_tensor, trt.ElementWiseOperation.PROD
302303
).get_output(0)
303304

305+
# Set as scalar
306+
if rank == 1:
307+
shuffle_layer = network.add_shuffle(zero_tensor)
308+
shuffle_layer.reshape_dims = trt.Dims()
309+
zero_tensor = shuffle_layer.get_output(0)
310+
304311
# Cycle and add according to the axis
305312
running_sum = loop.add_recurrence(zero_tensor)
306313
running_sum_tensor = running_sum.get_output(0)

0 commit comments

Comments
 (0)