@@ -151,5 +151,149 @@ def setUp(self):
151151 self .places .append (paddle .CPUPlace ())
152152
153153
154+ class TestConv2dBiasBnOneDNNPassPatternCase2 (PassTest ):
155+ r"""
156+ x_var f_var
157+ \ /
158+ conv2d add_y
159+ \ /
160+ add
161+ |
162+ BatchNorm
163+ |
164+ """
165+
166+ def is_program_valid (self , program = None ):
167+ return True
168+
169+ def build_ir_program (self ):
170+ with paddle .pir_utils .IrGuard ():
171+ main_prog = paddle .static .Program ()
172+ start_prog = paddle .static .Program ()
173+ with paddle .pir .core .program_guard (main_prog , start_prog ):
174+ x = paddle .static .data (
175+ name = 'x' , shape = [3 , 1 , 28 , 28 ], dtype = 'float32'
176+ )
177+ bias_attr = paddle .ParamAttr (
178+ learning_rate = 0.0 ,
179+ initializer = paddle .nn .initializer .Normal (mean = 0.0 , std = 2.0 ),
180+ )
181+ y = paddle .static .create_parameter (
182+ shape = [1 , 32 , 1 , 1 ],
183+ dtype = 'float32' ,
184+ attr = bias_attr ,
185+ is_bias = False ,
186+ )
187+ conv2d = paddle .nn .Conv2D (
188+ in_channels = 1 ,
189+ out_channels = 32 ,
190+ kernel_size = 3 ,
191+ padding = 1 ,
192+ data_format = 'NCHW' ,
193+ bias_attr = False ,
194+ )
195+ bn = paddle .nn .BatchNorm2D (
196+ num_features = 32 ,
197+ data_format = 'NCHW' ,
198+ use_global_stats = True ,
199+ )
200+ add_out = paddle .add (conv2d (x ), y )
201+ out = bn (add_out )
202+ out = paddle .assign (out )
203+ self .pass_attr_list = [{'conv2d_bias_bn_onednn_fuse_pass' : {}}]
204+ self .feeds = {
205+ "x" : np .random .random ((3 , 1 , 28 , 28 )).astype ("float32" ),
206+ "y" : np .random .random ((1 , 32 , 1 , 1 )).astype ("float32" ),
207+ }
208+ self .fetch_list = [out ]
209+ self .valid_op_map = {
210+ "onednn_op.fused_conv2d" : 1 ,
211+ "pd_op.batch_norm_" : 0 ,
212+ }
213+ return [main_prog , start_prog ]
214+
215+ def sample_program (self ):
216+ pir_program = self .build_ir_program ()
217+ yield pir_program , False
218+
219+ def test_check_output (self ):
220+ self .check_pass_correct ()
221+
222+ def setUp (self ):
223+ self .places .append (paddle .CPUPlace ())
224+
225+
226+ class TestConv2dBiasBnOneDNNPassPatternCase3 (PassTest ):
227+ r"""
228+ x_var f_var
229+ \ /
230+ conv2d add_y
231+ \ /
232+ add
233+ |
234+ BatchNorm
235+ |
236+ """
237+
238+ def is_program_valid (self , program = None ):
239+ return True
240+
241+ def build_ir_program (self ):
242+ with paddle .pir_utils .IrGuard ():
243+ main_prog = paddle .static .Program ()
244+ start_prog = paddle .static .Program ()
245+ with paddle .pir .core .program_guard (main_prog , start_prog ):
246+ x = paddle .static .data (
247+ name = 'x' , shape = [3 , 28 , 28 , 1 ], dtype = 'float32'
248+ )
249+ bias_attr = paddle .ParamAttr (
250+ learning_rate = 0.0 ,
251+ initializer = paddle .nn .initializer .Normal (mean = 0.0 , std = 2.0 ),
252+ )
253+ y = paddle .static .create_parameter (
254+ shape = [1 , 1 , 1 , 32 ],
255+ dtype = 'float32' ,
256+ attr = bias_attr ,
257+ is_bias = False ,
258+ )
259+ conv2d = paddle .nn .Conv2D (
260+ in_channels = 1 ,
261+ out_channels = 32 ,
262+ kernel_size = 3 ,
263+ padding = 1 ,
264+ data_format = 'NHWC' ,
265+ bias_attr = False ,
266+ )
267+ bn = paddle .nn .BatchNorm2D (
268+ num_features = 32 ,
269+ data_format = 'NHWC' ,
270+ use_global_stats = True ,
271+ )
272+ add_out = paddle .add (conv2d (x ), y )
273+ out = bn (add_out )
274+ out = paddle .assign (out )
275+ self .pass_attr_list = [{'conv2d_bias_bn_onednn_fuse_pass' : {}}]
276+ self .feeds = {
277+ "x" : np .random .random ((3 , 28 , 28 , 1 )).astype ("float32" ),
278+ "y" : np .random .random ((1 , 1 , 1 , 32 )).astype ("float32" ),
279+ }
280+ self .fetch_list = [out ]
281+ self .valid_op_map = {
282+ "onednn_op.fused_conv2d" : 1 ,
283+ "pd_op.batch_norm_" : 0 ,
284+ }
285+ return [main_prog , start_prog ]
286+
287+ def sample_program (self ):
288+ pir_program = self .build_ir_program ()
289+ yield pir_program , False
290+
291+ def test_check_output (self ):
292+ self .check_pass_correct ()
293+
294+ def setUp (self ):
295+ self .places .append (paddle .CPUPlace ())
296+
297+
154298if __name__ == "__main__" :
155299 unittest .main ()
0 commit comments