Skip to content
16 changes: 5 additions & 11 deletions paddle/fluid/pybind/pir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2606,18 +2606,12 @@ std::shared_ptr<Program> ApplyCommonSubexpressionEliminationPass(
return program;
}

std::shared_ptr<Program> ApplyReduceAsToSumPass(
std::shared_ptr<Program> program) {
void ApplyReduceAsToSumPass(
std::shared_ptr<pir::PassManager> &pass_manager, // NOLINT
pir::Program &program) { // NOLINT
#ifdef PADDLE_WITH_CINN
pir::PassManager pm(pir::IrContext::Instance(), 2);
pm.AddPass(cinn::dialect::ir::CreateReduceAsToSumPass());
pm.AddPass(pir::CreateDeadCodeEliminationPass());
pm.Run(program.get());
if (FLAGS_print_ir) {
std::cout << "IR After ReduceAsToSumPass -------------------" << std::endl;
std::cout << *program << std::endl;
}
return program;
pass_manager->AddPass(cinn::dialect::ir::CreateReduceAsToSumPass());
pass_manager->AddPass(pir::CreateDeadCodeEliminationPass());
#else
PADDLE_THROW(common::errors::Unimplemented(
"Currently we only support ReduceAsToSumPass Pass for Pir under "
Expand Down
10 changes: 8 additions & 2 deletions paddle/pir/include/dialect/control_flow/ir/cf_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,10 @@
#include "paddle/pir/include/core/op_base.h"
#include "paddle/pir/include/core/op_trait.h"
#include "paddle/pir/include/dialect/control_flow/ir/cf_interface.h"
#include "paddle/pir/include/dialect/shape/interface/infer_symbolic_shape/cache_grad_op_symbolic_shape.h"
#include "paddle/pir/include/dialect/shape/interface/infer_symbolic_shape/infer_symbolic_shape.h"

#include "paddle/pir/include/dialect/shape/utils/original_attributes_filter.h"
#include "paddle/pir/include/dialect/shape/utils/shape_analysis.h"
namespace pir {
class IR_API YieldOp : public Op<YieldOp, SideEffectTrait> {
public:
Expand All @@ -39,7 +41,9 @@ class IR_API YieldOp : public Op<YieldOp, SideEffectTrait> {
///
/// \brief Push a value tuple to a container.
///
class IR_API TuplePushOp : public Op<TuplePushOp, SideEffectTrait> {
class IR_API TuplePushOp : public Op<TuplePushOp,
SideEffectTrait,
CacheGradOpSymbolicShapeInterface> {
public:
using Op::Op;
static const char *name() { return "cf.tuple_push"; }
Expand Down Expand Up @@ -70,6 +74,8 @@ class IR_API TuplePushOp : public Op<TuplePushOp, SideEffectTrait> {
return inlet().defining_op<ContainerOpInterface>();
}
TuplePopOp tuple_pop_op();

void CacheGradOpSymbolicShape(pir::InferSymbolicShapeContext *infer_context);
};

class IR_API TuplePopOp : public Op<TuplePopOp, SideEffectTrait> {
Expand Down
35 changes: 29 additions & 6 deletions paddle/pir/src/dialect/control_flow/ir/cf_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
#include "paddle/pir/include/core/ir_printer.h"
#include "paddle/pir/include/dialect/control_flow/ir/cf_op.h"
#include "paddle/pir/include/dialect/control_flow/ir/cf_type.h"

namespace pir {

void YieldOp::Build(Builder &builder,
Expand Down Expand Up @@ -90,6 +89,27 @@ TuplePopOp TuplePushOp::tuple_pop_op() {
return container_interface().tuple_pop_op();
}

void TuplePushOp::CacheGradOpSymbolicShape(
pir::InferSymbolicShapeContext *infer_context) {
const auto &x_shape =
infer_context->GetShapeOrDataForValue(this->operand_source(0));
std::string tuple_pop_name(TuplePopOp::name());
pir::InferSymbolicShapeCacheKey op_shape_info(
tuple_pop_name,
{x_shape},
pir::GetOrderedOriginalAttributes("cf.tuple_pop",
this->operation()->attributes()));

std::vector<symbol::ShapeOrDataDimExprs> pop_value_shape_list;
for (size_t index = 1; index < num_operands(); ++index) {
const auto &pop_value_shape_or_data =
infer_context->GetShapeOrDataForValue(this->operand_source(index));
pop_value_shape_list.emplace_back(pop_value_shape_or_data);
}
infer_context->SetOpInferSymbolicShapeCache(op_shape_info,
pop_value_shape_list);
}

void TuplePopOp::Build(Builder &builder, // NOLINT
OperationArgument &argument, // NOLINT
Value outlet) {
Expand Down Expand Up @@ -202,11 +222,14 @@ void StackCreateOp::VerifySig() {

bool StackCreateOp::InferSymbolicShape(
pir::InferSymbolicShapeContext *infer_context) {
const auto &null_shape_or_data =
symbol::ShapeOrDataDimExprs(symbol::NullShapeOrDataDimExpr());
infer_context->SetShapeOrDataForValue(result(0), null_shape_or_data);
infer_context->SetShapeOrDataForValue(result(1), null_shape_or_data);
infer_context->SetShapeOrDataForValue(result(2), null_shape_or_data);
std::vector<symbol::DimExpr> shape;
shape.emplace_back(symbol::DimExpr(infer_context->GetNextSymName()));
const symbol::ShapeOrDataDimExprs &mark_shape_or_data =
symbol::ShapeOrDataDimExprs(symbol::TensorShapeOrDataDimExprs(shape));

infer_context->SetShapeOrDataForValue(result(0), mark_shape_or_data);
infer_context->SetShapeOrDataForValue(result(1), mark_shape_or_data);
infer_context->SetShapeOrDataForValue(result(2), mark_shape_or_data);
return true;
}

Expand Down
16 changes: 14 additions & 2 deletions paddle/pir/src/dialect/shape/transforms/shape_optimization_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -270,16 +270,28 @@ void InferSymExprForOp(Operation* op,
op, infer_context->GetShapeOrDataForValue(op->result(0)));
}
} else {
bool is_grad_op = [&]() {
const bool is_grad_op = [&]() {
std::string suffix = "_grad";
const auto& op_name = op->name();
if (op_name.size() < suffix.size()) return false;
return op_name.compare(
op_name.size() - suffix.size(), suffix.size(), suffix) == 0;
}();

const bool is_special_cached_op = [&]() {
const auto& op_name = op->name();
std::vector<std::string> special_cached_ops = {
"cf.tuple_pop",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里建议使用:pir::TuplePopOp::name(),

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

好的,这里放下个PR修改

};
return (std::find(special_cached_ops.begin(),
special_cached_ops.end(),
op_name) != special_cached_ops.end());
}();

if (!is_grad_op)
LOG(WARNING) << op->name()
<< " DOES NOT have InferSymbolicShapeInterface!";

const bool all_outs_static_dims = [&] {
bool all_static_dims = true;
for (uint32_t i = 0; i < op->num_results(); ++i) {
Expand All @@ -293,7 +305,7 @@ void InferSymExprForOp(Operation* op,
return all_static_dims;
}();

if (all_outs_static_dims) {
if (all_outs_static_dims && !is_special_cached_op) {
for (uint32_t i = 0; i < op->num_results(); ++i) {
infer_context->SetSymbolForValueByStaticShape(op->result(i));
}
Expand Down
7 changes: 6 additions & 1 deletion python/paddle/jit/dy2static/pir_partial_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -634,7 +634,12 @@ class FullGraphPreProcessPass(ValuePreservePass):
def apply(self, program):
program = paddle.base.libpaddle.pir.apply_bn_add_act_pass(program)
if self.use_cinn_pass:
program = paddle.base.libpaddle.pir.reduce_as_sum_pass(program)
# NOTE(gongshaotian): execute infer_symbolic_shape_pass before reduce_as_sum_pass
pm = paddle.base.libpaddle.pir.PassManager()
pm.add_pass("delete_assert_op_pass", {})
paddle.base.libpaddle.pir.infer_symbolic_shape_pass(pm, program)
paddle.base.libpaddle.pir.reduce_as_sum_pass(pm, program)
pm.run(program)
return program


Expand Down
Loading