Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
fb01806
add_pad_op
PolaKuma Nov 25, 2024
cce2a0d
fix codestyle
PolaKuma Nov 25, 2024
73e6907
fix codestyle
PolaKuma Nov 25, 2024
95071d5
Merge branch 'PaddlePaddle:develop' into add_pd_op
PolaKuma Nov 26, 2024
95d254c
update
PolaKuma Nov 28, 2024
9a82b0e
Merge branch 'develop' into add_pd_op
PolaKuma Dec 11, 2024
9a5b9d9
Update trt_op_marker_pass.cc
PolaKuma Dec 30, 2024
2dd2e05
update
PolaKuma Dec 30, 2024
99fc17e
update
PolaKuma Jan 7, 2025
497f7e9
Merge branch 'develop' into add_pd_op
PolaKuma Jan 7, 2025
91f0bf9
Update trt_op_marker_pass.cc
PolaKuma Jan 7, 2025
3471139
Merge branch 'add_pd_op' of https://github.com/PolaKuma/Paddle into a…
PolaKuma Jan 7, 2025
2d8cfcf
fix codestyle
PolaKuma Jan 7, 2025
a047d4a
fix codestyle
PolaKuma Jan 7, 2025
8d1b8b7
Merge branch 'develop' into add_pd_op
PolaKuma Jan 7, 2025
eab4cbd
fix opt
PolaKuma Jan 8, 2025
e22ce64
Update manipulation.py
PolaKuma Jan 8, 2025
acfc70b
Update trt_op_marker_pass.cc
PolaKuma Jan 10, 2025
8cf03ae
Update trt_op_marker_pass.cc
PolaKuma Jan 10, 2025
1f76a96
update
PolaKuma Jan 10, 2025
d5de8a2
Merge branch 'develop' into add_pd_op
PolaKuma Jan 14, 2025
44c6555
Update test_converter_manipulation.py
PolaKuma Jan 21, 2025
1ba7381
Update test_converter_manipulation.py
PolaKuma Jan 21, 2025
a838c35
Merge branch 'PaddlePaddle:develop' into add_pd_op
PolaKuma Feb 10, 2025
b2d8e72
fix
PolaKuma Feb 18, 2025
0247e70
Merge branch 'develop' into add_pd_op
PolaKuma Feb 19, 2025
54606b5
Update trt_op_marker_pass.cc
PolaKuma Feb 21, 2025
a184bcc
fix codestyle
PolaKuma Feb 21, 2025
8906875
Merge branch 'PaddlePaddle:develop' into add_pd_op
PolaKuma Feb 21, 2025
92daaed
fix codestyle
PolaKuma Feb 21, 2025
5a9aa6f
Update trt_op_marker_pass.cc
PolaKuma Feb 21, 2025
a183a41
Merge branch 'PaddlePaddle:develop' into add_pd_op
PolaKuma Feb 22, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 59 additions & 0 deletions paddle/fluid/pir/transforms/tensorrt/trt_op_marker_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1065,6 +1065,7 @@ class FlattenOpPattern
return true;
}
};

class CastOpPattern : public pir::OpRewritePattern<paddle::dialect::CastOp> {
public:
using pir::OpRewritePattern<paddle::dialect::CastOp>::OpRewritePattern;
Expand Down Expand Up @@ -1092,6 +1093,63 @@ class CastOpPattern : public pir::OpRewritePattern<paddle::dialect::CastOp> {
}
};

class PadOpPattern : public pir::OpRewritePattern<paddle::dialect::PadOp> {
public:
using pir::OpRewritePattern<paddle::dialect::PadOp>::OpRewritePattern;

bool MatchAndRewrite(paddle::dialect::PadOp op,
pir::PatternRewriter &rewriter) const override {
if (op->HasAttribute(kCanRunTrtAttr) &&
op->attribute<pir::BoolAttribute>(kCanRunTrtAttr).data()) {
return false;
}
pir::Value pad_value_tensor = op.operand_source(1);
if (!op->HasAttribute("paddings") || !pad_value_tensor) {
VLOG(3) << "PadOp must has 'paddings' and 'pad_value'.";
return false;
}
if (pir::GetDefiningOpForInput(op, 1)->isa<paddle::dialect::FullOp>()) {
paddle::dialect::FullOp full_op =
pir::GetDefiningOpForInput(op, 1)
->dyn_cast<paddle::dialect::FullOp>();
auto pad_value =
full_op->attribute<paddle::dialect::ScalarAttribute>("value")
.data()
.to<float>();
if (pad_value != 0.0f) {
VLOG(3) << "The pad layer of TRT only support zero.";
return false;
}
}
auto paddings_attr = op->attribute<pir::ArrayAttribute>("paddings");
std::vector<int> paddings;
for (const auto &attr : paddings_attr.AsVector()) {
paddings.push_back(attr.dyn_cast<pir::Int32Attribute>().data());
}
int pad_size = paddings.size();
pir::Value x = op.operand_source(0);
auto x_type = x.type().dyn_cast<paddle::dialect::DenseTensorType>();
auto input_shape = x_type.dims();
int nbDims = input_shape.size();
if (nbDims < 2) {
VLOG(3) << "Input must have at least 2 dimensions.";
return false;
}
if (nbDims * 2 != pad_size) {
VLOG(3) << "Padding size must be twice the number of input dimensions.";
return false;
}
for (int i = 0; i < pad_size - 4; i++) {
if (paddings[i] != 0) {
VLOG(3) << "Only the last two dimensions can have non-zero paddings.";
return false;
}
}
op->set_attribute(kCanRunTrtAttr, rewriter.bool_attr(true));
return true;
}
};

class SplitOpPattern : public pir::OpRewritePattern<paddle::dialect::SplitOp> {
public:
using pir::OpRewritePattern<paddle::dialect::SplitOp>::OpRewritePattern;
Expand Down Expand Up @@ -2914,6 +2972,7 @@ class TrtOpMarkerPass : public pir::PatternRewritePass {
ps.Add(std::make_unique<LogicalXorOpPattern>(context));
ps.Add(std::make_unique<CeluOpPattern>(context));
ps.Add(std::make_unique<OneHotOpPattern>(context));
ps.Add(std::make_unique<PadOpPattern>(context));
ps.Add(std::make_unique<TemporalShiftOpPattern>(context));
ps.Add(std::make_unique<IndexPutOpPattern>(context));
ps.Add(std::make_unique<InstanceNormOpPattern>(context));
Expand Down
11 changes: 11 additions & 0 deletions python/paddle/tensorrt/impls/manipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -961,6 +961,17 @@ def roll_converter(network, paddle_op, inputs):
return layer.get_output(0)


@converter_registry.register("pd_op.pad", trt_version="8.x")
def pad_converter(network, paddle_op, inputs):
input_tensor = inputs[0]
paddings = paddle_op.attrs()["paddings"]
pad_size = len(paddings)
pre_pad = [paddings[pad_size - 4], paddings[pad_size - 2]]
post_pad = [paddings[pad_size - 3], paddings[pad_size - 1]]
layer = network.add_padding_nd(input_tensor, pre_pad, post_pad)
return layer.get_output(0)


@converter_registry.register("pd_op.pad3d", trt_version="8.x")
def pad3d_converter(network, paddle_op, inputs):
input_tensor, paddings = inputs
Expand Down
84 changes: 84 additions & 0 deletions test/tensorrt/test_converter_manipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -743,6 +743,90 @@ def test_trt_result(self):
self.check_trt_result()


def wrapper_pad_error(x, padding, mode, pad_value):
return paddle.nn.functional.pad(
x=paddle.to_tensor(np.random.randn(1, 1, 1, 2, 3).astype("float32")),
pad=[0, 0, 0, 0, 0, 0, 1, 1, 0, 0],
mode='constant',
value=0,
)


class TestPadCaseTRTPattern(TensorRTBaseTest):
def setUp(self):
self.python_api = paddle.nn.functional.pad
self.api_args = {
"x": np.random.randn(1, 1, 1, 2, 3).astype("float32"),
"paddings": [0, 0, 0, 0, 0, 0, 1, 1, 0, 0],
"mode": "constant",
"pad_value": np.array([0], dtype="float32"),
}
self.program_config = {"feed_list": ["x", "pad_value"]}
self.min_shape = {"x": [1, 1, 1, 2, 3]}
self.opt_shape = {"x": [1, 1, 1, 2, 3]}
self.max_shape = {"x": [5, 1, 1, 2, 3]}

def test_trt_result_fp16(self):
self.check_trt_result(precision_mode="fp16")

def test_trt_result_fp32(self):
self.check_trt_result()


class TestPadError1TRTPattern(TensorRTBaseTest):
def setUp(self):
self.python_api = paddle.nn.functional.pad
self.api_args = {
"x": np.random.randn(1, 1, 1, 2, 3).astype("float32"),
"paddings": [0, 0, 0, 0, 0, 0, 1, 1, 0, 0],
"mode": "constant",
"pad_value": np.array([1], dtype="float32"),
}
self.program_config = {"feed_list": ["x", "pad_value"]}
self.min_shape = {"x": [1, 1, 1, 2, 3]}
self.opt_shape = {"x": [1, 1, 1, 2, 3]}
self.max_shape = {"x": [5, 1, 1, 2, 3]}

def test_trt_result(self):
self.check_marker(expected_result=False)


class TestPadError2TRTPattern(TensorRTBaseTest):
def setUp(self):
self.python_api = wrapper_pad_error
self.api_args = {
"x": np.random.randn(1, 1, 1, 2, 3).astype("float32"),
"paddings": [1, 1, 1, 0, 0, 0, 1, 1, 0, 0],
"mode": "constant",
"pad_value": np.array([1], dtype="float32"),
}
self.program_config = {"feed_list": ["x", "pad_value"]}
self.min_shape = {"x": [1, 1, 1, 2, 3]}
self.opt_shape = {"x": [1, 1, 1, 2, 3]}
self.max_shape = {"x": [5, 1, 1, 2, 3]}

def test_trt_result(self):
self.check_marker(expected_result=False)


class TestPadError3TRTPattern(TensorRTBaseTest):
def setUp(self):
self.python_api = wrapper_pad_error
self.api_args = {
"x": np.random.randn(1, 1).astype("float32"),
"paddings": [0, 0, 0, 0, 0, 0, 1, 1, 0, 0],
"mode": "constant",
"pad_value": np.array([0], dtype="float32"),
}
self.program_config = {"feed_list": ["x", "pad_value"]}
self.min_shape = {"x": [1, 1, 1, 2, 3]}
self.opt_shape = {"x": [1, 1, 1, 2, 3]}
self.max_shape = {"x": [5, 1, 1, 2, 3]}

def test_trt_result(self):
self.check_marker(expected_result=False)


def wrapper_pad3d(x, paddings, mode, value, data_format):
pad3d = paddle.nn.Pad3D(
padding=[1, 0, 1, 2, 0, 0],
Expand Down