Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
178 changes: 178 additions & 0 deletions paddle/fluid/pir/transforms/tensorrt/trt_op_marker_pass.cc
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@

// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
Expand Down Expand Up @@ -66,6 +67,8 @@ DEFINE_GENERAL_PATTERN(Silu, paddle::dialect::SiluOp)
DEFINE_GENERAL_PATTERN(Conv2d, paddle::dialect::Conv2dOp)
DEFINE_GENERAL_PATTERN(FusedConv2dAddAct, paddle::dialect::FusedConv2dAddActOp)
DEFINE_GENERAL_PATTERN(DepthwiseConv2d, paddle::dialect::DepthwiseConv2dOp)
DEFINE_GENERAL_PATTERN(Shape, paddle::dialect::ShapeOp)
DEFINE_GENERAL_PATTERN(Expand, paddle::dialect::ExpandOp)
DEFINE_GENERAL_PATTERN(Sigmoid, paddle::dialect::SigmoidOp)

#undef DEFINE_GENERAL_PATTERN
Expand Down Expand Up @@ -919,6 +922,172 @@ class MultiplyOpPattern
}
};

class SubtractOpPattern
: public pir::OpRewritePattern<paddle::dialect::SubtractOp> {
public:
using pir::OpRewritePattern<paddle::dialect::SubtractOp>::OpRewritePattern;
bool MatchAndRewrite(paddle::dialect::SubtractOp op,
pir::PatternRewriter &rewriter) const override {
if (op->HasAttribute(kCanRunTrtAttr) &&
op->attribute<pir::BoolAttribute>(kCanRunTrtAttr).data()) {
return false;
}
pir::Value x = op.operand_source(0);
pir::Value y = op.operand_source(1);
auto x_dtype = pir::GetDataTypeFromValue(x);
auto y_dtype = pir::GetDataTypeFromValue(y);
if (x_dtype.isa<pir::BoolType>() || y_dtype.isa<pir::BoolType>()) {
VLOG(3) << "elementwise_sub do not support boolean datatype.";
return false;
}

op->set_attribute(kCanRunTrtAttr, rewriter.bool_attr(true));
return true;
}
};

class DivideOpPattern
: public pir::OpRewritePattern<paddle::dialect::DivideOp> {
public:
using pir::OpRewritePattern<paddle::dialect::DivideOp>::OpRewritePattern;
bool MatchAndRewrite(paddle::dialect::DivideOp op,
pir::PatternRewriter &rewriter) const override {
if (op->HasAttribute(kCanRunTrtAttr) &&
op->attribute<pir::BoolAttribute>(kCanRunTrtAttr).data()) {
return false;
}
pir::Value x = op.operand_source(0);
pir::Value y = op.operand_source(1);
auto x_dtype = pir::GetDataTypeFromValue(x);
auto y_dtype = pir::GetDataTypeFromValue(y);
if (x_dtype.isa<pir::BoolType>() || y_dtype.isa<pir::BoolType>()) {
VLOG(3) << "elementwise_div do not support boolean datatype.";
return false;
}

op->set_attribute(kCanRunTrtAttr, rewriter.bool_attr(true));
return true;
}
};

class ElementwisePowOpPattern
: public pir::OpRewritePattern<paddle::dialect::ElementwisePowOp> {
public:
using pir::OpRewritePattern<
paddle::dialect::ElementwisePowOp>::OpRewritePattern;
bool MatchAndRewrite(paddle::dialect::ElementwisePowOp op,
pir::PatternRewriter &rewriter) const override {
if (op->HasAttribute(kCanRunTrtAttr) &&
op->attribute<pir::BoolAttribute>(kCanRunTrtAttr).data()) {
return false;
}
pir::Value x = op.operand_source(0);
pir::Value y = op.operand_source(1);
auto x_dtype = pir::GetDataTypeFromValue(x);
auto y_dtype = pir::GetDataTypeFromValue(y);
if (x_dtype.isa<pir::BoolType>() || x_dtype.isa<pir::Int32Type>() ||
y_dtype.isa<pir::BoolType>() || y_dtype.isa<pir::Int32Type>()) {
VLOG(3) << "elementwise_pow do not support"
"boolean datatype and int32 datatype.";
return false;
}
op->set_attribute(kCanRunTrtAttr, rewriter.bool_attr(true));
return true;
}
};
class MinimumOpPattern
: public pir::OpRewritePattern<paddle::dialect::MinimumOp> {
public:
using pir::OpRewritePattern<paddle::dialect::MinimumOp>::OpRewritePattern;
bool MatchAndRewrite(paddle::dialect::MinimumOp op,
pir::PatternRewriter &rewriter) const override {
if (op->HasAttribute(kCanRunTrtAttr) &&
op->attribute<pir::BoolAttribute>(kCanRunTrtAttr).data()) {
return false;
}
pir::Value x = op.operand_source(0);
pir::Value y = op.operand_source(1);
auto x_dtype = pir::GetDataTypeFromValue(x);
auto y_dtype = pir::GetDataTypeFromValue(y);
if (x_dtype.isa<pir::BoolType>() || y_dtype.isa<pir::BoolType>()) {
VLOG(3) << "elementwise_min do not support boolean datatype.";
return false;
}

op->set_attribute(kCanRunTrtAttr, rewriter.bool_attr(true));
return true;
}
};
class MaximumOpPattern
: public pir::OpRewritePattern<paddle::dialect::MaximumOp> {
public:
using pir::OpRewritePattern<paddle::dialect::MaximumOp>::OpRewritePattern;
bool MatchAndRewrite(paddle::dialect::MaximumOp op,
pir::PatternRewriter &rewriter) const override {
if (op->HasAttribute(kCanRunTrtAttr) &&
op->attribute<pir::BoolAttribute>(kCanRunTrtAttr).data()) {
return false;
}
pir::Value x = op.operand_source(0);
pir::Value y = op.operand_source(1);
auto x_dtype = pir::GetDataTypeFromValue(x);
auto y_dtype = pir::GetDataTypeFromValue(y);
if (x_dtype.isa<pir::BoolType>() || y_dtype.isa<pir::BoolType>()) {
VLOG(3) << "elementwise_max do not support boolean datatype.";
return false;
}

op->set_attribute(kCanRunTrtAttr, rewriter.bool_attr(true));
return true;
}
};

class FloorDivideOpPattern
: public pir::OpRewritePattern<paddle::dialect::FloorDivideOp> {
public:
using pir::OpRewritePattern<paddle::dialect::FloorDivideOp>::OpRewritePattern;
bool MatchAndRewrite(paddle::dialect::FloorDivideOp op,
pir::PatternRewriter &rewriter) const override {
if (op->HasAttribute(kCanRunTrtAttr) &&
op->attribute<pir::BoolAttribute>(kCanRunTrtAttr).data()) {
return false;
}
pir::Value x = op.operand_source(0);
pir::Value y = op.operand_source(1);
auto x_dtype = pir::GetDataTypeFromValue(x);
auto y_dtype = pir::GetDataTypeFromValue(y);
if (x_dtype.isa<pir::BoolType>() || y_dtype.isa<pir::BoolType>()) {
VLOG(3) << "elementwise_floordiv do not support boolean datatype.";
return false;
}
op->set_attribute(kCanRunTrtAttr, rewriter.bool_attr(true));
return true;
}
};

class RemainderOpPattern
: public pir::OpRewritePattern<paddle::dialect::RemainderOp> {
public:
using pir::OpRewritePattern<paddle::dialect::RemainderOp>::OpRewritePattern;
bool MatchAndRewrite(paddle::dialect::RemainderOp op,
pir::PatternRewriter &rewriter) const override {
if (op->HasAttribute(kCanRunTrtAttr) &&
op->attribute<pir::BoolAttribute>(kCanRunTrtAttr).data()) {
return false;
}
pir::Value x = op.operand_source(0);
pir::Value y = op.operand_source(1);
auto x_dtype = pir::GetDataTypeFromValue(x);
auto y_dtype = pir::GetDataTypeFromValue(y);
if (x_dtype.isa<pir::BoolType>() || y_dtype.isa<pir::BoolType>()) {
VLOG(3) << "elementwise_mod do not support boolean datatype.";
return false;
}

op->set_attribute(kCanRunTrtAttr, rewriter.bool_attr(true));
return true;
}
};
class TrtOpMarkerPass : public pir::PatternRewritePass {
public:
TrtOpMarkerPass() : pir::PatternRewritePass("trt_op_marker_pass", 2) {}
Expand Down Expand Up @@ -948,6 +1117,8 @@ class TrtOpMarkerPass : public pir::PatternRewritePass {
ADD_PATTERN(DepthwiseConv2d)
ADD_PATTERN(Nonzero)
ADD_PATTERN(Gelu)
ADD_PATTERN(Shape)
ADD_PATTERN(Expand)
ADD_PATTERN(Sigmoid)

#undef ADD_PATTERN
Expand All @@ -974,6 +1145,13 @@ class TrtOpMarkerPass : public pir::PatternRewritePass {
ps.Add(std::make_unique<SplitWithNumOpPattern>(context));
ps.Add(std::make_unique<GreaterEqualOpPattern>(context));
ps.Add(std::make_unique<MultiplyOpPattern>(context));
ps.Add(std::make_unique<SubtractOpPattern>(context));
ps.Add(std::make_unique<DivideOpPattern>(context));
ps.Add(std::make_unique<ElementwisePowOpPattern>(context));
ps.Add(std::make_unique<MinimumOpPattern>(context));
ps.Add(std::make_unique<MaximumOpPattern>(context));
ps.Add(std::make_unique<FloorDivideOpPattern>(context));
ps.Add(std::make_unique<RemainderOpPattern>(context));
return ps;
}
};
Expand Down
58 changes: 58 additions & 0 deletions test/tensorrt/test_trt_marker_divide.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest

import numpy as np
from pass_test import PassTest

import paddle
from paddle.base import core


class TestDivideTRTPattern(PassTest):
def is_program_valid(self, program=None):
return True

def sample_program(self):
with paddle.pir_utils.IrGuard():
main_prog = paddle.static.Program()
start_prog = paddle.static.Program()
with paddle.pir.core.program_guard(main_prog, start_prog):
x = paddle.static.data(name='x', shape=[3], dtype='float32')
y = paddle.static.data(name='y', shape=[3], dtype='float32')
divide_out = paddle.divide(x, y)
out = paddle.assign(divide_out)
self.pass_attr_list = [{'trt_op_marker_pass': {}}]
self.feeds = {
"x": np.array([2, 3, 4]).astype("float32"),
"y": np.array([1, 5, 2]).astype("float32"),
}
self.fetch_list = [out]
self.valid_op_map = {
"pd_op.fusion_transpose_flatten_concat": 0,
}
yield [main_prog, start_prog], False

def setUp(self):
if core.is_compiled_with_cuda():
self.places.append(paddle.CUDAPlace(0))
self.trt_expected_ops = {"pd_op.divide"}

def test_check_output(self):
self.check_pass_correct()


if __name__ == '__main__':
unittest.main()
58 changes: 58 additions & 0 deletions test/tensorrt/test_trt_marker_elementwise_pow.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest

import numpy as np
from pass_test import PassTest

import paddle
from paddle.base import core


class TestElementWisePowTRTPattern(PassTest):
def is_program_valid(self, program=None):
return True

def sample_program(self):
with paddle.pir_utils.IrGuard():
main_prog = paddle.static.Program()
start_prog = paddle.static.Program()
with paddle.pir.core.program_guard(main_prog, start_prog):
x = paddle.static.data(name='x', shape=[3], dtype='float32')
y = paddle.static.data(name='y', shape=[1], dtype='float32')
pow_out = paddle.pow(x, y)
out = paddle.assign(pow_out)
self.pass_attr_list = [{'trt_op_marker_pass': {}}]
self.feeds = {
"x": np.array([1, 2, 3]).astype("float32"),
"y": np.array([2]).astype("float32"),
}
self.fetch_list = [out]
self.valid_op_map = {
"pd_op.fusion_transpose_flatten_concat": 0,
}
yield [main_prog, start_prog], False

def setUp(self):
if core.is_compiled_with_cuda():
self.places.append(paddle.CUDAPlace(0))
self.trt_expected_ops = {"pd_op.elementwise_pow"}

def test_check_output(self):
self.check_pass_correct()


if __name__ == '__main__':
unittest.main()
56 changes: 56 additions & 0 deletions test/tensorrt/test_trt_marker_expand.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest

import numpy as np
from pass_test import PassTest

import paddle
from paddle.base import core


class TestExpandTRTPattern(PassTest):
def is_program_valid(self, program=None):
return True

def sample_program(self):
with paddle.pir_utils.IrGuard():
main_prog = paddle.static.Program()
start_prog = paddle.static.Program()
with paddle.pir.core.program_guard(main_prog, start_prog):
x = paddle.static.data(name="x", shape=[3], dtype="float32")
expand_out = paddle.expand(x, shape=[2, 3])
out = paddle.assign(expand_out)
self.pass_attr_list = [{'trt_op_marker_pass': {}}]
self.feeds = {
"x": np.array([[1, 2, 3]]).astype("float32"),
}
self.fetch_list = [out]
self.valid_op_map = {
"pd_op.fusion_transpose_flatten_concat": 0,
}
yield [main_prog, start_prog], False

def setUp(self):
if core.is_compiled_with_cuda():
self.places.append(paddle.CUDAPlace(0))
self.trt_expected_ops = {"pd_op.expand"}

def test_check_output(self):
self.check_pass_correct()


if __name__ == '__main__':
unittest.main()
Loading