@@ -64,6 +64,18 @@ class CastBf16Pattern : public pir::OpRewritePattern<OpType> {
6464
6565 pir::IrContext *ctx = rewriter.ir_context ();
6666
67+ auto dtype_attr = attributes[" dtype" ];
68+ phi::DataType dtype =
69+ dtype_attr.template dyn_cast <paddle::dialect::DataTypeAttribute>()
70+ .data ();
71+ if (dtype == phi::DataType::FLOAT32) {
72+ pir::Attribute new_dtype =
73+ paddle::dialect::DataTypeAttribute::get (ctx, phi::DataType::BFLOAT16);
74+ attributes[" dtype" ] = new_dtype;
75+ } else {
76+ return false ;
77+ }
78+
6779 std::unordered_map<std::string, pir::Attribute> q_attributes;
6880 q_attributes[" scale" ] = rewriter.float_attr (1 .0f );
6981 q_attributes[" shift" ] = rewriter.float_attr (0 .0f );
@@ -81,18 +93,6 @@ class CastBf16Pattern : public pir::OpRewritePattern<OpType> {
8193 type, pir::BFloat16Type::get (ctx), ctx);
8294 q_op->result (0 ).set_type (new_type);
8395
84- auto dtype_attr = attributes[" dtype" ];
85- phi::DataType dtype =
86- dtype_attr.template dyn_cast <paddle::dialect::DataTypeAttribute>()
87- .data ();
88- if (dtype == phi::DataType::FLOAT32) {
89- pir::Attribute new_dtype =
90- paddle::dialect::DataTypeAttribute::get (ctx, phi::DataType::BFLOAT16);
91- attributes[" dtype" ] = new_dtype;
92- } else {
93- return false ;
94- }
95-
9696 OpType new_cast = rewriter.Build <OpType>(q_op.output (), attributes);
9797
9898 std::unordered_map<std::string, pir::Attribute> dq_attributes;
0 commit comments