Skip to content

Commit 280d84a

Browse files
authored
extend capability of conv_bias_bn fusion (#66568)
1 parent 6a21709 commit 280d84a

File tree

2 files changed

+174
-4
lines changed

2 files changed

+174
-4
lines changed

paddle/fluid/pir/transforms/onednn/conv2d_bn_onednn_fuse_pass.cc

Lines changed: 30 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -192,8 +192,33 @@ class Conv2dBiasBnOneDNNFusePattern
192192
pir::Value bn_bias = op.bias();
193193

194194
auto add_y_shape = pir::GetShapeFromValue(add_y);
195-
// bias currently only support per_tensor add
196-
if (add_y_shape.size() != 1 || add_y_shape[0] != 1) return false;
195+
auto bn_bias_shape = pir::GetShapeFromValue(bn_bias);
196+
std::vector<int64_t> add_y_new_shape{add_y_shape[0]};
197+
// Support both per_tensor & per_channel addition
198+
if (add_y_shape.size() != 1) {
199+
size_t idx;
200+
if (data_format == "NHWC") {
201+
idx = add_y_shape.size() - 1;
202+
} else {
203+
idx = 1;
204+
}
205+
if (add_y_shape[idx] != bn_bias_shape[0]) return false;
206+
bool is_ok = true;
207+
for (size_t i = 0; i < add_y_shape.size(); i++) {
208+
if (i == idx) continue;
209+
if (add_y_shape[i] != 1) {
210+
is_ok = false;
211+
break;
212+
}
213+
}
214+
if (!is_ok) return false;
215+
// reshape add_y from [1, X, 1, 1] (NCHW) to [X]
216+
add_y_new_shape[0] = add_y_shape[idx];
217+
} else if (add_y_shape[0] != 1) {
218+
return false;
219+
}
220+
paddle::dialect::ReshapeOp reshape_add_y_op =
221+
rewriter.Build<paddle::dialect::ReshapeOp>(add_op.y(), add_y_new_shape);
197222

198223
// --- deal with filter ---
199224
auto bn_variance_shape = pir::GetShapeFromValue(bn_variance);
@@ -240,9 +265,10 @@ class Conv2dBiasBnOneDNNFusePattern
240265
conv2d_filter, reshape_scale_op.out());
241266

242267
// --- deal with bias ---
243-
// (add_op.y() - bn_mean)*scale + bn_bias
268+
// (add_op.y()(reshaped) - bn_mean)*scale + bn_bias
244269
paddle::dialect::SubtractOp sub_op_1 =
245-
rewriter.Build<paddle::dialect::SubtractOp>(add_op.y(), bn_mean);
270+
rewriter.Build<paddle::dialect::SubtractOp>(reshape_add_y_op.out(),
271+
bn_mean);
246272
paddle::dialect::MultiplyOp mul_bias_op =
247273
rewriter.Build<paddle::dialect::MultiplyOp>(sub_op_1.out(),
248274
div_op.out());

test/ir/pir/fused_pass/onednn/test_conv2d_bn_onednn_fuse_pass.py

Lines changed: 144 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,5 +151,149 @@ def setUp(self):
151151
self.places.append(paddle.CPUPlace())
152152

153153

154+
class TestConv2dBiasBnOneDNNPassPatternCase2(PassTest):
155+
r"""
156+
x_var f_var
157+
\ /
158+
conv2d add_y
159+
\ /
160+
add
161+
|
162+
BatchNorm
163+
|
164+
"""
165+
166+
def is_program_valid(self, program=None):
167+
return True
168+
169+
def build_ir_program(self):
170+
with paddle.pir_utils.IrGuard():
171+
main_prog = paddle.static.Program()
172+
start_prog = paddle.static.Program()
173+
with paddle.pir.core.program_guard(main_prog, start_prog):
174+
x = paddle.static.data(
175+
name='x', shape=[3, 1, 28, 28], dtype='float32'
176+
)
177+
bias_attr = paddle.ParamAttr(
178+
learning_rate=0.0,
179+
initializer=paddle.nn.initializer.Normal(mean=0.0, std=2.0),
180+
)
181+
y = paddle.static.create_parameter(
182+
shape=[1, 32, 1, 1],
183+
dtype='float32',
184+
attr=bias_attr,
185+
is_bias=False,
186+
)
187+
conv2d = paddle.nn.Conv2D(
188+
in_channels=1,
189+
out_channels=32,
190+
kernel_size=3,
191+
padding=1,
192+
data_format='NCHW',
193+
bias_attr=False,
194+
)
195+
bn = paddle.nn.BatchNorm2D(
196+
num_features=32,
197+
data_format='NCHW',
198+
use_global_stats=True,
199+
)
200+
add_out = paddle.add(conv2d(x), y)
201+
out = bn(add_out)
202+
out = paddle.assign(out)
203+
self.pass_attr_list = [{'conv2d_bias_bn_onednn_fuse_pass': {}}]
204+
self.feeds = {
205+
"x": np.random.random((3, 1, 28, 28)).astype("float32"),
206+
"y": np.random.random((1, 32, 1, 1)).astype("float32"),
207+
}
208+
self.fetch_list = [out]
209+
self.valid_op_map = {
210+
"onednn_op.fused_conv2d": 1,
211+
"pd_op.batch_norm_": 0,
212+
}
213+
return [main_prog, start_prog]
214+
215+
def sample_program(self):
216+
pir_program = self.build_ir_program()
217+
yield pir_program, False
218+
219+
def test_check_output(self):
220+
self.check_pass_correct()
221+
222+
def setUp(self):
223+
self.places.append(paddle.CPUPlace())
224+
225+
226+
class TestConv2dBiasBnOneDNNPassPatternCase3(PassTest):
227+
r"""
228+
x_var f_var
229+
\ /
230+
conv2d add_y
231+
\ /
232+
add
233+
|
234+
BatchNorm
235+
|
236+
"""
237+
238+
def is_program_valid(self, program=None):
239+
return True
240+
241+
def build_ir_program(self):
242+
with paddle.pir_utils.IrGuard():
243+
main_prog = paddle.static.Program()
244+
start_prog = paddle.static.Program()
245+
with paddle.pir.core.program_guard(main_prog, start_prog):
246+
x = paddle.static.data(
247+
name='x', shape=[3, 28, 28, 1], dtype='float32'
248+
)
249+
bias_attr = paddle.ParamAttr(
250+
learning_rate=0.0,
251+
initializer=paddle.nn.initializer.Normal(mean=0.0, std=2.0),
252+
)
253+
y = paddle.static.create_parameter(
254+
shape=[1, 1, 1, 32],
255+
dtype='float32',
256+
attr=bias_attr,
257+
is_bias=False,
258+
)
259+
conv2d = paddle.nn.Conv2D(
260+
in_channels=1,
261+
out_channels=32,
262+
kernel_size=3,
263+
padding=1,
264+
data_format='NHWC',
265+
bias_attr=False,
266+
)
267+
bn = paddle.nn.BatchNorm2D(
268+
num_features=32,
269+
data_format='NHWC',
270+
use_global_stats=True,
271+
)
272+
add_out = paddle.add(conv2d(x), y)
273+
out = bn(add_out)
274+
out = paddle.assign(out)
275+
self.pass_attr_list = [{'conv2d_bias_bn_onednn_fuse_pass': {}}]
276+
self.feeds = {
277+
"x": np.random.random((3, 28, 28, 1)).astype("float32"),
278+
"y": np.random.random((1, 1, 1, 32)).astype("float32"),
279+
}
280+
self.fetch_list = [out]
281+
self.valid_op_map = {
282+
"onednn_op.fused_conv2d": 1,
283+
"pd_op.batch_norm_": 0,
284+
}
285+
return [main_prog, start_prog]
286+
287+
def sample_program(self):
288+
pir_program = self.build_ir_program()
289+
yield pir_program, False
290+
291+
def test_check_output(self):
292+
self.check_pass_correct()
293+
294+
def setUp(self):
295+
self.places.append(paddle.CPUPlace())
296+
297+
154298
if __name__ == "__main__":
155299
unittest.main()

0 commit comments

Comments
 (0)