Skip to content

Commit 18bb70a

Browse files
PolaKumaEnigmatisms
authored andcommitted
【SCU】【Paddle TensorRT No.48】Add pd_op.pad converter (PaddlePaddle#69673)
* add_pad_op * fix codestyle * fix codestyle * update * Update trt_op_marker_pass.cc * update * update * Update trt_op_marker_pass.cc * fix codestyle * fix codestyle * fix opt * Update manipulation.py * Update trt_op_marker_pass.cc * Update trt_op_marker_pass.cc * update * Update test_converter_manipulation.py * Update test_converter_manipulation.py * fix * Update trt_op_marker_pass.cc * fix codestyle * fix codestyle * Update trt_op_marker_pass.cc
1 parent 9c38ef8 commit 18bb70a

File tree

3 files changed

+154
-0
lines changed

3 files changed

+154
-0
lines changed

paddle/fluid/pir/transforms/tensorrt/trt_op_marker_pass.cc

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1065,6 +1065,7 @@ class FlattenOpPattern
10651065
return true;
10661066
}
10671067
};
1068+
10681069
class CastOpPattern : public pir::OpRewritePattern<paddle::dialect::CastOp> {
10691070
public:
10701071
using pir::OpRewritePattern<paddle::dialect::CastOp>::OpRewritePattern;
@@ -1092,6 +1093,63 @@ class CastOpPattern : public pir::OpRewritePattern<paddle::dialect::CastOp> {
10921093
}
10931094
};
10941095

1096+
class PadOpPattern : public pir::OpRewritePattern<paddle::dialect::PadOp> {
1097+
public:
1098+
using pir::OpRewritePattern<paddle::dialect::PadOp>::OpRewritePattern;
1099+
1100+
bool MatchAndRewrite(paddle::dialect::PadOp op,
1101+
pir::PatternRewriter &rewriter) const override {
1102+
if (op->HasAttribute(kCanRunTrtAttr) &&
1103+
op->attribute<pir::BoolAttribute>(kCanRunTrtAttr).data()) {
1104+
return false;
1105+
}
1106+
pir::Value pad_value_tensor = op.operand_source(1);
1107+
if (!op->HasAttribute("paddings") || !pad_value_tensor) {
1108+
VLOG(3) << "PadOp must has 'paddings' and 'pad_value'.";
1109+
return false;
1110+
}
1111+
if (pir::GetDefiningOpForInput(op, 1)->isa<paddle::dialect::FullOp>()) {
1112+
paddle::dialect::FullOp full_op =
1113+
pir::GetDefiningOpForInput(op, 1)
1114+
->dyn_cast<paddle::dialect::FullOp>();
1115+
auto pad_value =
1116+
full_op->attribute<paddle::dialect::ScalarAttribute>("value")
1117+
.data()
1118+
.to<float>();
1119+
if (pad_value != 0.0f) {
1120+
VLOG(3) << "The pad layer of TRT only support zero.";
1121+
return false;
1122+
}
1123+
}
1124+
auto paddings_attr = op->attribute<pir::ArrayAttribute>("paddings");
1125+
std::vector<int> paddings;
1126+
for (const auto &attr : paddings_attr.AsVector()) {
1127+
paddings.push_back(attr.dyn_cast<pir::Int32Attribute>().data());
1128+
}
1129+
int pad_size = paddings.size();
1130+
pir::Value x = op.operand_source(0);
1131+
auto x_type = x.type().dyn_cast<paddle::dialect::DenseTensorType>();
1132+
auto input_shape = x_type.dims();
1133+
int nbDims = input_shape.size();
1134+
if (nbDims < 2) {
1135+
VLOG(3) << "Input must have at least 2 dimensions.";
1136+
return false;
1137+
}
1138+
if (nbDims * 2 != pad_size) {
1139+
VLOG(3) << "Padding size must be twice the number of input dimensions.";
1140+
return false;
1141+
}
1142+
for (int i = 0; i < pad_size - 4; i++) {
1143+
if (paddings[i] != 0) {
1144+
VLOG(3) << "Only the last two dimensions can have non-zero paddings.";
1145+
return false;
1146+
}
1147+
}
1148+
op->set_attribute(kCanRunTrtAttr, rewriter.bool_attr(true));
1149+
return true;
1150+
}
1151+
};
1152+
10951153
class SplitOpPattern : public pir::OpRewritePattern<paddle::dialect::SplitOp> {
10961154
public:
10971155
using pir::OpRewritePattern<paddle::dialect::SplitOp>::OpRewritePattern;
@@ -2914,6 +2972,7 @@ class TrtOpMarkerPass : public pir::PatternRewritePass {
29142972
ps.Add(std::make_unique<LogicalXorOpPattern>(context));
29152973
ps.Add(std::make_unique<CeluOpPattern>(context));
29162974
ps.Add(std::make_unique<OneHotOpPattern>(context));
2975+
ps.Add(std::make_unique<PadOpPattern>(context));
29172976
ps.Add(std::make_unique<TemporalShiftOpPattern>(context));
29182977
ps.Add(std::make_unique<IndexPutOpPattern>(context));
29192978
ps.Add(std::make_unique<InstanceNormOpPattern>(context));

python/paddle/tensorrt/impls/manipulation.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -961,6 +961,17 @@ def roll_converter(network, paddle_op, inputs):
961961
return layer.get_output(0)
962962

963963

964+
@converter_registry.register("pd_op.pad", trt_version="8.x")
965+
def pad_converter(network, paddle_op, inputs):
966+
input_tensor = inputs[0]
967+
paddings = paddle_op.attrs()["paddings"]
968+
pad_size = len(paddings)
969+
pre_pad = [paddings[pad_size - 4], paddings[pad_size - 2]]
970+
post_pad = [paddings[pad_size - 3], paddings[pad_size - 1]]
971+
layer = network.add_padding_nd(input_tensor, pre_pad, post_pad)
972+
return layer.get_output(0)
973+
974+
964975
@converter_registry.register("pd_op.pad3d", trt_version="8.x")
965976
def pad3d_converter(network, paddle_op, inputs):
966977
input_tensor, paddings = inputs

test/tensorrt/test_converter_manipulation.py

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -743,6 +743,90 @@ def test_trt_result(self):
743743
self.check_trt_result()
744744

745745

746+
def wrapper_pad_error(x, padding, mode, pad_value):
747+
return paddle.nn.functional.pad(
748+
x=paddle.to_tensor(np.random.randn(1, 1, 1, 2, 3).astype("float32")),
749+
pad=[0, 0, 0, 0, 0, 0, 1, 1, 0, 0],
750+
mode='constant',
751+
value=0,
752+
)
753+
754+
755+
class TestPadCaseTRTPattern(TensorRTBaseTest):
756+
def setUp(self):
757+
self.python_api = paddle.nn.functional.pad
758+
self.api_args = {
759+
"x": np.random.randn(1, 1, 1, 2, 3).astype("float32"),
760+
"paddings": [0, 0, 0, 0, 0, 0, 1, 1, 0, 0],
761+
"mode": "constant",
762+
"pad_value": np.array([0], dtype="float32"),
763+
}
764+
self.program_config = {"feed_list": ["x", "pad_value"]}
765+
self.min_shape = {"x": [1, 1, 1, 2, 3]}
766+
self.opt_shape = {"x": [1, 1, 1, 2, 3]}
767+
self.max_shape = {"x": [5, 1, 1, 2, 3]}
768+
769+
def test_trt_result_fp16(self):
770+
self.check_trt_result(precision_mode="fp16")
771+
772+
def test_trt_result_fp32(self):
773+
self.check_trt_result()
774+
775+
776+
class TestPadError1TRTPattern(TensorRTBaseTest):
777+
def setUp(self):
778+
self.python_api = paddle.nn.functional.pad
779+
self.api_args = {
780+
"x": np.random.randn(1, 1, 1, 2, 3).astype("float32"),
781+
"paddings": [0, 0, 0, 0, 0, 0, 1, 1, 0, 0],
782+
"mode": "constant",
783+
"pad_value": np.array([1], dtype="float32"),
784+
}
785+
self.program_config = {"feed_list": ["x", "pad_value"]}
786+
self.min_shape = {"x": [1, 1, 1, 2, 3]}
787+
self.opt_shape = {"x": [1, 1, 1, 2, 3]}
788+
self.max_shape = {"x": [5, 1, 1, 2, 3]}
789+
790+
def test_trt_result(self):
791+
self.check_marker(expected_result=False)
792+
793+
794+
class TestPadError2TRTPattern(TensorRTBaseTest):
795+
def setUp(self):
796+
self.python_api = wrapper_pad_error
797+
self.api_args = {
798+
"x": np.random.randn(1, 1, 1, 2, 3).astype("float32"),
799+
"paddings": [1, 1, 1, 0, 0, 0, 1, 1, 0, 0],
800+
"mode": "constant",
801+
"pad_value": np.array([1], dtype="float32"),
802+
}
803+
self.program_config = {"feed_list": ["x", "pad_value"]}
804+
self.min_shape = {"x": [1, 1, 1, 2, 3]}
805+
self.opt_shape = {"x": [1, 1, 1, 2, 3]}
806+
self.max_shape = {"x": [5, 1, 1, 2, 3]}
807+
808+
def test_trt_result(self):
809+
self.check_marker(expected_result=False)
810+
811+
812+
class TestPadError3TRTPattern(TensorRTBaseTest):
813+
def setUp(self):
814+
self.python_api = wrapper_pad_error
815+
self.api_args = {
816+
"x": np.random.randn(1, 1).astype("float32"),
817+
"paddings": [0, 0, 0, 0, 0, 0, 1, 1, 0, 0],
818+
"mode": "constant",
819+
"pad_value": np.array([0], dtype="float32"),
820+
}
821+
self.program_config = {"feed_list": ["x", "pad_value"]}
822+
self.min_shape = {"x": [1, 1, 1, 2, 3]}
823+
self.opt_shape = {"x": [1, 1, 1, 2, 3]}
824+
self.max_shape = {"x": [5, 1, 1, 2, 3]}
825+
826+
def test_trt_result(self):
827+
self.check_marker(expected_result=False)
828+
829+
746830
def wrapper_pad3d(x, paddings, mode, value, data_format):
747831
pad3d = paddle.nn.Pad3D(
748832
padding=[1, 0, 1, 2, 0, 0],

0 commit comments

Comments
 (0)