Skip to content

Commit 7831f62

Browse files
committed
unify pass
1 parent 0d7c438 commit 7831f62

File tree

2 files changed

+5
-19
lines changed

2 files changed

+5
-19
lines changed

paddle/fluid/pybind/pir.cc

Lines changed: 2 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -2582,8 +2582,8 @@ void CheckInferSymbolicIfNeed(Program &program) { // NOLINT
25822582
void InferSymbolicShapePass(
25832583
std::shared_ptr<pir::PassManager> &pass_manager, // NOLINT
25842584
pir::Program &program) { // NOLINT
2585-
pir::IrContext *ctx = pir::IrContext::Instance();
2586-
ctx->GetOrRegisterDialect<pir::shape::ShapeDialect>();
2585+
// pir::IrContext *ctx = pir::IrContext::Instance();
2586+
// ctx->GetOrRegisterDialect<pir::shape::ShapeDialect>();
25872587
pass_manager->AddPass(pir::CreateShapeOptimizationPass());
25882588
}
25892589

@@ -2601,20 +2601,6 @@ std::shared_ptr<Program> ApplyCommonSubexpressionEliminationPass(
26012601
return program;
26022602
}
26032603

2604-
void ApplyReduceAsToSumPass(
2605-
std::shared_ptr<pir::PassManager> &pass_manager, // NOLINT
2606-
pir::Program &program) { // NOLINT
2607-
#ifdef PADDLE_WITH_CINN
2608-
pass_manager->AddPass(cinn::dialect::ir::CreateReduceAsToSumPass());
2609-
pass_manager->AddPass(pir::CreateDeadCodeEliminationPass());
2610-
#else
2611-
PADDLE_THROW(common::errors::Unimplemented(
2612-
"Currently we only support ReduceAsToSumPass Pass for Pir under "
2613-
"@to_static, please "
2614-
"compile PaddlePaddle with CINN"));
2615-
#endif
2616-
}
2617-
26182604
std::shared_ptr<Program> ApplyFusedBnAddActPass(
26192605
std::shared_ptr<Program> program) {
26202606
pir::PassManager pm(pir::IrContext::Instance(), 3);
@@ -2633,7 +2619,6 @@ void BindIrPass(pybind11::module *m) {
26332619
m->def("infer_symbolic_shape_pass", InferSymbolicShapePass);
26342620
m->def("apply_cse_pass", ApplyCommonSubexpressionEliminationPass);
26352621
m->def("apply_bn_add_act_pass", ApplyFusedBnAddActPass);
2636-
m->def("reduce_as_sum_pass", ApplyReduceAsToSumPass);
26372622

26382623
py::class_<Pass, std::shared_ptr<Pass>> pass(*m,
26392624
"Pass",

python/paddle/jit/dy2static/pir_partial_program.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -626,8 +626,9 @@ def apply(self, program):
626626
# NOTE(gongshaotian): execute infer_symbolic_shape_pass before reduce_as_sum_pass
627627
pm = paddle.base.libpaddle.pir.PassManager()
628628
pm.add_pass("delete_assert_op_pass", {})
629-
paddle.base.libpaddle.pir.infer_symbolic_shape_pass(pm, program)
630-
paddle.base.libpaddle.pir.reduce_as_sum_pass(pm, program)
629+
pm.add_pass("shape_optimization_pass", {})
630+
pm.add_pass("reduce_as_to_sum_pass", {})
631+
pm.add_pass("dead_code_elimination_pass", {})
631632
pm.run(program)
632633
return program
633634

0 commit comments

Comments
 (0)