Skip to content

Commit 15259f9

Browse files
authored
【SCU】【Paddle TensorRT No.57】Add pd_op.temporal_shift converter (#69848)
* add * fix codestyle * update * Update trt_op_marker_pass.cc * add_fp16 * Update trt_op_marker_pass.cc * update * Update trt_op_marker_pass.cc * fix * fix * fix * fix codestyle * add_test * add_optshape * update * Update test_converter_others.py * delete size
1 parent 0c14ec7 commit 15259f9

File tree

3 files changed

+275
-0
lines changed

3 files changed

+275
-0
lines changed

paddle/fluid/pir/transforms/tensorrt/trt_op_marker_pass.cc

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2181,6 +2181,27 @@ class OneHotOpPattern
21812181
}
21822182
};
21832183

2184+
class TemporalShiftOpPattern
2185+
: public pir::OpRewritePattern<paddle::dialect::TemporalShiftOp> {
2186+
public:
2187+
using pir::OpRewritePattern<
2188+
paddle::dialect::TemporalShiftOp>::OpRewritePattern;
2189+
2190+
bool MatchAndRewrite(paddle::dialect::TemporalShiftOp op,
2191+
pir::PatternRewriter &rewriter) const override {
2192+
if (op->HasAttribute(kCanRunTrtAttr) &&
2193+
op.attribute<pir::BoolAttribute>(kCanRunTrtAttr).data()) {
2194+
return false;
2195+
}
2196+
if (!op->HasAttribute("shift_ratio") || !op->HasAttribute("seg_num")) {
2197+
VLOG(3) << "temporal shift need attributes : shift_ratio and seg_num";
2198+
return false;
2199+
}
2200+
op->set_attribute(kCanRunTrtAttr, rewriter.bool_attr(true));
2201+
return true;
2202+
}
2203+
};
2204+
21842205
class InstanceNormOpPattern
21852206
: public pir::OpRewritePattern<paddle::dialect::InstanceNormOp> {
21862207
public:
@@ -2388,6 +2409,7 @@ class TrtOpMarkerPass : public pir::PatternRewritePass {
23882409
ps.Add(std::make_unique<TanhOpPattern>(context));
23892410
ps.Add(std::make_unique<CeluOpPattern>(context));
23902411
ps.Add(std::make_unique<OneHotOpPattern>(context));
2412+
ps.Add(std::make_unique<TemporalShiftOpPattern>(context));
23912413
ps.Add(std::make_unique<InstanceNormOpPattern>(context));
23922414
ps.Add(std::make_unique<AffineChannelOpPattern>(context));
23932415
return ps;

python/paddle/tensorrt/impls/others.py

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
trt_concat,
2828
trt_prod,
2929
trt_shape,
30+
trt_sub,
3031
trt_sum,
3132
)
3233
from paddle.tensorrt.register import converter_registry
@@ -274,6 +275,112 @@ def share_data_converter(network, paddle_op, inputs):
274275
return identity_layer.get_output(0)
275276

276277

278+
@converter_registry.register("pd_op.temporal_shift", trt_version="8.x")
279+
def temporal_shift_converter(network, paddle_op, inputs):
280+
input_tensor = inputs[0]
281+
shift_ratio = paddle_op.attrs()["shift_ratio"]
282+
T = paddle_op.attrs()["seg_num"]
283+
data_format = paddle_op.attrs().get("data_format", "NCHW")
284+
285+
if data_format == "NHWC":
286+
# Transpose input to [N, C, H, W]
287+
transpose_layer = network.add_shuffle(input_tensor)
288+
transpose_layer.first_transpose = trt.Permutation([0, 3, 1, 2])
289+
input_tensor = transpose_layer.get_output(0)
290+
291+
input_dims = input_tensor.shape
292+
C, H, W = input_dims[1], input_dims[2], input_dims[3]
293+
294+
# Reshape input to [N, T, C, H, W]
295+
reshape_layer = network.add_shuffle(input_tensor)
296+
reshape_layer.reshape_dims = trt.Dims([-1, T, C, H, W])
297+
input_tensor = reshape_layer.get_output(0)
298+
299+
# Pad input to [N, T + 2, C, H, W]
300+
pre_pad = add_1D_constant_layer(network, [0, 1, 0, 0, 0])
301+
post_pad = add_1D_constant_layer(network, [0, 1, 0, 0, 0])
302+
dims = 5
303+
zeros = add_1D_constant_layer(network, [0] * dims)
304+
start = trt_sub(network, zeros, pre_pad)
305+
total_padding = trt_sum(network, pre_pad, post_pad)
306+
input_shape = trt_shape(network, input_tensor)
307+
size = trt_sum(network, input_shape, total_padding)
308+
stride = [1] * dims
309+
dummy = stride
310+
311+
slice_layer = network.add_slice(input_tensor, dummy, dummy, stride)
312+
slice_layer.set_input(1, start)
313+
slice_layer.set_input(2, size)
314+
315+
trt_version = trt.__version__.split('.')
316+
if int(trt_version[0]) > 8 or (
317+
int(trt_version[0]) == 8 and int(trt_version[1]) >= 5
318+
):
319+
slice_layer.mode = trt.SampleMode.FILL
320+
else:
321+
slice_layer.mode = trt.SliceMode.FILL
322+
323+
slice_c = int(C * shift_ratio)
324+
slice_c2 = int(C * shift_ratio * 2)
325+
326+
slice_start1 = zeros
327+
slice_start2 = add_1D_constant_layer(network, [0, 2, slice_c, 0, 0])
328+
slice_start3 = add_1D_constant_layer(network, [0, 1, slice_c2, 0, 0])
329+
330+
slice_size_base = trt_shape(network, input_tensor)
331+
sub_size1 = add_1D_constant_layer(network, [0, 0, C - slice_c, 0, 0])
332+
sub_size2 = add_1D_constant_layer(
333+
network, [0, 0, C + slice_c - slice_c2, 0, 0]
334+
)
335+
sub_size3 = add_1D_constant_layer(network, [0, 0, slice_c2, 0, 0])
336+
337+
slice_size1 = trt_sub(network, slice_size_base, sub_size1)
338+
slice_size2 = trt_sub(network, slice_size_base, sub_size2)
339+
slice_size3 = trt_sub(network, slice_size_base, sub_size3)
340+
341+
slice1_layer = network.add_slice(
342+
slice_layer.get_output(0), start=dummy, shape=dummy, stride=stride
343+
)
344+
slice1_layer.set_input(1, slice_start1)
345+
slice1_layer.set_input(2, slice_size1)
346+
slice2_layer = network.add_slice(
347+
slice_layer.get_output(0), start=dummy, shape=dummy, stride=stride
348+
)
349+
slice2_layer.set_input(1, slice_start2)
350+
slice2_layer.set_input(2, slice_size2)
351+
slice3_layer = network.add_slice(
352+
slice_layer.get_output(0), start=dummy, shape=dummy, stride=stride
353+
)
354+
slice3_layer.set_input(1, slice_start3)
355+
slice3_layer.set_input(2, slice_size3)
356+
357+
concat_inputs = [slice2_layer.get_output(0), slice3_layer.get_output(0)]
358+
if slice_c == 0:
359+
concat_layer = network.add_concatenation(concat_inputs)
360+
concat_layer.axis = 2
361+
else:
362+
concat_inputs = [
363+
slice1_layer.get_output(0),
364+
slice2_layer.get_output(0),
365+
slice3_layer.get_output(0),
366+
]
367+
concat_layer = network.add_concatenation(concat_inputs)
368+
concat_layer.axis = 2
369+
370+
# Reshape output to [N*T,C,H,W]
371+
reshape_layer3 = network.add_shuffle(concat_layer.get_output(0))
372+
reshape_layer3.reshape_dims = trt.Dims([-1, C, H, W])
373+
374+
if data_format == "NHWC":
375+
transpose_layer2 = network.add_shuffle(reshape_layer3.get_output(0))
376+
transpose_layer2.first_transpose = trt.Permutation([0, 2, 3, 1])
377+
output_tensor = transpose_layer2.get_output(0)
378+
else:
379+
output_tensor = reshape_layer3.get_output(0)
380+
381+
return output_tensor
382+
383+
277384
@converter_registry.register("pd_op.anchor_generator", trt_version="8.x")
278385
def anchor_generator_converter(network, paddle_op, inputs):
279386
inputs = inputs[0]

test/tensorrt/test_converter_others.py

Lines changed: 146 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -406,6 +406,152 @@ def test_trt_result(self):
406406
self.check_trt_result()
407407

408408

409+
class TestTemporalShiftTRTPatternBasic(TensorRTBaseTest):
410+
def setUp(self):
411+
self.python_api = paddle.nn.functional.temporal_shift
412+
self.api_args = {
413+
"x": np.random.random([4, 9, 7, 7]).astype(np.float32),
414+
"seg_num": 2,
415+
"shift_ratio": 0.2,
416+
"data_format": "NCHW",
417+
}
418+
self.program_config = {"feed_list": ["x"]}
419+
self.min_shape = {"x": [2, 9, 7, 7]}
420+
self.opt_shape = {"x": [2, 9, 7, 7]}
421+
self.max_shape = {"x": [8, 9, 7, 7]}
422+
423+
def test_trt_result_fp16(self):
424+
self.check_trt_result(precision_mode="fp16")
425+
426+
def test_trt_result_fp32(self):
427+
self.check_trt_result()
428+
429+
430+
class TestTemporalShiftTRTPatternZeroSlice(TensorRTBaseTest):
431+
def setUp(self):
432+
self.python_api = paddle.nn.functional.temporal_shift
433+
self.api_args = {
434+
"x": np.random.random([4, 2, 7, 7]).astype(np.float32),
435+
"seg_num": 2,
436+
"shift_ratio": 0.2,
437+
"data_format": "NCHW",
438+
}
439+
self.program_config = {"feed_list": ["x"]}
440+
self.min_shape = {"x": [2, 2, 7, 7]}
441+
self.opt_shape = {"x": [2, 2, 7, 7]}
442+
self.max_shape = {"x": [8, 2, 7, 7]}
443+
444+
def test_trt_result_fp16(self):
445+
self.check_trt_result(precision_mode="fp16")
446+
447+
def test_trt_result_fp32(self):
448+
self.check_trt_result()
449+
450+
451+
class TestTemporalShiftTRTPatternDifferentSegNum(TensorRTBaseTest):
452+
def setUp(self):
453+
self.python_api = paddle.nn.functional.temporal_shift
454+
self.api_args = {
455+
"x": np.random.random([4, 9, 7, 7]).astype(np.float32),
456+
"seg_num": 4,
457+
"shift_ratio": 0.2,
458+
"data_format": "NCHW",
459+
}
460+
self.program_config = {"feed_list": ["x"]}
461+
self.min_shape = {"x": [4, 9, 7, 7]}
462+
self.opt_shape = {"x": [4, 9, 7, 7]}
463+
self.max_shape = {"x": [8, 9, 7, 7]}
464+
465+
def test_trt_result_fp16(self):
466+
self.check_trt_result(precision_mode="fp16")
467+
468+
def test_trt_result_fp32(self):
469+
self.check_trt_result()
470+
471+
472+
class TestTemporalShiftTRTPatternDifferentShiftRatio(TensorRTBaseTest):
473+
def setUp(self):
474+
self.python_api = paddle.nn.functional.temporal_shift
475+
self.api_args = {
476+
"x": np.random.random([4, 9, 7, 7]).astype(np.float32),
477+
"seg_num": 2,
478+
"shift_ratio": 0.4,
479+
"data_format": "NCHW",
480+
}
481+
self.program_config = {"feed_list": ["x"]}
482+
self.min_shape = {"x": [2, 9, 7, 7]}
483+
self.opt_shape = {"x": [2, 9, 7, 7]}
484+
self.max_shape = {"x": [8, 9, 7, 7]}
485+
486+
def test_trt_result_fp16(self):
487+
self.check_trt_result(precision_mode="fp16")
488+
489+
def test_trt_result_fp32(self):
490+
self.check_trt_result()
491+
492+
493+
class TestTemporalShiftTRTPatternDifferentDataFormat(TensorRTBaseTest):
494+
def setUp(self):
495+
self.python_api = paddle.nn.functional.temporal_shift
496+
self.api_args = {
497+
"x": np.random.random([4, 9, 7, 7]).astype(np.float32),
498+
"seg_num": 2,
499+
"shift_ratio": 0.2,
500+
"name": None,
501+
"data_format": "NHWC",
502+
}
503+
self.program_config = {"feed_list": ["x"]}
504+
self.min_shape = {"x": [2, 9, 7, 7]}
505+
self.opt_shape = {"x": [2, 9, 7, 7]}
506+
self.max_shape = {"x": [8, 9, 7, 7]}
507+
508+
def test_trt_result_fp16(self):
509+
self.check_trt_result(precision_mode="fp16")
510+
511+
def test_trt_result_fp32(self):
512+
self.check_trt_result()
513+
514+
515+
class TestTemporalShiftTRTPatternMinMaxShape(TensorRTBaseTest):
516+
def setUp(self):
517+
self.python_api = paddle.nn.functional.temporal_shift
518+
self.api_args = {
519+
"x": np.random.random([4, 9, 7, 7]).astype(np.float32),
520+
"seg_num": 2,
521+
"shift_ratio": 0.2,
522+
"data_format": "NCHW",
523+
}
524+
self.program_config = {"feed_list": ["x"]}
525+
self.min_shape = {"x": [2, 9, 7, 7]}
526+
self.opt_shape = {"x": [2, 9, 7, 7]}
527+
self.max_shape = {"x": [10, 9, 7, 7]}
528+
529+
def test_trt_result_fp16(self):
530+
self.check_trt_result(precision_mode="fp16")
531+
532+
def test_trt_result_fp32(self):
533+
self.check_trt_result()
534+
535+
536+
def wrapper_temporal_shift(x):
537+
return paddle.nn.functional.temporal_shift(x=x, seg_num=2, shift_ratio=0.2)
538+
539+
540+
class TestTemporalShiftTRTPatternError1(TensorRTBaseTest):
541+
def setUp(self):
542+
self.python_api = wrapper_temporal_shift
543+
self.api_args = {
544+
"x": np.random.random([4, 9, 7, 7]).astype(np.float32),
545+
}
546+
self.program_config = {"feed_list": ["x"]}
547+
self.min_shape = {"x": [2, 9, 7, 7]}
548+
self.opt_shape = {"x": [2, 9, 7, 7]}
549+
self.max_shape = {"x": [10, 9, 7, 7]}
550+
551+
def test_trt_result(self):
552+
self.check_marker(expected_result=False)
553+
554+
409555
def affine_channel(x, scale_shape, bias_shape, layout):
410556
scale = paddle.static.create_parameter(
411557
shape=scale_shape, dtype='float32', name="scale"

0 commit comments

Comments
 (0)