Skip to content

Commit 293ba3f

Browse files
committed
unify pass
2 parents 40d4bed + c5da2ce commit 293ba3f

File tree

3 files changed

+5
-20
lines changed

3 files changed

+5
-20
lines changed

paddle/fluid/pybind/pir.cc

Lines changed: 1 addition & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -2585,9 +2585,7 @@ void InferSymbolicShapePass(
25852585
pir::Program &program) { // NOLINT
25862586
pir::IrContext *ctx = pir::IrContext::Instance();
25872587
ctx->GetOrRegisterDialect<pir::shape::ShapeDialect>();
2588-
if (FLAGS_pir_apply_shape_optimization_pass) {
2589-
pass_manager->AddPass(pir::CreateShapeOptimizationPass());
2590-
}
2588+
pass_manager->AddPass(pir::CreateShapeOptimizationPass());
25912589
}
25922590

25932591
std::shared_ptr<Program> ApplyCommonSubexpressionEliminationPass(
@@ -2604,20 +2602,6 @@ std::shared_ptr<Program> ApplyCommonSubexpressionEliminationPass(
26042602
return program;
26052603
}
26062604

2607-
void ApplyReduceAsToSumPass(
2608-
std::shared_ptr<pir::PassManager> &pass_manager, // NOLINT
2609-
pir::Program &program) { // NOLINT
2610-
#ifdef PADDLE_WITH_CINN
2611-
pass_manager->AddPass(cinn::dialect::ir::CreateReduceAsToSumPass());
2612-
pass_manager->AddPass(pir::CreateDeadCodeEliminationPass());
2613-
#else
2614-
PADDLE_THROW(common::errors::Unimplemented(
2615-
"Currently we only support ReduceAsToSumPass Pass for Pir under "
2616-
"@to_static, please "
2617-
"compile PaddlePaddle with CINN"));
2618-
#endif
2619-
}
2620-
26212605
std::shared_ptr<Program> ApplyFusedBnAddActPass(
26222606
std::shared_ptr<Program> program) {
26232607
pir::PassManager pm(pir::IrContext::Instance(), 3);
@@ -2636,7 +2620,6 @@ void BindIrPass(pybind11::module *m) {
26362620
m->def("infer_symbolic_shape_pass", InferSymbolicShapePass);
26372621
m->def("apply_cse_pass", ApplyCommonSubexpressionEliminationPass);
26382622
m->def("apply_bn_add_act_pass", ApplyFusedBnAddActPass);
2639-
m->def("reduce_as_sum_pass", ApplyReduceAsToSumPass);
26402623

26412624
py::class_<Pass, std::shared_ptr<Pass>> pass(*m,
26422625
"Pass",

paddle/pir/include/dialect/control_flow/ir/cf_op.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
#include <functional>
1818

19+
#include "paddle/fluid/pir/dialect/operator/utils/shape_analysis_utils.h"
1920
#include "paddle/pir/include/core/builder.h"
2021
#include "paddle/pir/include/core/op_base.h"
2122
#include "paddle/pir/include/core/op_trait.h"

python/paddle/jit/dy2static/pir_partial_program.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -626,8 +626,9 @@ def apply(self, program):
626626
# NOTE(gongshaotian): execute infer_symbolic_shape_pass before reduce_as_sum_pass
627627
pm = paddle.base.libpaddle.pir.PassManager()
628628
pm.add_pass("delete_assert_op_pass", {})
629-
paddle.base.libpaddle.pir.infer_symbolic_shape_pass(pm, program)
630-
paddle.base.libpaddle.pir.reduce_as_sum_pass(pm, program)
629+
pm.add_pass("shape_optimization_pass", {})
630+
pm.add_pass("reduce_as_to_sum_pass", {})
631+
pm.add_pass("dead_code_elimination_pass", {})
631632
pm.run(program)
632633
return program
633634

0 commit comments

Comments
 (0)