Skip to content

Commit 2278927

Browse files
gongshaotianEnigmatisms
authored andcommitted
[CINN] Support inferSymbolicShape for cf.tuple_pop and cf.tuple_push (PaddlePaddle#71153)
* add cache interface for cf.tuple_push * tuple push grad v2 * support tuple_pop infer symbolic shape * refine code * refine code * refine code * refine ApplyReduceAsToSumPass * revert some code
1 parent 9da7caf commit 2278927

File tree

5 files changed

+62
-22
lines changed

5 files changed

+62
-22
lines changed

paddle/fluid/pybind/pir.cc

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2606,18 +2606,12 @@ std::shared_ptr<Program> ApplyCommonSubexpressionEliminationPass(
26062606
return program;
26072607
}
26082608

2609-
std::shared_ptr<Program> ApplyReduceAsToSumPass(
2610-
std::shared_ptr<Program> program) {
2609+
void ApplyReduceAsToSumPass(
2610+
std::shared_ptr<pir::PassManager> &pass_manager, // NOLINT
2611+
pir::Program &program) { // NOLINT
26112612
#ifdef PADDLE_WITH_CINN
2612-
pir::PassManager pm(pir::IrContext::Instance(), 2);
2613-
pm.AddPass(cinn::dialect::ir::CreateReduceAsToSumPass());
2614-
pm.AddPass(pir::CreateDeadCodeEliminationPass());
2615-
pm.Run(program.get());
2616-
if (FLAGS_print_ir) {
2617-
std::cout << "IR After ReduceAsToSumPass -------------------" << std::endl;
2618-
std::cout << *program << std::endl;
2619-
}
2620-
return program;
2613+
pass_manager->AddPass(cinn::dialect::ir::CreateReduceAsToSumPass());
2614+
pass_manager->AddPass(pir::CreateDeadCodeEliminationPass());
26212615
#else
26222616
PADDLE_THROW(common::errors::Unimplemented(
26232617
"Currently we only support ReduceAsToSumPass Pass for Pir under "

paddle/pir/include/dialect/control_flow/ir/cf_op.h

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,10 @@
2020
#include "paddle/pir/include/core/op_base.h"
2121
#include "paddle/pir/include/core/op_trait.h"
2222
#include "paddle/pir/include/dialect/control_flow/ir/cf_interface.h"
23+
#include "paddle/pir/include/dialect/shape/interface/infer_symbolic_shape/cache_grad_op_symbolic_shape.h"
2324
#include "paddle/pir/include/dialect/shape/interface/infer_symbolic_shape/infer_symbolic_shape.h"
24-
25+
#include "paddle/pir/include/dialect/shape/utils/original_attributes_filter.h"
26+
#include "paddle/pir/include/dialect/shape/utils/shape_analysis.h"
2527
namespace pir {
2628
class IR_API YieldOp : public Op<YieldOp, SideEffectTrait> {
2729
public:
@@ -39,7 +41,9 @@ class IR_API YieldOp : public Op<YieldOp, SideEffectTrait> {
3941
///
4042
/// \brief Push a value tuple to a container.
4143
///
42-
class IR_API TuplePushOp : public Op<TuplePushOp, SideEffectTrait> {
44+
class IR_API TuplePushOp : public Op<TuplePushOp,
45+
SideEffectTrait,
46+
CacheGradOpSymbolicShapeInterface> {
4347
public:
4448
using Op::Op;
4549
static const char *name() { return "cf.tuple_push"; }
@@ -70,6 +74,8 @@ class IR_API TuplePushOp : public Op<TuplePushOp, SideEffectTrait> {
7074
return inlet().defining_op<ContainerOpInterface>();
7175
}
7276
TuplePopOp tuple_pop_op();
77+
78+
void CacheGradOpSymbolicShape(pir::InferSymbolicShapeContext *infer_context);
7379
};
7480

7581
class IR_API TuplePopOp : public Op<TuplePopOp, SideEffectTrait> {

paddle/pir/src/dialect/control_flow/ir/cf_op.cc

Lines changed: 29 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
#include "paddle/pir/include/core/ir_printer.h"
2020
#include "paddle/pir/include/dialect/control_flow/ir/cf_op.h"
2121
#include "paddle/pir/include/dialect/control_flow/ir/cf_type.h"
22-
2322
namespace pir {
2423

2524
void YieldOp::Build(Builder &builder,
@@ -90,6 +89,27 @@ TuplePopOp TuplePushOp::tuple_pop_op() {
9089
return container_interface().tuple_pop_op();
9190
}
9291

92+
void TuplePushOp::CacheGradOpSymbolicShape(
93+
pir::InferSymbolicShapeContext *infer_context) {
94+
const auto &x_shape =
95+
infer_context->GetShapeOrDataForValue(this->operand_source(0));
96+
std::string tuple_pop_name(TuplePopOp::name());
97+
pir::InferSymbolicShapeCacheKey op_shape_info(
98+
tuple_pop_name,
99+
{x_shape},
100+
pir::GetOrderedOriginalAttributes("cf.tuple_pop",
101+
this->operation()->attributes()));
102+
103+
std::vector<symbol::ShapeOrDataDimExprs> pop_value_shape_list;
104+
for (size_t index = 1; index < num_operands(); ++index) {
105+
const auto &pop_value_shape_or_data =
106+
infer_context->GetShapeOrDataForValue(this->operand_source(index));
107+
pop_value_shape_list.emplace_back(pop_value_shape_or_data);
108+
}
109+
infer_context->SetOpInferSymbolicShapeCache(op_shape_info,
110+
pop_value_shape_list);
111+
}
112+
93113
void TuplePopOp::Build(Builder &builder, // NOLINT
94114
OperationArgument &argument, // NOLINT
95115
Value outlet) {
@@ -202,11 +222,14 @@ void StackCreateOp::VerifySig() {
202222

203223
bool StackCreateOp::InferSymbolicShape(
204224
pir::InferSymbolicShapeContext *infer_context) {
205-
const auto &null_shape_or_data =
206-
symbol::ShapeOrDataDimExprs(symbol::NullShapeOrDataDimExpr());
207-
infer_context->SetShapeOrDataForValue(result(0), null_shape_or_data);
208-
infer_context->SetShapeOrDataForValue(result(1), null_shape_or_data);
209-
infer_context->SetShapeOrDataForValue(result(2), null_shape_or_data);
225+
std::vector<symbol::DimExpr> shape;
226+
shape.emplace_back(symbol::DimExpr(infer_context->GetNextSymName()));
227+
const symbol::ShapeOrDataDimExprs &mark_shape_or_data =
228+
symbol::ShapeOrDataDimExprs(symbol::TensorShapeOrDataDimExprs(shape));
229+
230+
infer_context->SetShapeOrDataForValue(result(0), mark_shape_or_data);
231+
infer_context->SetShapeOrDataForValue(result(1), mark_shape_or_data);
232+
infer_context->SetShapeOrDataForValue(result(2), mark_shape_or_data);
210233
return true;
211234
}
212235

paddle/pir/src/dialect/shape/transforms/shape_optimization_pass.cc

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -270,16 +270,28 @@ void InferSymExprForOp(Operation* op,
270270
op, infer_context->GetShapeOrDataForValue(op->result(0)));
271271
}
272272
} else {
273-
bool is_grad_op = [&]() {
273+
const bool is_grad_op = [&]() {
274274
std::string suffix = "_grad";
275275
const auto& op_name = op->name();
276276
if (op_name.size() < suffix.size()) return false;
277277
return op_name.compare(
278278
op_name.size() - suffix.size(), suffix.size(), suffix) == 0;
279279
}();
280+
281+
const bool is_special_cached_op = [&]() {
282+
const auto& op_name = op->name();
283+
std::vector<std::string> special_cached_ops = {
284+
"cf.tuple_pop",
285+
};
286+
return (std::find(special_cached_ops.begin(),
287+
special_cached_ops.end(),
288+
op_name) != special_cached_ops.end());
289+
}();
290+
280291
if (!is_grad_op)
281292
LOG(WARNING) << op->name()
282293
<< " DOES NOT have InferSymbolicShapeInterface!";
294+
283295
const bool all_outs_static_dims = [&] {
284296
bool all_static_dims = true;
285297
for (uint32_t i = 0; i < op->num_results(); ++i) {
@@ -293,7 +305,7 @@ void InferSymExprForOp(Operation* op,
293305
return all_static_dims;
294306
}();
295307

296-
if (all_outs_static_dims) {
308+
if (all_outs_static_dims && !is_special_cached_op) {
297309
for (uint32_t i = 0; i < op->num_results(); ++i) {
298310
infer_context->SetSymbolForValueByStaticShape(op->result(i));
299311
}

python/paddle/jit/dy2static/pir_partial_program.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -634,7 +634,12 @@ class FullGraphPreProcessPass(ValuePreservePass):
634634
def apply(self, program):
635635
program = paddle.base.libpaddle.pir.apply_bn_add_act_pass(program)
636636
if self.use_cinn_pass:
637-
program = paddle.base.libpaddle.pir.reduce_as_sum_pass(program)
637+
# NOTE(gongshaotian): execute infer_symbolic_shape_pass before reduce_as_sum_pass
638+
pm = paddle.base.libpaddle.pir.PassManager()
639+
pm.add_pass("delete_assert_op_pass", {})
640+
paddle.base.libpaddle.pir.infer_symbolic_shape_pass(pm, program)
641+
paddle.base.libpaddle.pir.reduce_as_sum_pass(pm, program)
642+
pm.run(program)
638643
return program
639644

640645

0 commit comments

Comments
 (0)