@@ -50,6 +50,7 @@ class CpuBfloat16Pattern : public paddle::drr::DrrPatternBase {
5050 paddle::drr::SourcePattern pat = ctx->SourcePattern ();
5151
5252 std::unordered_map<std::string, paddle::drr::Attribute> op_attrs;
53+ bool data_format = false ;
5354 if (bfloat16_ops_ == " onednn_op.conv2d" ) {
5455 op_attrs.emplace (" strides" , pat.Attr (" strides" ));
5556 op_attrs.emplace (" paddings" , pat.Attr (" paddings" ));
@@ -60,6 +61,7 @@ class CpuBfloat16Pattern : public paddle::drr::DrrPatternBase {
6061 op_attrs.emplace (" is_test" , pat.Attr (" is_test" ));
6162 op_attrs.emplace (" mkldnn_data_type" , pat.Attr (" mkldnn_data_type" ));
6263 op_attrs.emplace (" force_fp32_output" , pat.Attr (" force_fp32_output" ));
64+ data_format = true ;
6365 } else if (bfloat16_ops_ == " onednn_op.matmul" ) {
6466 op_attrs.emplace (" transpose_x" , pat.Attr (" transpose_x" ));
6567 op_attrs.emplace (" transpose_y" , pat.Attr (" transpose_y" ));
@@ -75,21 +77,20 @@ class CpuBfloat16Pattern : public paddle::drr::DrrPatternBase {
7577 op_attrs.emplace (" data_format" , pat.Attr (" data_format" ));
7678 op_attrs.emplace (" ceil_mode" , pat.Attr (" ceil_mode" ));
7779 op_attrs.emplace (" exclusive" , pat.Attr (" exclusive" ));
78- op_attrs.emplace (" data_format" , pat.Attr (" data_format" ));
7980 op_attrs.emplace (" pooling_type" , pat.Attr (" pooling_type" ));
8081 op_attrs.emplace (" global_pooling" , pat.Attr (" global_pooling" ));
8182 op_attrs.emplace (" adaptive" , pat.Attr (" adaptive" ));
8283 op_attrs.emplace (" padding_algorithm" , pat.Attr (" padding_algorithm" ));
8384 op_attrs.emplace (" use_quantizer" , pat.Attr (" use_quantizer" ));
8485 op_attrs.emplace (" mkldnn_data_type" , pat.Attr (" mkldnn_data_type" ));
8586 op_attrs.emplace (" is_test" , pat.Attr (" is_test" ));
86-
87+ data_format = true ;
8788 } else if (bfloat16_ops_ == " onednn_op.prelu" ) {
8889 op_attrs.emplace (" data_format" , pat.Attr (" data_format" ));
8990 op_attrs.emplace (" mode" , pat.Attr (" mode" ));
9091 op_attrs.emplace (" is_test" , pat.Attr (" is_test" ));
9192 op_attrs.emplace (" mkldnn_data_type" , pat.Attr (" mkldnn_data_type" ));
92-
93+ data_format = true ;
9394 } else if (bfloat16_ops_ == " onednn_op.sum" ) {
9495 op_attrs.emplace (" mkldnn_data_type" , pat.Attr (" mkldnn_data_type" ));
9596 op_attrs.emplace (" keepdim" , pat.Attr (" keepdim" ));
@@ -178,15 +179,16 @@ class CpuBfloat16Pattern : public paddle::drr::DrrPatternBase {
178179 });
179180 paddle::drr::ResultPattern res = pat.ResultPattern ();
180181
181- const auto &quantize_op =
182- res.Op (" onednn_op.quantize" ,
183- {{
184- {" scale" , res.Float32Attr (1 .f )},
185- {" shift" , res.Float32Attr (0 .0f )},
186- {" bfloat16" , res.BoolAttr (true )},
187- {" is_negative_input" , res.BoolAttr (false )},
188- {" output_format" , res.StrAttr (" NCHW" )},
189- }});
182+ const auto &quantize_op = res.Op (
183+ " onednn_op.quantize" ,
184+ {{
185+ {" scale" , res.Float32Attr (1 .f )},
186+ {" shift" , res.Float32Attr (0 .0f )},
187+ {" bfloat16" , res.BoolAttr (true )},
188+ {" is_negative_input" , res.BoolAttr (false )},
189+ {" output_format" ,
190+ data_format ? pat.Attr (" data_format" ) : res.StrAttr (" NCHW" )},
191+ }});
190192 quantize_op ({&res.Tensor (" quantize_" + std::to_string (index_))},
191193 {&res.Tensor (" quantize_out_" + std::to_string (index_))});
192194
@@ -251,7 +253,6 @@ class CpuBfloat16DequantPattern : public paddle::drr::DrrPatternBase {
251253 op_attrs.emplace (" data_format" , pat.Attr (" data_format" ));
252254 op_attrs.emplace (" ceil_mode" , pat.Attr (" ceil_mode" ));
253255 op_attrs.emplace (" exclusive" , pat.Attr (" exclusive" ));
254- op_attrs.emplace (" data_format" , pat.Attr (" data_format" ));
255256 op_attrs.emplace (" pooling_type" , pat.Attr (" pooling_type" ));
256257 op_attrs.emplace (" global_pooling" , pat.Attr (" global_pooling" ));
257258 op_attrs.emplace (" adaptive" , pat.Attr (" adaptive" ));
@@ -383,6 +384,7 @@ class CpuBfloat16PatternOne_one : public paddle::drr::DrrPatternBase {
383384 paddle::drr::SourcePattern pat = ctx->SourcePattern ();
384385
385386 std::unordered_map<std::string, paddle::drr::Attribute> op_attrs;
387+ bool data_format = false ;
386388 if (bfloat16_ops_ == " onednn_op.gelu" ) {
387389 op_attrs.emplace (" approximate" , pat.Attr (" approximate" ));
388390 op_attrs.emplace (" mkldnn_data_type" , pat.Attr (" mkldnn_data_type" ));
@@ -392,12 +394,13 @@ class CpuBfloat16PatternOne_one : public paddle::drr::DrrPatternBase {
392394 op_attrs.emplace (" axis" , pat.Attr (" axis" ));
393395 op_attrs.emplace (" data_format" , pat.Attr (" data_format" ));
394396 op_attrs.emplace (" is_test" , pat.Attr (" is_test" ));
395-
397+ data_format = true ;
396398 } else if (bfloat16_ops_ == " onednn_op.transpose" ||
397399 bfloat16_ops_ == " onednn_op.transpose_" ) {
398400 op_attrs.emplace (" perm" , pat.Attr (" perm" ));
399401 op_attrs.emplace (" data_format" , pat.Attr (" data_format" ));
400402 op_attrs.emplace (" mkldnn_data_type" , pat.Attr (" mkldnn_data_type" ));
403+ data_format = true ;
401404 } else if (bfloat16_ops_ == " onednn_op.relu" ||
402405 bfloat16_ops_ == " onednn_op.relu_" ) {
403406 op_attrs.emplace (" mkldnn_data_type" , pat.Attr (" mkldnn_data_type" ));
@@ -461,15 +464,16 @@ class CpuBfloat16PatternOne_one : public paddle::drr::DrrPatternBase {
461464 });
462465 paddle::drr::ResultPattern res = pat.ResultPattern ();
463466
464- const auto &quantize_op =
465- res.Op (" onednn_op.quantize" ,
466- {{
467- {" scale" , res.Float32Attr (1 .f )},
468- {" shift" , res.Float32Attr (0 .0f )},
469- {" bfloat16" , res.BoolAttr (true )},
470- {" is_negative_input" , res.BoolAttr (false )},
471- {" output_format" , res.StrAttr (" NCHW" )},
472- }});
467+ const auto &quantize_op = res.Op (
468+ " onednn_op.quantize" ,
469+ {{
470+ {" scale" , res.Float32Attr (1 .f )},
471+ {" shift" , res.Float32Attr (0 .0f )},
472+ {" bfloat16" , res.BoolAttr (true )},
473+ {" is_negative_input" , res.BoolAttr (false )},
474+ {" output_format" ,
475+ data_format ? pat.Attr (" data_format" ) : res.StrAttr (" NCHW" )},
476+ }});
473477 quantize_op ({&res.Tensor (" quantize_0" )}, {&res.Tensor (" quantize_out_0" )});
474478
475479 const auto &res_op = res.Op (bfloat16_ops_, op_attrs);
@@ -812,6 +816,7 @@ class CpuBfloat16PatternThree_one : public paddle::drr::DrrPatternBase {
812816 paddle::drr::SourcePattern pat = ctx->SourcePattern ();
813817
814818 std::unordered_map<std::string, paddle::drr::Attribute> op_attrs;
819+ bool data_format = false ;
815820 if (bfloat16_ops_ == " onednn_op.fc" ) {
816821 op_attrs.emplace (" in_num_col_dims" , pat.Attr (" in_num_col_dims" ));
817822 op_attrs.emplace (" activation_type" , pat.Attr (" activation_type" ));
@@ -870,6 +875,7 @@ class CpuBfloat16PatternThree_one : public paddle::drr::DrrPatternBase {
870875 op_attrs.emplace (" paddings" , pat.Attr (" paddings" ));
871876 op_attrs.emplace (" strides" , pat.Attr (" strides" ));
872877 op_attrs.emplace (" force_fp32_output" , pat.Attr (" force_fp32_output" ));
878+ data_format = true ;
873879 }
874880
875881 const auto &op = pat.Op (bfloat16_ops_, op_attrs);
@@ -930,15 +936,16 @@ class CpuBfloat16PatternThree_one : public paddle::drr::DrrPatternBase {
930936
931937 paddle::drr::ResultPattern res = pat.ResultPattern ();
932938
933- const auto &quantize_op =
934- res.Op (" onednn_op.quantize" ,
935- {{
936- {" scale" , res.Float32Attr (1 .f )},
937- {" shift" , res.Float32Attr (0 .0f )},
938- {" bfloat16" , res.BoolAttr (true )},
939- {" is_negative_input" , res.BoolAttr (false )},
940- {" output_format" , res.StrAttr (" NCHW" )},
941- }});
939+ const auto &quantize_op = res.Op (
940+ " onednn_op.quantize" ,
941+ {{
942+ {" scale" , res.Float32Attr (1 .f )},
943+ {" shift" , res.Float32Attr (0 .0f )},
944+ {" bfloat16" , res.BoolAttr (true )},
945+ {" is_negative_input" , res.BoolAttr (false )},
946+ {" output_format" ,
947+ data_format ? pat.Attr (" data_format" ) : res.StrAttr (" NCHW" )},
948+ }});
942949 quantize_op ({&res.Tensor (" quantize_" + std::to_string (index_))},
943950 {&res.Tensor (" quantize_out_" + std::to_string (index_))});
944951
@@ -1858,7 +1865,7 @@ class CpuBfloat16PatternFour_one : public paddle::drr::DrrPatternBase {
18581865 {" shift" , res.Float32Attr (0 .0f )},
18591866 {" bfloat16" , res.BoolAttr (true )},
18601867 {" is_negative_input" , res.BoolAttr (false )},
1861- {" output_format" , res. StrAttr ( " NCHW " )},
1868+ {" output_format" , pat. Attr ( " data_format " )},
18621869 }});
18631870 quantize_op ({&res.Tensor (" quantize_" + std::to_string (index_))},
18641871 {&res.Tensor (" quantize_out_" + std::to_string (index_))});
0 commit comments