Skip to content

Commit 079677e

Browse files
authored
[Paddle Tensorrt] add pd_op.tanh marker and converter (#68614)
* add tanh_marker * add math_converter * mod marker * add version_judge * add func_activation_converter * mod file * add if * mod file * mod converter * mod map
1 parent 0021608 commit 079677e

File tree

5 files changed

+77
-25
lines changed

5 files changed

+77
-25
lines changed

paddle/fluid/pir/transforms/tensorrt/trt_op_marker_pass.cc

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1459,6 +1459,31 @@ class NearestInterV2Pattern
14591459
}
14601460
};
14611461

1462+
class TanhOpPattern : public pir::OpRewritePattern<paddle::dialect::TanhOp> {
1463+
public:
1464+
using pir::OpRewritePattern<paddle::dialect::TanhOp>::OpRewritePattern;
1465+
bool MatchAndRewrite(paddle::dialect::TanhOp op,
1466+
pir::PatternRewriter &rewriter) const override {
1467+
if (op->HasAttribute(kCanRunTrtAttr) &&
1468+
op.attribute<pir::BoolAttribute>(kCanRunTrtAttr).data()) {
1469+
return false;
1470+
}
1471+
#if IS_TRT_VERSION_LT(8600)
1472+
pir::Value x = op.operand_source(0);
1473+
auto x_type = x.type().dyn_cast<paddle::dialect::DenseTensorType>();
1474+
auto x_shape = x_type.dims();
1475+
int dims = x_shape.size();
1476+
if (dims < 1) {
1477+
VLOG(3) << "Tanh op does not support 0 dim input when TensorRT < 8.6.";
1478+
return false;
1479+
}
1480+
#endif
1481+
1482+
op->set_attribute(kCanRunTrtAttr, rewriter.bool_attr(true));
1483+
return true;
1484+
}
1485+
};
1486+
14621487
class TrtOpMarkerPass : public pir::PatternRewritePass {
14631488
public:
14641489
TrtOpMarkerPass() : pir::PatternRewritePass("trt_op_marker_pass", 2) {}
@@ -1534,6 +1559,7 @@ class TrtOpMarkerPass : public pir::PatternRewritePass {
15341559
ps.Add(std::make_unique<MaxOpPattern>(context));
15351560
ps.Add(std::make_unique<BilinearInterpV2Pattern>(context));
15361561
ps.Add(std::make_unique<NearestInterV2Pattern>(context));
1562+
ps.Add(std::make_unique<TanhOpPattern>(context));
15371563
return ps;
15381564
}
15391565
};

python/paddle/tensorrt/impls/activation.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,21 @@
2020
)
2121
from paddle.tensorrt.register import converter_registry
2222

23+
activation_type_map = {
24+
"pd_op.tanh": trt.ActivationType.TANH,
25+
"pd_op.relu": trt.ActivationType.RELU,
26+
"pd_op.sigmoid": trt.ActivationType.SIGMOID,
27+
}
28+
2329

2430
@converter_registry.register("pd_op.relu", trt_version="8.x")
25-
def relu_converter(network, paddle_op, inputs):
26-
relu_layer = network.add_activation(inputs[0], trt.ActivationType.RELU)
27-
return relu_layer.get_output(0)
31+
@converter_registry.register("pd_op.tanh", trt_version="8.x")
32+
@converter_registry.register("pd_op.sigmoid", trt_version="8.x")
33+
def activation_converter(network, paddle_op, inputs):
34+
layer = network.add_activation(
35+
inputs[0], activation_type_map[paddle_op.name()]
36+
)
37+
return layer.get_output(0)
2838

2939

3040
@converter_registry.register("pd_op.softmax", trt_version="8.x")

python/paddle/tensorrt/impls/ops.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -23,11 +23,3 @@ def sqrt_converter(network, paddle_op, inputs):
2323

2424
sqrt_layer = network.add_unary(input_tensor, trt.UnaryOperation.SQRT)
2525
return sqrt_layer.get_output(0)
26-
27-
28-
@converter_registry.register("pd_op.sigmoid", trt_version="8.x")
29-
def sigmoid_converter(network, paddle_op, inputs):
30-
sigmoid_layer = network.add_activation(
31-
inputs[0], trt.ActivationType.SIGMOID
32-
)
33-
return sigmoid_layer.get_output(0)

test/tensorrt/test_converter_activation.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,5 +48,43 @@ def test_trt_result(self):
4848
self.check_trt_result()
4949

5050

51+
class TestRELUTRTPattern(TensorRTBaseTest):
52+
def setUp(self):
53+
self.python_api = paddle.nn.functional.relu
54+
self.api_args = {"x": np.random.randn(3).astype(np.float32)}
55+
self.program_config = {"feed_list": ["x"]}
56+
self.min_shape = {"x": [1]}
57+
self.max_shape = {"x": [5]}
58+
59+
def test_trt_result(self):
60+
self.check_trt_result()
61+
62+
63+
class TestTANHTRTPattern(TensorRTBaseTest):
64+
def setUp(self):
65+
self.python_api = paddle.tanh
66+
self.api_args = {"x": np.random.randn(3).astype(np.float32)}
67+
self.program_config = {"feed_list": ["x"]}
68+
self.min_shape = {"x": [1]}
69+
self.max_shape = {"x": [5]}
70+
71+
def test_trt_result(self):
72+
self.check_trt_result()
73+
74+
75+
class TestSigmoidTRTPattern(TensorRTBaseTest):
76+
def setUp(self):
77+
self.python_api = paddle.nn.functional.sigmoid
78+
self.api_args = {
79+
"x": np.random.randn(2, 3).astype(np.float32),
80+
}
81+
self.program_config = {"feed_list": ["x"]}
82+
self.min_shape = {"x": [1, 3], "y": [1, 3]}
83+
self.max_shape = {"x": [5, 3], "y": [5, 3]}
84+
85+
def test_trt_result(self):
86+
self.check_trt_result()
87+
88+
5189
if __name__ == '__main__':
5290
unittest.main()

test/tensorrt/test_converter_ops.py

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -34,19 +34,5 @@ def test_trt_result(self):
3434
self.check_trt_result()
3535

3636

37-
class TestSigmoidTRTPattern(TensorRTBaseTest):
38-
def setUp(self):
39-
self.python_api = paddle.nn.functional.sigmoid
40-
self.api_args = {
41-
"x": np.random.randn(2, 3).astype(np.float32),
42-
}
43-
self.program_config = {"feed_list": ["x"]}
44-
self.min_shape = {"x": [1, 3], "y": [1, 3]}
45-
self.max_shape = {"x": [5, 3], "y": [5, 3]}
46-
47-
def test_trt_result(self):
48-
self.check_trt_result()
49-
50-
5137
if __name__ == '__main__':
5238
unittest.main()

0 commit comments

Comments
 (0)