Skip to content

Commit 2ca729b

Browse files
add reshape build
1 parent bd92abc commit 2ca729b

File tree

1 file changed

+37
-0
lines changed

1 file changed

+37
-0
lines changed

paddle/fluid/pir/drr/src/ir_operation_factory.cc

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -673,6 +673,43 @@ void OperationFactory::RegisterManualOpCreator() {
673673
attrs);
674674
});
675675

676+
RegisterOperationCreator(
677+
"onednn_op.scale",
678+
[](const std::vector<pir::Value>& inputs,
679+
const pir::AttributeMap& attrs,
680+
pir::PatternRewriter& rewriter) {
681+
if (inputs.size() == 2) {
682+
// Add after scale add this attr
683+
// PADDLE_ENFORCE_EQ(attrs.find("mkldnn_data_type") != attrs.end(),
684+
// true,
685+
// phi::errors::InvalidArgument(
686+
// "'mkldnn_data_type' Attribute is expected "
687+
// "for ScaleOp. "));
688+
// std::string mkldnn_data_type = attrs.at("mkldnn_data_type")
689+
// .dyn_cast<pir::StrAttribute>()
690+
// .AsString();
691+
PADDLE_ENFORCE_EQ(attrs.find("bias_after_scale") != attrs.end(),
692+
true,
693+
phi::errors::InvalidArgument(
694+
"'bias_after_scale' Attribute is expected "
695+
"for ScaleOp. "));
696+
bool bias_after_scale = attrs.at("bias_after_scale")
697+
.dyn_cast<pir::BoolAttribute>()
698+
.data();
699+
700+
PADDLE_ENFORCE_EQ(
701+
attrs.find("bias") != attrs.end(),
702+
true,
703+
phi::errors::InvalidArgument("'bias' Attribute is expected "
704+
"for ScaleOp. "));
705+
bool bias = attrs.at("bias").dyn_cast<pir::FloatAttribute>().data();
706+
707+
return rewriter.Build<paddle::onednn::dialect::ScaleOp>(
708+
inputs[0], inputs[1], bias, bias_after_scale);
709+
}
710+
return rewriter.Build<paddle::onednn::dialect::ScaleOp>(inputs[0],
711+
attrs);
712+
});
676713
#endif
677714

678715
RegisterOperationCreator(

0 commit comments

Comments
 (0)