Skip to content

Commit da539f1

Browse files
committed
lift up check
1 parent ef77944 commit da539f1

File tree

1 file changed

+12
-12
lines changed

1 file changed

+12
-12
lines changed

paddle/fluid/pir/transforms/onednn/cpu_special_ops_bf16_pass.cc

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,18 @@ class CastBf16Pattern : public pir::OpRewritePattern<OpType> {
6464

6565
pir::IrContext *ctx = rewriter.ir_context();
6666

67+
auto dtype_attr = attributes["dtype"];
68+
phi::DataType dtype =
69+
dtype_attr.template dyn_cast<paddle::dialect::DataTypeAttribute>()
70+
.data();
71+
if (dtype == phi::DataType::FLOAT32) {
72+
pir::Attribute new_dtype =
73+
paddle::dialect::DataTypeAttribute::get(ctx, phi::DataType::BFLOAT16);
74+
attributes["dtype"] = new_dtype;
75+
} else {
76+
return false;
77+
}
78+
6779
std::unordered_map<std::string, pir::Attribute> q_attributes;
6880
q_attributes["scale"] = rewriter.float_attr(1.0f);
6981
q_attributes["shift"] = rewriter.float_attr(0.0f);
@@ -81,18 +93,6 @@ class CastBf16Pattern : public pir::OpRewritePattern<OpType> {
8193
type, pir::BFloat16Type::get(ctx), ctx);
8294
q_op->result(0).set_type(new_type);
8395

84-
auto dtype_attr = attributes["dtype"];
85-
phi::DataType dtype =
86-
dtype_attr.template dyn_cast<paddle::dialect::DataTypeAttribute>()
87-
.data();
88-
if (dtype == phi::DataType::FLOAT32) {
89-
pir::Attribute new_dtype =
90-
paddle::dialect::DataTypeAttribute::get(ctx, phi::DataType::BFLOAT16);
91-
attributes["dtype"] = new_dtype;
92-
} else {
93-
return false;
94-
}
95-
9696
OpType new_cast = rewriter.Build<OpType>(q_op.output(), attributes);
9797

9898
std::unordered_map<std::string, pir::Attribute> dq_attributes;

0 commit comments

Comments
 (0)