Skip to content

Commit 1cf0e34

Browse files
LLee233lixcli
authored andcommitted
[PIR][oneDNN] Extend Conv bias fusion capability (PaddlePaddle#66000)
* extend conv+bias fusion * revise copyright
1 parent c9d4581 commit 1cf0e34

File tree

3 files changed

+116
-14
lines changed

3 files changed

+116
-14
lines changed

paddle/fluid/pir/transforms/onednn/conv_bias_fuse_pass.cc

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -58,8 +58,7 @@ class ConvBiasFusePattern : public paddle::drr::DrrPatternBase {
5858

5959
pat.Tensor("add_out") = add(pat.Tensor("conv_out"), pat.Tensor("bias"));
6060

61-
if (conv_name_ == paddle::dialect::Conv2dOp::name() ||
62-
conv_name_ == paddle::onednn::dialect::FusedConv2dOp::name()) {
61+
if (conv_name_ == paddle::dialect::Conv2dOp::name()) {
6362
pat.AddConstraint([&](const paddle::drr::MatchContext &match_ctx) {
6463
if (!pir::ValueIsPersistable(match_ctx.Tensor("bias"))) {
6564
return false;
@@ -96,7 +95,19 @@ class ConvBiasFusePattern : public paddle::drr::DrrPatternBase {
9695
}
9796
pat.AddConstraint([&](const paddle::drr::MatchContext &match_ctx) {
9897
auto bias_shape = pir::GetShapeFromValue(match_ctx.Tensor("bias"));
99-
if (bias_shape.size() != 1) return false;
98+
auto output_shape = pir::GetShapeFromValue(match_ctx.Tensor("conv_out"));
99+
if (bias_shape.size() != 1) {
100+
if (bias_shape[1] != output_shape[1]) return false;
101+
bool is_ok = true;
102+
for (size_t i = 0; i < bias_shape.size(); i++) {
103+
if (i == 1) continue;
104+
if (bias_shape[i] != 1) {
105+
is_ok = false;
106+
break;
107+
}
108+
}
109+
return is_ok;
110+
}
100111
return true;
101112
});
102113

@@ -304,7 +315,7 @@ class FusedConvTransposeAddFusePattern : public paddle::drr::DrrPatternBase {
304315

305316
class Conv2dBiasFusePass : public pir::PatternRewritePass {
306317
public:
307-
Conv2dBiasFusePass() : pir::PatternRewritePass("conv2d_bias_fuse_pass", 2) {}
318+
Conv2dBiasFusePass() : pir::PatternRewritePass("conv2d_bias_fuse_pass", 3) {}
308319

309320
pir::RewritePatternSet InitializePatterns(pir::IrContext *context) override {
310321
pir::RewritePatternSet ps(context);
@@ -319,7 +330,7 @@ class Conv2dBiasFusePass : public pir::PatternRewritePass {
319330
class Conv2dTransposeBiasFusePass : public pir::PatternRewritePass {
320331
public:
321332
Conv2dTransposeBiasFusePass()
322-
: pir::PatternRewritePass("conv2d_transpose_bias_fuse_pass", 2) {}
333+
: pir::PatternRewritePass("conv2d_transpose_bias_fuse_pass", 3) {}
323334

324335
pir::RewritePatternSet InitializePatterns(pir::IrContext *context) override {
325336
pir::RewritePatternSet ps(context);
@@ -331,7 +342,7 @@ class Conv2dTransposeBiasFusePass : public pir::PatternRewritePass {
331342

332343
class Conv3dBiasFusePass : public pir::PatternRewritePass {
333344
public:
334-
Conv3dBiasFusePass() : pir::PatternRewritePass("conv3d_bias_fuse_pass", 2) {}
345+
Conv3dBiasFusePass() : pir::PatternRewritePass("conv3d_bias_fuse_pass", 3) {}
335346

336347
pir::RewritePatternSet InitializePatterns(pir::IrContext *context) override {
337348
pir::RewritePatternSet ps(context);

paddle/phi/kernels/onednn/conv_handler.h

Lines changed: 35 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
1+
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
22
//
33
// Licensed under the Apache License, Version 2.0 (the "License");
44
// you may not use this file except in compliance with the License.
@@ -125,12 +125,30 @@ class ConvOneDNNHandlerT
125125
DataLayout::ONEDNN,
126126
bias->layout()));
127127

128-
PADDLE_ENFORCE_EQ(
129-
bias->dims().size(),
130-
1,
131-
phi::errors::InvalidArgument("Bias must only have 1 dimension, "
132-
"i.e. X, but got dimension = %d .",
133-
bias->dims().size()));
128+
auto bias_shape = common::vectorize(bias->dims());
129+
auto output_shape = common::vectorize(output->dims());
130+
// layout of bias is always NCHW/NCDHW, so channel is always at 1st dim
131+
if (bias_shape.size() != 1) {
132+
PADDLE_ENFORCE_EQ(
133+
bias_shape[1],
134+
output_shape[1],
135+
phi::errors::InvalidArgument(
136+
"Bias must only have 1 dimension or only bias_dims[1] == "
137+
"output_dims[1] i.e. [X] or [1, X, 1, 1], but got dimension "
138+
"== %d and failed",
139+
bias->dims().size()));
140+
for (size_t i = 0; i < bias_shape.size(); i++) {
141+
if (i == 1) continue;
142+
PADDLE_ENFORCE_EQ(
143+
bias_shape[i],
144+
1,
145+
phi::errors::InvalidArgument(
146+
"Bias with multiply dimensions must only have 1 dimension "
147+
"> 1, i.e. [1, X, 1, 1], but got %d-th dimension == %d .",
148+
i,
149+
bias_shape[i]));
150+
}
151+
}
134152
}
135153
const auto input_dims = input->dims();
136154
const auto data_dims =
@@ -195,6 +213,7 @@ class ConvOneDNNHandlerT
195213

196214
if (bias) {
197215
auto bias_tz = common::vectorize(bias->dims());
216+
if (bias_tz.size() > 1) bias_tz = {bias_tz[1]};
198217
dnnl::memory::desc bias_md =
199218
funcs::OneDNNMemDesc(bias_tz,
200219
dnnl::memory::data_type::f32,
@@ -594,8 +613,16 @@ class ConvOneDNNHandlerT
594613
}
595614
const K_Bias* bias_data = bias->data<K_Bias>();
596615

616+
dnnl::memory::desc bias_md = bias->mem_desc();
617+
auto bias_tz = common::vectorize(bias->dims());
618+
if (bias_tz.size() > 1) {
619+
bias_tz = {bias_tz[1]};
620+
bias_md = funcs::OneDNNMemDesc(bias_tz,
621+
dnnl::memory::data_type::f32,
622+
funcs::OneDNNMemoryFormat::x);
623+
}
597624
return this->AcquireMemoryWithReorder(
598-
bias->mem_desc(),
625+
bias_md,
599626
this->fwd_pd_->bias_desc(),
600627
funcs::to_void_cast<K_Bias>(bias_data),
601628
"@bias_mem_p",

test/ir/pir/fused_pass/onednn/test_conv2d_bias_fuse_pass.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,70 @@ def test_check_output(self):
8383
self.check_pass_correct()
8484

8585

86+
class TestConv2dAddFusePassCase2(PassTest):
87+
def is_program_valid(self, program=None):
88+
return True
89+
90+
def build_ir_program(self):
91+
with paddle.pir_utils.IrGuard():
92+
main_prog = paddle.static.Program()
93+
start_prog = paddle.static.Program()
94+
with paddle.pir.core.program_guard(main_prog, start_prog):
95+
x = paddle.static.data(
96+
name='x', shape=[5, 5, 5, 5], dtype='float32'
97+
)
98+
bias_attr = paddle.ParamAttr(
99+
learning_rate=0.0,
100+
initializer=paddle.nn.initializer.Normal(mean=0.0, std=2.0),
101+
)
102+
bias = paddle.static.create_parameter(
103+
shape=[1, 3, 1, 1],
104+
dtype='float32',
105+
attr=bias_attr,
106+
is_bias=False,
107+
)
108+
w_attr = paddle.ParamAttr(
109+
learning_rate=0.0,
110+
initializer=paddle.nn.initializer.Normal(mean=0.0, std=2.0),
111+
)
112+
conv2d = paddle.nn.Conv2D(
113+
in_channels=5,
114+
out_channels=3,
115+
kernel_size=[1, 1],
116+
groups=1,
117+
stride=[1, 1],
118+
padding=[1, 1, 1, 1],
119+
dilation=[1, 1],
120+
data_format='NCHW',
121+
bias_attr=False,
122+
weight_attr=w_attr,
123+
)
124+
125+
out = paddle.add(conv2d(x), bias)
126+
out = paddle.assign(out)
127+
self.pass_attr_list = [{'conv2d_bias_fuse_pass': {}}]
128+
self.feeds = {
129+
"x": np.random.random((5, 5, 5, 5)).astype("float32"),
130+
"bias": np.random.random((1, 3, 1, 1)).astype("float32"),
131+
}
132+
self.fetch_list = [out]
133+
self.valid_op_map = {
134+
"onednn_op.fused_conv2d": 1,
135+
"pd_op.conv2d": 0,
136+
"pd_op.add": 0,
137+
}
138+
return [main_prog, start_prog]
139+
140+
def sample_program(self):
141+
yield self.build_ir_program(), False
142+
143+
def setUp(self):
144+
self.places.append(paddle.CPUPlace())
145+
146+
def test_check_output(self):
147+
self.check_pass_correct()
148+
149+
86150
class TestConv2dAddFusePassWithAddParam(PassTest):
87151
def is_program_valid(self, program=None):
88152
return True

0 commit comments

Comments
 (0)