Skip to content

Commit 461fbd1

Browse files
authored
【SCU】【Paddle TensorRT No.4】Add pd_op.stanh converter (#69539)
* add_tanh * fix codestyle * fix codestyle * fix * fix codestyle * fix codestyle * fix
1 parent 7679609 commit 461fbd1

File tree

3 files changed

+29
-0
lines changed

3 files changed

+29
-0
lines changed

paddle/fluid/pir/transforms/tensorrt/trt_op_marker_pass.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@ DEFINE_GENERAL_PATTERN(Swish, paddle::dialect::SwishOp)
8484
DEFINE_GENERAL_PATTERN(Log, paddle::dialect::LogOp)
8585
DEFINE_GENERAL_PATTERN(Floor, paddle::dialect::FloorOp)
8686
DEFINE_GENERAL_PATTERN(Roll, paddle::dialect::RollOp)
87+
DEFINE_GENERAL_PATTERN(Stanh, paddle::dialect::StanhOp)
8788
DEFINE_GENERAL_PATTERN(Softplus, paddle::dialect::SoftplusOp)
8889
DEFINE_GENERAL_PATTERN(ThresholdedRelu, paddle::dialect::ThresholdedReluOp)
8990
DEFINE_GENERAL_PATTERN(Flip, paddle::dialect::FlipOp)
@@ -2166,6 +2167,7 @@ class TrtOpMarkerPass : public pir::PatternRewritePass {
21662167
ADD_PATTERN(Log)
21672168
ADD_PATTERN(Floor)
21682169
ADD_PATTERN(Roll)
2170+
ADD_PATTERN(Stanh)
21692171
ADD_PATTERN(Softplus)
21702172
ADD_PATTERN(ThresholdedRelu)
21712173
ADD_PATTERN(Flip)

python/paddle/tensorrt/impls/activation.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,17 @@ def swish_silu_converter(network, paddle_op, inputs):
139139
return trt_prod(network, inputs[0], layer_output)
140140

141141

142+
@converter_registry.register("pd_op.stanh", trt_version="8.x")
143+
def stanh_converter(network, paddle_op, inputs):
144+
x = inputs[0]
145+
scale_a = paddle_op.attrs()["scale_a"]
146+
scale_b = paddle_op.attrs()["scale_b"]
147+
stanh_layer = network.add_activation(x, trt.ActivationType.SCALED_TANH)
148+
stanh_layer.alpha = scale_b
149+
stanh_layer.beta = scale_a
150+
return stanh_layer.get_output(0)
151+
152+
142153
@converter_registry.register("pd_op.mish", trt_version="8.x")
143154
def mish_converter(network, paddle_op, inputs):
144155
x = inputs[0]

test/tensorrt/test_converter_activation.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,22 @@ def test_trt_result(self):
128128
self.check_trt_result()
129129

130130

131+
class TestStanhFloatTRTPattern(TensorRTBaseTest):
132+
def setUp(self):
133+
self.python_api = paddle.stanh
134+
self.api_args = {
135+
"x": np.random.randn(2, 3).astype("float32"),
136+
"scale_a": 0.67,
137+
"scale_b": 1.7159,
138+
}
139+
self.program_config = {"feed_list": ["x"]}
140+
self.min_shape = {"x": [1, 3]}
141+
self.max_shape = {"x": [5, 3]}
142+
143+
def test_trt_result(self):
144+
self.check_trt_result()
145+
146+
131147
class TestCeluTRTPattern(TensorRTBaseTest):
132148
def setUp(self):
133149
self.python_api = paddle.nn.functional.celu

0 commit comments

Comments
 (0)