Skip to content

Commit 7679609

Browse files
Fix bf16 quantize data_format bug (#70099)
1 parent f1868d3 commit 7679609

File tree

1 file changed

+40
-33
lines changed

1 file changed

+40
-33
lines changed

paddle/fluid/pir/transforms/onednn/cpu_bfloat16_pass.cc

Lines changed: 40 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ class CpuBfloat16Pattern : public paddle::drr::DrrPatternBase {
5050
paddle::drr::SourcePattern pat = ctx->SourcePattern();
5151

5252
std::unordered_map<std::string, paddle::drr::Attribute> op_attrs;
53+
bool data_format = false;
5354
if (bfloat16_ops_ == "onednn_op.conv2d") {
5455
op_attrs.emplace("strides", pat.Attr("strides"));
5556
op_attrs.emplace("paddings", pat.Attr("paddings"));
@@ -60,6 +61,7 @@ class CpuBfloat16Pattern : public paddle::drr::DrrPatternBase {
6061
op_attrs.emplace("is_test", pat.Attr("is_test"));
6162
op_attrs.emplace("mkldnn_data_type", pat.Attr("mkldnn_data_type"));
6263
op_attrs.emplace("force_fp32_output", pat.Attr("force_fp32_output"));
64+
data_format = true;
6365
} else if (bfloat16_ops_ == "onednn_op.matmul") {
6466
op_attrs.emplace("transpose_x", pat.Attr("transpose_x"));
6567
op_attrs.emplace("transpose_y", pat.Attr("transpose_y"));
@@ -75,21 +77,20 @@ class CpuBfloat16Pattern : public paddle::drr::DrrPatternBase {
7577
op_attrs.emplace("data_format", pat.Attr("data_format"));
7678
op_attrs.emplace("ceil_mode", pat.Attr("ceil_mode"));
7779
op_attrs.emplace("exclusive", pat.Attr("exclusive"));
78-
op_attrs.emplace("data_format", pat.Attr("data_format"));
7980
op_attrs.emplace("pooling_type", pat.Attr("pooling_type"));
8081
op_attrs.emplace("global_pooling", pat.Attr("global_pooling"));
8182
op_attrs.emplace("adaptive", pat.Attr("adaptive"));
8283
op_attrs.emplace("padding_algorithm", pat.Attr("padding_algorithm"));
8384
op_attrs.emplace("use_quantizer", pat.Attr("use_quantizer"));
8485
op_attrs.emplace("mkldnn_data_type", pat.Attr("mkldnn_data_type"));
8586
op_attrs.emplace("is_test", pat.Attr("is_test"));
86-
87+
data_format = true;
8788
} else if (bfloat16_ops_ == "onednn_op.prelu") {
8889
op_attrs.emplace("data_format", pat.Attr("data_format"));
8990
op_attrs.emplace("mode", pat.Attr("mode"));
9091
op_attrs.emplace("is_test", pat.Attr("is_test"));
9192
op_attrs.emplace("mkldnn_data_type", pat.Attr("mkldnn_data_type"));
92-
93+
data_format = true;
9394
} else if (bfloat16_ops_ == "onednn_op.sum") {
9495
op_attrs.emplace("mkldnn_data_type", pat.Attr("mkldnn_data_type"));
9596
op_attrs.emplace("keepdim", pat.Attr("keepdim"));
@@ -178,15 +179,16 @@ class CpuBfloat16Pattern : public paddle::drr::DrrPatternBase {
178179
});
179180
paddle::drr::ResultPattern res = pat.ResultPattern();
180181

181-
const auto &quantize_op =
182-
res.Op("onednn_op.quantize",
183-
{{
184-
{"scale", res.Float32Attr(1.f)},
185-
{"shift", res.Float32Attr(0.0f)},
186-
{"bfloat16", res.BoolAttr(true)},
187-
{"is_negative_input", res.BoolAttr(false)},
188-
{"output_format", res.StrAttr("NCHW")},
189-
}});
182+
const auto &quantize_op = res.Op(
183+
"onednn_op.quantize",
184+
{{
185+
{"scale", res.Float32Attr(1.f)},
186+
{"shift", res.Float32Attr(0.0f)},
187+
{"bfloat16", res.BoolAttr(true)},
188+
{"is_negative_input", res.BoolAttr(false)},
189+
{"output_format",
190+
data_format ? pat.Attr("data_format") : res.StrAttr("NCHW")},
191+
}});
190192
quantize_op({&res.Tensor("quantize_" + std::to_string(index_))},
191193
{&res.Tensor("quantize_out_" + std::to_string(index_))});
192194

@@ -251,7 +253,6 @@ class CpuBfloat16DequantPattern : public paddle::drr::DrrPatternBase {
251253
op_attrs.emplace("data_format", pat.Attr("data_format"));
252254
op_attrs.emplace("ceil_mode", pat.Attr("ceil_mode"));
253255
op_attrs.emplace("exclusive", pat.Attr("exclusive"));
254-
op_attrs.emplace("data_format", pat.Attr("data_format"));
255256
op_attrs.emplace("pooling_type", pat.Attr("pooling_type"));
256257
op_attrs.emplace("global_pooling", pat.Attr("global_pooling"));
257258
op_attrs.emplace("adaptive", pat.Attr("adaptive"));
@@ -383,6 +384,7 @@ class CpuBfloat16PatternOne_one : public paddle::drr::DrrPatternBase {
383384
paddle::drr::SourcePattern pat = ctx->SourcePattern();
384385

385386
std::unordered_map<std::string, paddle::drr::Attribute> op_attrs;
387+
bool data_format = false;
386388
if (bfloat16_ops_ == "onednn_op.gelu") {
387389
op_attrs.emplace("approximate", pat.Attr("approximate"));
388390
op_attrs.emplace("mkldnn_data_type", pat.Attr("mkldnn_data_type"));
@@ -392,12 +394,13 @@ class CpuBfloat16PatternOne_one : public paddle::drr::DrrPatternBase {
392394
op_attrs.emplace("axis", pat.Attr("axis"));
393395
op_attrs.emplace("data_format", pat.Attr("data_format"));
394396
op_attrs.emplace("is_test", pat.Attr("is_test"));
395-
397+
data_format = true;
396398
} else if (bfloat16_ops_ == "onednn_op.transpose" ||
397399
bfloat16_ops_ == "onednn_op.transpose_") {
398400
op_attrs.emplace("perm", pat.Attr("perm"));
399401
op_attrs.emplace("data_format", pat.Attr("data_format"));
400402
op_attrs.emplace("mkldnn_data_type", pat.Attr("mkldnn_data_type"));
403+
data_format = true;
401404
} else if (bfloat16_ops_ == "onednn_op.relu" ||
402405
bfloat16_ops_ == "onednn_op.relu_") {
403406
op_attrs.emplace("mkldnn_data_type", pat.Attr("mkldnn_data_type"));
@@ -461,15 +464,16 @@ class CpuBfloat16PatternOne_one : public paddle::drr::DrrPatternBase {
461464
});
462465
paddle::drr::ResultPattern res = pat.ResultPattern();
463466

464-
const auto &quantize_op =
465-
res.Op("onednn_op.quantize",
466-
{{
467-
{"scale", res.Float32Attr(1.f)},
468-
{"shift", res.Float32Attr(0.0f)},
469-
{"bfloat16", res.BoolAttr(true)},
470-
{"is_negative_input", res.BoolAttr(false)},
471-
{"output_format", res.StrAttr("NCHW")},
472-
}});
467+
const auto &quantize_op = res.Op(
468+
"onednn_op.quantize",
469+
{{
470+
{"scale", res.Float32Attr(1.f)},
471+
{"shift", res.Float32Attr(0.0f)},
472+
{"bfloat16", res.BoolAttr(true)},
473+
{"is_negative_input", res.BoolAttr(false)},
474+
{"output_format",
475+
data_format ? pat.Attr("data_format") : res.StrAttr("NCHW")},
476+
}});
473477
quantize_op({&res.Tensor("quantize_0")}, {&res.Tensor("quantize_out_0")});
474478

475479
const auto &res_op = res.Op(bfloat16_ops_, op_attrs);
@@ -812,6 +816,7 @@ class CpuBfloat16PatternThree_one : public paddle::drr::DrrPatternBase {
812816
paddle::drr::SourcePattern pat = ctx->SourcePattern();
813817

814818
std::unordered_map<std::string, paddle::drr::Attribute> op_attrs;
819+
bool data_format = false;
815820
if (bfloat16_ops_ == "onednn_op.fc") {
816821
op_attrs.emplace("in_num_col_dims", pat.Attr("in_num_col_dims"));
817822
op_attrs.emplace("activation_type", pat.Attr("activation_type"));
@@ -870,6 +875,7 @@ class CpuBfloat16PatternThree_one : public paddle::drr::DrrPatternBase {
870875
op_attrs.emplace("paddings", pat.Attr("paddings"));
871876
op_attrs.emplace("strides", pat.Attr("strides"));
872877
op_attrs.emplace("force_fp32_output", pat.Attr("force_fp32_output"));
878+
data_format = true;
873879
}
874880

875881
const auto &op = pat.Op(bfloat16_ops_, op_attrs);
@@ -930,15 +936,16 @@ class CpuBfloat16PatternThree_one : public paddle::drr::DrrPatternBase {
930936

931937
paddle::drr::ResultPattern res = pat.ResultPattern();
932938

933-
const auto &quantize_op =
934-
res.Op("onednn_op.quantize",
935-
{{
936-
{"scale", res.Float32Attr(1.f)},
937-
{"shift", res.Float32Attr(0.0f)},
938-
{"bfloat16", res.BoolAttr(true)},
939-
{"is_negative_input", res.BoolAttr(false)},
940-
{"output_format", res.StrAttr("NCHW")},
941-
}});
939+
const auto &quantize_op = res.Op(
940+
"onednn_op.quantize",
941+
{{
942+
{"scale", res.Float32Attr(1.f)},
943+
{"shift", res.Float32Attr(0.0f)},
944+
{"bfloat16", res.BoolAttr(true)},
945+
{"is_negative_input", res.BoolAttr(false)},
946+
{"output_format",
947+
data_format ? pat.Attr("data_format") : res.StrAttr("NCHW")},
948+
}});
942949
quantize_op({&res.Tensor("quantize_" + std::to_string(index_))},
943950
{&res.Tensor("quantize_out_" + std::to_string(index_))});
944951

@@ -1858,7 +1865,7 @@ class CpuBfloat16PatternFour_one : public paddle::drr::DrrPatternBase {
18581865
{"shift", res.Float32Attr(0.0f)},
18591866
{"bfloat16", res.BoolAttr(true)},
18601867
{"is_negative_input", res.BoolAttr(false)},
1861-
{"output_format", res.StrAttr("NCHW")},
1868+
{"output_format", pat.Attr("data_format")},
18621869
}});
18631870
quantize_op({&res.Tensor("quantize_" + std::to_string(index_))},
18641871
{&res.Tensor("quantize_out_" + std::to_string(index_))});

0 commit comments

Comments
 (0)