1+
12// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
23//
34// Licensed under the Apache License, Version 2.0 (the "License");
@@ -66,6 +67,8 @@ DEFINE_GENERAL_PATTERN(Silu, paddle::dialect::SiluOp)
6667DEFINE_GENERAL_PATTERN (Conv2d, paddle::dialect::Conv2dOp)
6768DEFINE_GENERAL_PATTERN (FusedConv2dAddAct, paddle::dialect::FusedConv2dAddActOp)
6869DEFINE_GENERAL_PATTERN (DepthwiseConv2d, paddle::dialect::DepthwiseConv2dOp)
70+ DEFINE_GENERAL_PATTERN (Shape, paddle::dialect::ShapeOp)
71+ DEFINE_GENERAL_PATTERN (Expand, paddle::dialect::ExpandOp)
6972DEFINE_GENERAL_PATTERN (Sigmoid, paddle::dialect::SigmoidOp)
7073
7174#undef DEFINE_GENERAL_PATTERN
@@ -919,6 +922,172 @@ class MultiplyOpPattern
919922 }
920923};
921924
925+ class SubtractOpPattern
926+ : public pir::OpRewritePattern<paddle::dialect::SubtractOp> {
927+ public:
928+ using pir::OpRewritePattern<paddle::dialect::SubtractOp>::OpRewritePattern;
929+ bool MatchAndRewrite (paddle::dialect::SubtractOp op,
930+ pir::PatternRewriter &rewriter) const override {
931+ if (op->HasAttribute (kCanRunTrtAttr ) &&
932+ op->attribute <pir::BoolAttribute>(kCanRunTrtAttr ).data ()) {
933+ return false ;
934+ }
935+ pir::Value x = op.operand_source (0 );
936+ pir::Value y = op.operand_source (1 );
937+ auto x_dtype = pir::GetDataTypeFromValue (x);
938+ auto y_dtype = pir::GetDataTypeFromValue (y);
939+ if (x_dtype.isa <pir::BoolType>() || y_dtype.isa <pir::BoolType>()) {
940+ VLOG (3 ) << " elementwise_sub do not support boolean datatype." ;
941+ return false ;
942+ }
943+
944+ op->set_attribute (kCanRunTrtAttr , rewriter.bool_attr (true ));
945+ return true ;
946+ }
947+ };
948+
949+ class DivideOpPattern
950+ : public pir::OpRewritePattern<paddle::dialect::DivideOp> {
951+ public:
952+ using pir::OpRewritePattern<paddle::dialect::DivideOp>::OpRewritePattern;
953+ bool MatchAndRewrite (paddle::dialect::DivideOp op,
954+ pir::PatternRewriter &rewriter) const override {
955+ if (op->HasAttribute (kCanRunTrtAttr ) &&
956+ op->attribute <pir::BoolAttribute>(kCanRunTrtAttr ).data ()) {
957+ return false ;
958+ }
959+ pir::Value x = op.operand_source (0 );
960+ pir::Value y = op.operand_source (1 );
961+ auto x_dtype = pir::GetDataTypeFromValue (x);
962+ auto y_dtype = pir::GetDataTypeFromValue (y);
963+ if (x_dtype.isa <pir::BoolType>() || y_dtype.isa <pir::BoolType>()) {
964+ VLOG (3 ) << " elementwise_div do not support boolean datatype." ;
965+ return false ;
966+ }
967+
968+ op->set_attribute (kCanRunTrtAttr , rewriter.bool_attr (true ));
969+ return true ;
970+ }
971+ };
972+
973+ class ElementwisePowOpPattern
974+ : public pir::OpRewritePattern<paddle::dialect::ElementwisePowOp> {
975+ public:
976+ using pir::OpRewritePattern<
977+ paddle::dialect::ElementwisePowOp>::OpRewritePattern;
978+ bool MatchAndRewrite (paddle::dialect::ElementwisePowOp op,
979+ pir::PatternRewriter &rewriter) const override {
980+ if (op->HasAttribute (kCanRunTrtAttr ) &&
981+ op->attribute <pir::BoolAttribute>(kCanRunTrtAttr ).data ()) {
982+ return false ;
983+ }
984+ pir::Value x = op.operand_source (0 );
985+ pir::Value y = op.operand_source (1 );
986+ auto x_dtype = pir::GetDataTypeFromValue (x);
987+ auto y_dtype = pir::GetDataTypeFromValue (y);
988+ if (x_dtype.isa <pir::BoolType>() || x_dtype.isa <pir::Int32Type>() ||
989+ y_dtype.isa <pir::BoolType>() || y_dtype.isa <pir::Int32Type>()) {
990+ VLOG (3 ) << " elementwise_pow do not support"
991+ " boolean datatype and int32 datatype." ;
992+ return false ;
993+ }
994+ op->set_attribute (kCanRunTrtAttr , rewriter.bool_attr (true ));
995+ return true ;
996+ }
997+ };
998+ class MinimumOpPattern
999+ : public pir::OpRewritePattern<paddle::dialect::MinimumOp> {
1000+ public:
1001+ using pir::OpRewritePattern<paddle::dialect::MinimumOp>::OpRewritePattern;
1002+ bool MatchAndRewrite (paddle::dialect::MinimumOp op,
1003+ pir::PatternRewriter &rewriter) const override {
1004+ if (op->HasAttribute (kCanRunTrtAttr ) &&
1005+ op->attribute <pir::BoolAttribute>(kCanRunTrtAttr ).data ()) {
1006+ return false ;
1007+ }
1008+ pir::Value x = op.operand_source (0 );
1009+ pir::Value y = op.operand_source (1 );
1010+ auto x_dtype = pir::GetDataTypeFromValue (x);
1011+ auto y_dtype = pir::GetDataTypeFromValue (y);
1012+ if (x_dtype.isa <pir::BoolType>() || y_dtype.isa <pir::BoolType>()) {
1013+ VLOG (3 ) << " elementwise_min do not support boolean datatype." ;
1014+ return false ;
1015+ }
1016+
1017+ op->set_attribute (kCanRunTrtAttr , rewriter.bool_attr (true ));
1018+ return true ;
1019+ }
1020+ };
1021+ class MaximumOpPattern
1022+ : public pir::OpRewritePattern<paddle::dialect::MaximumOp> {
1023+ public:
1024+ using pir::OpRewritePattern<paddle::dialect::MaximumOp>::OpRewritePattern;
1025+ bool MatchAndRewrite (paddle::dialect::MaximumOp op,
1026+ pir::PatternRewriter &rewriter) const override {
1027+ if (op->HasAttribute (kCanRunTrtAttr ) &&
1028+ op->attribute <pir::BoolAttribute>(kCanRunTrtAttr ).data ()) {
1029+ return false ;
1030+ }
1031+ pir::Value x = op.operand_source (0 );
1032+ pir::Value y = op.operand_source (1 );
1033+ auto x_dtype = pir::GetDataTypeFromValue (x);
1034+ auto y_dtype = pir::GetDataTypeFromValue (y);
1035+ if (x_dtype.isa <pir::BoolType>() || y_dtype.isa <pir::BoolType>()) {
1036+ VLOG (3 ) << " elementwise_max do not support boolean datatype." ;
1037+ return false ;
1038+ }
1039+
1040+ op->set_attribute (kCanRunTrtAttr , rewriter.bool_attr (true ));
1041+ return true ;
1042+ }
1043+ };
1044+
1045+ class FloorDivideOpPattern
1046+ : public pir::OpRewritePattern<paddle::dialect::FloorDivideOp> {
1047+ public:
1048+ using pir::OpRewritePattern<paddle::dialect::FloorDivideOp>::OpRewritePattern;
1049+ bool MatchAndRewrite (paddle::dialect::FloorDivideOp op,
1050+ pir::PatternRewriter &rewriter) const override {
1051+ if (op->HasAttribute (kCanRunTrtAttr ) &&
1052+ op->attribute <pir::BoolAttribute>(kCanRunTrtAttr ).data ()) {
1053+ return false ;
1054+ }
1055+ pir::Value x = op.operand_source (0 );
1056+ pir::Value y = op.operand_source (1 );
1057+ auto x_dtype = pir::GetDataTypeFromValue (x);
1058+ auto y_dtype = pir::GetDataTypeFromValue (y);
1059+ if (x_dtype.isa <pir::BoolType>() || y_dtype.isa <pir::BoolType>()) {
1060+ VLOG (3 ) << " elementwise_floordiv do not support boolean datatype." ;
1061+ return false ;
1062+ }
1063+ op->set_attribute (kCanRunTrtAttr , rewriter.bool_attr (true ));
1064+ return true ;
1065+ }
1066+ };
1067+
1068+ class RemainderOpPattern
1069+ : public pir::OpRewritePattern<paddle::dialect::RemainderOp> {
1070+ public:
1071+ using pir::OpRewritePattern<paddle::dialect::RemainderOp>::OpRewritePattern;
1072+ bool MatchAndRewrite (paddle::dialect::RemainderOp op,
1073+ pir::PatternRewriter &rewriter) const override {
1074+ if (op->HasAttribute (kCanRunTrtAttr ) &&
1075+ op->attribute <pir::BoolAttribute>(kCanRunTrtAttr ).data ()) {
1076+ return false ;
1077+ }
1078+ pir::Value x = op.operand_source (0 );
1079+ pir::Value y = op.operand_source (1 );
1080+ auto x_dtype = pir::GetDataTypeFromValue (x);
1081+ auto y_dtype = pir::GetDataTypeFromValue (y);
1082+ if (x_dtype.isa <pir::BoolType>() || y_dtype.isa <pir::BoolType>()) {
1083+ VLOG (3 ) << " elementwise_mod do not support boolean datatype." ;
1084+ return false ;
1085+ }
1086+
1087+ op->set_attribute (kCanRunTrtAttr , rewriter.bool_attr (true ));
1088+ return true ;
1089+ }
1090+ };
9221091class TrtOpMarkerPass : public pir ::PatternRewritePass {
9231092 public:
9241093 TrtOpMarkerPass () : pir::PatternRewritePass(" trt_op_marker_pass" , 2 ) {}
@@ -948,6 +1117,8 @@ class TrtOpMarkerPass : public pir::PatternRewritePass {
9481117 ADD_PATTERN (DepthwiseConv2d)
9491118 ADD_PATTERN (Nonzero)
9501119 ADD_PATTERN (Gelu)
1120+ ADD_PATTERN (Shape)
1121+ ADD_PATTERN (Expand)
9511122 ADD_PATTERN (Sigmoid)
9521123
9531124#undef ADD_PATTERN
@@ -974,6 +1145,13 @@ class TrtOpMarkerPass : public pir::PatternRewritePass {
9741145 ps.Add (std::make_unique<SplitWithNumOpPattern>(context));
9751146 ps.Add (std::make_unique<GreaterEqualOpPattern>(context));
9761147 ps.Add (std::make_unique<MultiplyOpPattern>(context));
1148+ ps.Add (std::make_unique<SubtractOpPattern>(context));
1149+ ps.Add (std::make_unique<DivideOpPattern>(context));
1150+ ps.Add (std::make_unique<ElementwisePowOpPattern>(context));
1151+ ps.Add (std::make_unique<MinimumOpPattern>(context));
1152+ ps.Add (std::make_unique<MaximumOpPattern>(context));
1153+ ps.Add (std::make_unique<FloorDivideOpPattern>(context));
1154+ ps.Add (std::make_unique<RemainderOpPattern>(context));
9771155 return ps;
9781156 }
9791157};
0 commit comments