@@ -39,7 +39,7 @@ namespace {
3939
4040class AutoLayoutPass : public pir ::Pass {
4141 public:
42- AutoLayoutPass () : pir::Pass(" auto_layout_pass" , 3 ) {}
42+ AutoLayoutPass () : pir::Pass(" auto_layout_pass" , 2 ) {}
4343 void Run (pir::Operation* op) override {
4444 for (size_t i = 0 ; i < op->num_regions (); ++i) {
4545 auto & region = op->region (i);
@@ -138,8 +138,10 @@ class AutoLayoutPass : public pir::Pass {
138138 if (op->HasTrait <pir::ImmutableLayoutTrait>()) continue ;
139139 if (op->operands ().size () == 0 ) continue ;
140140
141- // NHWC ops branch, Only support conv2d now, it will add white list later.
142- if (op->isa <paddle::dialect::Conv2dOp>()) {
141+ // NHWC ops branch, Only support conv2d and fused_conv2d_add_act now, it
142+ // will add white list later.
143+ if (op->isa <paddle::dialect::Conv2dOp>() ||
144+ op->isa <paddle::dialect::FusedConv2dAddActOp>()) {
143145 if (op->HasAttribute (" data_format" ) &&
144146 op->attribute <pir::StrAttribute>(" data_format" ).AsString () ==
145147 " NCHW" ) {
@@ -160,14 +162,8 @@ class AutoLayoutPass : public pir::Pass {
160162 // Skip the operand which is not dense tensor or not 4-D tensor, they don't
161163 // need transpose.
162164 bool JudgeValue (const pir::Value& value) {
163- if (!value) {
164- PADDLE_THROW (common::errors::Fatal (
165- " value is null, please check the input tensor." ));
166- }
167- if (!value.type ()) {
168- PADDLE_THROW (common::errors::Fatal (
169- " value type is null, please check the input tensor type." ));
170- }
165+ if (!value) return false ;
166+ if (!value.type ()) return false ;
171167 if (auto type = value.type ().dyn_cast <paddle::dialect::DenseTensorType>()) {
172168 return type.dims ().size () == 4 ;
173169 }
@@ -204,7 +200,6 @@ class AutoLayoutPass : public pir::Pass {
204200 pir::Builder& builder) { // NOLINT
205201 builder.SetInsertionPointAfter (op);
206202 for (auto & result : op->results ()) {
207- if (result.use_empty ()) continue ;
208203 if (!JudgeValue (result)) continue ;
209204 auto transpose_op =
210205 builder.Build <paddle::dialect::TransposeOp>(result, NHWC2NCHW_);
0 commit comments