Skip to content

Commit 027b466

Browse files
Layssylayssy3
andauthored
[Paddle TensorRT] add new marker (#67753)
* add new marker * add test for some op marker * 修改trt_op_marker_pass.cc代码规范 * 修改单测文件代码规范 --------- Co-authored-by: lay <[email protected]>
1 parent 6a3ab98 commit 027b466

10 files changed

+696
-0
lines changed

paddle/fluid/pir/transforms/tensorrt/trt_op_marker_pass.cc

Lines changed: 178 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
12
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
23
//
34
// Licensed under the Apache License, Version 2.0 (the "License");
@@ -66,6 +67,8 @@ DEFINE_GENERAL_PATTERN(Silu, paddle::dialect::SiluOp)
6667
DEFINE_GENERAL_PATTERN(Conv2d, paddle::dialect::Conv2dOp)
6768
DEFINE_GENERAL_PATTERN(FusedConv2dAddAct, paddle::dialect::FusedConv2dAddActOp)
6869
DEFINE_GENERAL_PATTERN(DepthwiseConv2d, paddle::dialect::DepthwiseConv2dOp)
70+
DEFINE_GENERAL_PATTERN(Shape, paddle::dialect::ShapeOp)
71+
DEFINE_GENERAL_PATTERN(Expand, paddle::dialect::ExpandOp)
6972
DEFINE_GENERAL_PATTERN(Sigmoid, paddle::dialect::SigmoidOp)
7073

7174
#undef DEFINE_GENERAL_PATTERN
@@ -919,6 +922,172 @@ class MultiplyOpPattern
919922
}
920923
};
921924

925+
class SubtractOpPattern
926+
: public pir::OpRewritePattern<paddle::dialect::SubtractOp> {
927+
public:
928+
using pir::OpRewritePattern<paddle::dialect::SubtractOp>::OpRewritePattern;
929+
bool MatchAndRewrite(paddle::dialect::SubtractOp op,
930+
pir::PatternRewriter &rewriter) const override {
931+
if (op->HasAttribute(kCanRunTrtAttr) &&
932+
op->attribute<pir::BoolAttribute>(kCanRunTrtAttr).data()) {
933+
return false;
934+
}
935+
pir::Value x = op.operand_source(0);
936+
pir::Value y = op.operand_source(1);
937+
auto x_dtype = pir::GetDataTypeFromValue(x);
938+
auto y_dtype = pir::GetDataTypeFromValue(y);
939+
if (x_dtype.isa<pir::BoolType>() || y_dtype.isa<pir::BoolType>()) {
940+
VLOG(3) << "elementwise_sub do not support boolean datatype.";
941+
return false;
942+
}
943+
944+
op->set_attribute(kCanRunTrtAttr, rewriter.bool_attr(true));
945+
return true;
946+
}
947+
};
948+
949+
class DivideOpPattern
950+
: public pir::OpRewritePattern<paddle::dialect::DivideOp> {
951+
public:
952+
using pir::OpRewritePattern<paddle::dialect::DivideOp>::OpRewritePattern;
953+
bool MatchAndRewrite(paddle::dialect::DivideOp op,
954+
pir::PatternRewriter &rewriter) const override {
955+
if (op->HasAttribute(kCanRunTrtAttr) &&
956+
op->attribute<pir::BoolAttribute>(kCanRunTrtAttr).data()) {
957+
return false;
958+
}
959+
pir::Value x = op.operand_source(0);
960+
pir::Value y = op.operand_source(1);
961+
auto x_dtype = pir::GetDataTypeFromValue(x);
962+
auto y_dtype = pir::GetDataTypeFromValue(y);
963+
if (x_dtype.isa<pir::BoolType>() || y_dtype.isa<pir::BoolType>()) {
964+
VLOG(3) << "elementwise_div do not support boolean datatype.";
965+
return false;
966+
}
967+
968+
op->set_attribute(kCanRunTrtAttr, rewriter.bool_attr(true));
969+
return true;
970+
}
971+
};
972+
973+
class ElementwisePowOpPattern
974+
: public pir::OpRewritePattern<paddle::dialect::ElementwisePowOp> {
975+
public:
976+
using pir::OpRewritePattern<
977+
paddle::dialect::ElementwisePowOp>::OpRewritePattern;
978+
bool MatchAndRewrite(paddle::dialect::ElementwisePowOp op,
979+
pir::PatternRewriter &rewriter) const override {
980+
if (op->HasAttribute(kCanRunTrtAttr) &&
981+
op->attribute<pir::BoolAttribute>(kCanRunTrtAttr).data()) {
982+
return false;
983+
}
984+
pir::Value x = op.operand_source(0);
985+
pir::Value y = op.operand_source(1);
986+
auto x_dtype = pir::GetDataTypeFromValue(x);
987+
auto y_dtype = pir::GetDataTypeFromValue(y);
988+
if (x_dtype.isa<pir::BoolType>() || x_dtype.isa<pir::Int32Type>() ||
989+
y_dtype.isa<pir::BoolType>() || y_dtype.isa<pir::Int32Type>()) {
990+
VLOG(3) << "elementwise_pow do not support"
991+
"boolean datatype and int32 datatype.";
992+
return false;
993+
}
994+
op->set_attribute(kCanRunTrtAttr, rewriter.bool_attr(true));
995+
return true;
996+
}
997+
};
998+
class MinimumOpPattern
999+
: public pir::OpRewritePattern<paddle::dialect::MinimumOp> {
1000+
public:
1001+
using pir::OpRewritePattern<paddle::dialect::MinimumOp>::OpRewritePattern;
1002+
bool MatchAndRewrite(paddle::dialect::MinimumOp op,
1003+
pir::PatternRewriter &rewriter) const override {
1004+
if (op->HasAttribute(kCanRunTrtAttr) &&
1005+
op->attribute<pir::BoolAttribute>(kCanRunTrtAttr).data()) {
1006+
return false;
1007+
}
1008+
pir::Value x = op.operand_source(0);
1009+
pir::Value y = op.operand_source(1);
1010+
auto x_dtype = pir::GetDataTypeFromValue(x);
1011+
auto y_dtype = pir::GetDataTypeFromValue(y);
1012+
if (x_dtype.isa<pir::BoolType>() || y_dtype.isa<pir::BoolType>()) {
1013+
VLOG(3) << "elementwise_min do not support boolean datatype.";
1014+
return false;
1015+
}
1016+
1017+
op->set_attribute(kCanRunTrtAttr, rewriter.bool_attr(true));
1018+
return true;
1019+
}
1020+
};
1021+
class MaximumOpPattern
1022+
: public pir::OpRewritePattern<paddle::dialect::MaximumOp> {
1023+
public:
1024+
using pir::OpRewritePattern<paddle::dialect::MaximumOp>::OpRewritePattern;
1025+
bool MatchAndRewrite(paddle::dialect::MaximumOp op,
1026+
pir::PatternRewriter &rewriter) const override {
1027+
if (op->HasAttribute(kCanRunTrtAttr) &&
1028+
op->attribute<pir::BoolAttribute>(kCanRunTrtAttr).data()) {
1029+
return false;
1030+
}
1031+
pir::Value x = op.operand_source(0);
1032+
pir::Value y = op.operand_source(1);
1033+
auto x_dtype = pir::GetDataTypeFromValue(x);
1034+
auto y_dtype = pir::GetDataTypeFromValue(y);
1035+
if (x_dtype.isa<pir::BoolType>() || y_dtype.isa<pir::BoolType>()) {
1036+
VLOG(3) << "elementwise_max do not support boolean datatype.";
1037+
return false;
1038+
}
1039+
1040+
op->set_attribute(kCanRunTrtAttr, rewriter.bool_attr(true));
1041+
return true;
1042+
}
1043+
};
1044+
1045+
class FloorDivideOpPattern
1046+
: public pir::OpRewritePattern<paddle::dialect::FloorDivideOp> {
1047+
public:
1048+
using pir::OpRewritePattern<paddle::dialect::FloorDivideOp>::OpRewritePattern;
1049+
bool MatchAndRewrite(paddle::dialect::FloorDivideOp op,
1050+
pir::PatternRewriter &rewriter) const override {
1051+
if (op->HasAttribute(kCanRunTrtAttr) &&
1052+
op->attribute<pir::BoolAttribute>(kCanRunTrtAttr).data()) {
1053+
return false;
1054+
}
1055+
pir::Value x = op.operand_source(0);
1056+
pir::Value y = op.operand_source(1);
1057+
auto x_dtype = pir::GetDataTypeFromValue(x);
1058+
auto y_dtype = pir::GetDataTypeFromValue(y);
1059+
if (x_dtype.isa<pir::BoolType>() || y_dtype.isa<pir::BoolType>()) {
1060+
VLOG(3) << "elementwise_floordiv do not support boolean datatype.";
1061+
return false;
1062+
}
1063+
op->set_attribute(kCanRunTrtAttr, rewriter.bool_attr(true));
1064+
return true;
1065+
}
1066+
};
1067+
1068+
class RemainderOpPattern
1069+
: public pir::OpRewritePattern<paddle::dialect::RemainderOp> {
1070+
public:
1071+
using pir::OpRewritePattern<paddle::dialect::RemainderOp>::OpRewritePattern;
1072+
bool MatchAndRewrite(paddle::dialect::RemainderOp op,
1073+
pir::PatternRewriter &rewriter) const override {
1074+
if (op->HasAttribute(kCanRunTrtAttr) &&
1075+
op->attribute<pir::BoolAttribute>(kCanRunTrtAttr).data()) {
1076+
return false;
1077+
}
1078+
pir::Value x = op.operand_source(0);
1079+
pir::Value y = op.operand_source(1);
1080+
auto x_dtype = pir::GetDataTypeFromValue(x);
1081+
auto y_dtype = pir::GetDataTypeFromValue(y);
1082+
if (x_dtype.isa<pir::BoolType>() || y_dtype.isa<pir::BoolType>()) {
1083+
VLOG(3) << "elementwise_mod do not support boolean datatype.";
1084+
return false;
1085+
}
1086+
1087+
op->set_attribute(kCanRunTrtAttr, rewriter.bool_attr(true));
1088+
return true;
1089+
}
1090+
};
9221091
class TrtOpMarkerPass : public pir::PatternRewritePass {
9231092
public:
9241093
TrtOpMarkerPass() : pir::PatternRewritePass("trt_op_marker_pass", 2) {}
@@ -948,6 +1117,8 @@ class TrtOpMarkerPass : public pir::PatternRewritePass {
9481117
ADD_PATTERN(DepthwiseConv2d)
9491118
ADD_PATTERN(Nonzero)
9501119
ADD_PATTERN(Gelu)
1120+
ADD_PATTERN(Shape)
1121+
ADD_PATTERN(Expand)
9511122
ADD_PATTERN(Sigmoid)
9521123

9531124
#undef ADD_PATTERN
@@ -974,6 +1145,13 @@ class TrtOpMarkerPass : public pir::PatternRewritePass {
9741145
ps.Add(std::make_unique<SplitWithNumOpPattern>(context));
9751146
ps.Add(std::make_unique<GreaterEqualOpPattern>(context));
9761147
ps.Add(std::make_unique<MultiplyOpPattern>(context));
1148+
ps.Add(std::make_unique<SubtractOpPattern>(context));
1149+
ps.Add(std::make_unique<DivideOpPattern>(context));
1150+
ps.Add(std::make_unique<ElementwisePowOpPattern>(context));
1151+
ps.Add(std::make_unique<MinimumOpPattern>(context));
1152+
ps.Add(std::make_unique<MaximumOpPattern>(context));
1153+
ps.Add(std::make_unique<FloorDivideOpPattern>(context));
1154+
ps.Add(std::make_unique<RemainderOpPattern>(context));
9771155
return ps;
9781156
}
9791157
};
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import unittest
16+
17+
import numpy as np
18+
from pass_test import PassTest
19+
20+
import paddle
21+
from paddle.base import core
22+
23+
24+
class TestDivideTRTPattern(PassTest):
25+
def is_program_valid(self, program=None):
26+
return True
27+
28+
def sample_program(self):
29+
with paddle.pir_utils.IrGuard():
30+
main_prog = paddle.static.Program()
31+
start_prog = paddle.static.Program()
32+
with paddle.pir.core.program_guard(main_prog, start_prog):
33+
x = paddle.static.data(name='x', shape=[3], dtype='float32')
34+
y = paddle.static.data(name='y', shape=[3], dtype='float32')
35+
divide_out = paddle.divide(x, y)
36+
out = paddle.assign(divide_out)
37+
self.pass_attr_list = [{'trt_op_marker_pass': {}}]
38+
self.feeds = {
39+
"x": np.array([2, 3, 4]).astype("float32"),
40+
"y": np.array([1, 5, 2]).astype("float32"),
41+
}
42+
self.fetch_list = [out]
43+
self.valid_op_map = {
44+
"pd_op.fusion_transpose_flatten_concat": 0,
45+
}
46+
yield [main_prog, start_prog], False
47+
48+
def setUp(self):
49+
if core.is_compiled_with_cuda():
50+
self.places.append(paddle.CUDAPlace(0))
51+
self.trt_expected_ops = {"pd_op.divide"}
52+
53+
def test_check_output(self):
54+
self.check_pass_correct()
55+
56+
57+
if __name__ == '__main__':
58+
unittest.main()
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import unittest
16+
17+
import numpy as np
18+
from pass_test import PassTest
19+
20+
import paddle
21+
from paddle.base import core
22+
23+
24+
class TestElementWisePowTRTPattern(PassTest):
25+
def is_program_valid(self, program=None):
26+
return True
27+
28+
def sample_program(self):
29+
with paddle.pir_utils.IrGuard():
30+
main_prog = paddle.static.Program()
31+
start_prog = paddle.static.Program()
32+
with paddle.pir.core.program_guard(main_prog, start_prog):
33+
x = paddle.static.data(name='x', shape=[3], dtype='float32')
34+
y = paddle.static.data(name='y', shape=[1], dtype='float32')
35+
pow_out = paddle.pow(x, y)
36+
out = paddle.assign(pow_out)
37+
self.pass_attr_list = [{'trt_op_marker_pass': {}}]
38+
self.feeds = {
39+
"x": np.array([1, 2, 3]).astype("float32"),
40+
"y": np.array([2]).astype("float32"),
41+
}
42+
self.fetch_list = [out]
43+
self.valid_op_map = {
44+
"pd_op.fusion_transpose_flatten_concat": 0,
45+
}
46+
yield [main_prog, start_prog], False
47+
48+
def setUp(self):
49+
if core.is_compiled_with_cuda():
50+
self.places.append(paddle.CUDAPlace(0))
51+
self.trt_expected_ops = {"pd_op.elementwise_pow"}
52+
53+
def test_check_output(self):
54+
self.check_pass_correct()
55+
56+
57+
if __name__ == '__main__':
58+
unittest.main()
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import unittest
16+
17+
import numpy as np
18+
from pass_test import PassTest
19+
20+
import paddle
21+
from paddle.base import core
22+
23+
24+
class TestExpandTRTPattern(PassTest):
25+
def is_program_valid(self, program=None):
26+
return True
27+
28+
def sample_program(self):
29+
with paddle.pir_utils.IrGuard():
30+
main_prog = paddle.static.Program()
31+
start_prog = paddle.static.Program()
32+
with paddle.pir.core.program_guard(main_prog, start_prog):
33+
x = paddle.static.data(name="x", shape=[3], dtype="float32")
34+
expand_out = paddle.expand(x, shape=[2, 3])
35+
out = paddle.assign(expand_out)
36+
self.pass_attr_list = [{'trt_op_marker_pass': {}}]
37+
self.feeds = {
38+
"x": np.array([[1, 2, 3]]).astype("float32"),
39+
}
40+
self.fetch_list = [out]
41+
self.valid_op_map = {
42+
"pd_op.fusion_transpose_flatten_concat": 0,
43+
}
44+
yield [main_prog, start_prog], False
45+
46+
def setUp(self):
47+
if core.is_compiled_with_cuda():
48+
self.places.append(paddle.CUDAPlace(0))
49+
self.trt_expected_ops = {"pd_op.expand"}
50+
51+
def test_check_output(self):
52+
self.check_pass_correct()
53+
54+
55+
if __name__ == '__main__':
56+
unittest.main()

0 commit comments

Comments
 (0)