@@ -673,6 +673,43 @@ void OperationFactory::RegisterManualOpCreator() {
673673 attrs);
674674 });
675675
676+ RegisterOperationCreator (
677+ " onednn_op.scale" ,
678+ [](const std::vector<pir::Value>& inputs,
679+ const pir::AttributeMap& attrs,
680+ pir::PatternRewriter& rewriter) {
681+ if (inputs.size () == 2 ) {
682+ // Add after scale add this attr
683+ // PADDLE_ENFORCE_EQ(attrs.find("mkldnn_data_type") != attrs.end(),
684+ // true,
685+ // phi::errors::InvalidArgument(
686+ // "'mkldnn_data_type' Attribute is expected "
687+ // "for ScaleOp. "));
688+ // std::string mkldnn_data_type = attrs.at("mkldnn_data_type")
689+ // .dyn_cast<pir::StrAttribute>()
690+ // .AsString();
691+ PADDLE_ENFORCE_EQ (attrs.find (" bias_after_scale" ) != attrs.end (),
692+ true ,
693+ phi::errors::InvalidArgument (
694+ " 'bias_after_scale' Attribute is expected "
695+ " for ScaleOp. " ));
696+ bool bias_after_scale = attrs.at (" bias_after_scale" )
697+ .dyn_cast <pir::BoolAttribute>()
698+ .data ();
699+
700+ PADDLE_ENFORCE_EQ (
701+ attrs.find (" bias" ) != attrs.end (),
702+ true ,
703+ phi::errors::InvalidArgument (" 'bias' Attribute is expected "
704+ " for ScaleOp. " ));
705+ bool bias = attrs.at (" bias" ).dyn_cast <pir::FloatAttribute>().data ();
706+
707+ return rewriter.Build <paddle::onednn::dialect::ScaleOp>(
708+ inputs[0 ], inputs[1 ], bias, bias_after_scale);
709+ }
710+ return rewriter.Build <paddle::onednn::dialect::ScaleOp>(inputs[0 ],
711+ attrs);
712+ });
676713#endif
677714
678715 RegisterOperationCreator (
0 commit comments