Skip to content

Commit 2275419

Browse files
authored
fix bug and change pass place in d2s. (#68759)
1 parent 3017878 commit 2275419

File tree

3 files changed

+19
-18
lines changed

3 files changed

+19
-18
lines changed

paddle/fluid/pir/transforms/general/auto_layout_pass.cc

Lines changed: 7 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ namespace {
3939

4040
class AutoLayoutPass : public pir::Pass {
4141
public:
42-
AutoLayoutPass() : pir::Pass("auto_layout_pass", 3) {}
42+
AutoLayoutPass() : pir::Pass("auto_layout_pass", 2) {}
4343
void Run(pir::Operation* op) override {
4444
for (size_t i = 0; i < op->num_regions(); ++i) {
4545
auto& region = op->region(i);
@@ -138,8 +138,10 @@ class AutoLayoutPass : public pir::Pass {
138138
if (op->HasTrait<pir::ImmutableLayoutTrait>()) continue;
139139
if (op->operands().size() == 0) continue;
140140

141-
// NHWC ops branch, Only support conv2d now, it will add white list later.
142-
if (op->isa<paddle::dialect::Conv2dOp>()) {
141+
// NHWC ops branch, Only support conv2d and fused_conv2d_add_act now, it
142+
// will add white list later.
143+
if (op->isa<paddle::dialect::Conv2dOp>() ||
144+
op->isa<paddle::dialect::FusedConv2dAddActOp>()) {
143145
if (op->HasAttribute("data_format") &&
144146
op->attribute<pir::StrAttribute>("data_format").AsString() ==
145147
"NCHW") {
@@ -160,14 +162,8 @@ class AutoLayoutPass : public pir::Pass {
160162
// Skip the operand which is not dense tensor or not 4-D tensor, they don't
161163
// need transpose.
162164
bool JudgeValue(const pir::Value& value) {
163-
if (!value) {
164-
PADDLE_THROW(common::errors::Fatal(
165-
"value is null, please check the input tensor."));
166-
}
167-
if (!value.type()) {
168-
PADDLE_THROW(common::errors::Fatal(
169-
"value type is null, please check the input tensor type."));
170-
}
165+
if (!value) return false;
166+
if (!value.type()) return false;
171167
if (auto type = value.type().dyn_cast<paddle::dialect::DenseTensorType>()) {
172168
return type.dims().size() == 4;
173169
}
@@ -204,7 +200,6 @@ class AutoLayoutPass : public pir::Pass {
204200
pir::Builder& builder) { // NOLINT
205201
builder.SetInsertionPointAfter(op);
206202
for (auto& result : op->results()) {
207-
if (result.use_empty()) continue;
208203
if (!JudgeValue(result)) continue;
209204
auto transpose_op =
210205
builder.Build<paddle::dialect::TransposeOp>(result, NHWC2NCHW_);

paddle/fluid/pir/transforms/general/auto_layout_simplify_pass.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ class RedundantTransposePattern
8585
class AutoLayoutSimplifyPass : public pir::PatternRewritePass {
8686
public:
8787
AutoLayoutSimplifyPass()
88-
: pir::PatternRewritePass("auto_layout_simplify_pass", 3) {}
88+
: pir::PatternRewritePass("auto_layout_simplify_pass", 2) {}
8989
pir::RewritePatternSet InitializePatterns(pir::IrContext* context) override {
9090
pir::RewritePatternSet ps(context);
9191
ps.Add<RedundantTransposePattern>(context);

python/paddle/jit/dy2static/pir_partial_program.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -674,11 +674,6 @@ def _get_scope(self, program_id=None, use_scope_cache=False):
674674
# whole
675675
@switch_to_static_graph
676676
def _create_program(self, is_infer_mode=False):
677-
if auto_layout_is_enabled():
678-
pm = paddle.pir.PassManager(3)
679-
pm.add_pass("auto_layout_pass", {})
680-
pm.add_pass("auto_layout_simplify_pass", {})
681-
pm.run(self._origin_main_program)
682677

683678
if is_infer_mode:
684679

@@ -704,6 +699,11 @@ def pass_fn(forward_program, backward_program, program_name_attr):
704699

705700
# TODO(xiongkun) who to transfer the pruning program?
706701
infer_program = self.origin_runnable_program.clone()
702+
if auto_layout_is_enabled():
703+
pm = paddle.pir.PassManager(2)
704+
pm.add_pass("auto_layout_pass", {})
705+
pm.add_pass("auto_layout_simplify_pass", {})
706+
pm.run(infer_program.program)
707707
for hooker in self._hookers:
708708
hooker.after_infer(infer_program)
709709
infer_program.apply_pir_program_pass(pass_fn)
@@ -712,6 +712,12 @@ def pass_fn(forward_program, backward_program, program_name_attr):
712712
train_program: RunnableProgram = (
713713
self.origin_runnable_program.clone()
714714
)
715+
# Author(liujinnan): auto_layout_pass should be applied to the original_program, before append backward. So we put it here.
716+
if auto_layout_is_enabled():
717+
pm = paddle.pir.PassManager(2)
718+
pm.add_pass("auto_layout_pass", {})
719+
pm.add_pass("auto_layout_simplify_pass", {})
720+
pm.run(train_program.program)
715721
train_program = self._append_backward_desc(train_program)
716722
# Note: Only set grad type once after initializing train program. So we put it here.
717723
self._set_grad_type(self._params, train_program)

0 commit comments

Comments
 (0)