Skip to content

Commit c5da2ce

Browse files
committed
tuple push grad v2
1 parent e2df3a0 commit c5da2ce

File tree

2 files changed

+20
-10
lines changed
  • paddle/pir

2 files changed

+20
-10
lines changed

paddle/pir/include/dialect/control_flow/ir/cf_op.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,12 @@
1616

1717
#include <functional>
1818

19+
#include "paddle/fluid/pir/dialect/operator/utils/shape_analysis_utils.h"
1920
#include "paddle/pir/include/core/builder.h"
2021
#include "paddle/pir/include/core/op_base.h"
2122
#include "paddle/pir/include/core/op_trait.h"
2223
#include "paddle/pir/include/dialect/control_flow/ir/cf_interface.h"
24+
#include "paddle/pir/include/dialect/shape/interface/infer_symbolic_shape/cache_grad_op_symbolic_shape.h"
2325
#include "paddle/pir/include/dialect/shape/interface/infer_symbolic_shape/infer_symbolic_shape.h"
2426

2527
namespace pir {
@@ -73,7 +75,7 @@ class IR_API TuplePushOp : public Op<TuplePushOp,
7375
}
7476
TuplePopOp tuple_pop_op();
7577

76-
CacheGradOpSymbolicShape(pir::InferSymbolicShapeContext *infer_context);
78+
void CacheGradOpSymbolicShape(pir::InferSymbolicShapeContext *infer_context);
7779
};
7880

7981
class IR_API TuplePopOp : public Op<TuplePopOp, SideEffectTrait> {

paddle/pir/src/dialect/control_flow/ir/cf_op.cc

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
#include "paddle/pir/include/core/ir_printer.h"
2020
#include "paddle/pir/include/dialect/control_flow/ir/cf_op.h"
2121
#include "paddle/pir/include/dialect/control_flow/ir/cf_type.h"
22-
2322
namespace pir {
2423

2524
void YieldOp::Build(Builder &builder,
@@ -92,13 +91,18 @@ TuplePopOp TuplePushOp::tuple_pop_op() {
9291

9392
void TuplePushOp::CacheGradOpSymbolicShape(
9493
pir::InferSymbolicShapeContext *infer_context) {
95-
const auto &x_shape = GetInputShape(infer_context, this->operation(), 0);
96-
pir::InferSymbolicShapeCacheKey op_shape_info("cf.tuple_pop", {x_shape}, );
94+
const auto &x_shape =
95+
paddle::dialect::GetInputShape(infer_context, this->operation(), 0);
96+
pir::InferSymbolicShapeCacheKey op_shape_info(
97+
"cf.tuple_pop",
98+
{x_shape},
99+
pir::GetOrderedOriginalAttributes("cf.tuple_pop",
100+
this->operation()->attributes()));
97101

98102
std::vector<symbol::ShapeOrDataDimExprs> pop_value_shape_list;
99103
for (size_t index = 1; index < num_operands(); ++index) {
100-
const auto &pop_value_shape =
101-
GetGradVarShapeFromInput(infer_context, this->operation(), index);
104+
const auto &pop_value_shape = paddle::dialect::GetGradVarShapeFromInput(
105+
infer_context, this->operation(), index);
102106
pop_value_shape_list.emplace_back(pop_value_shape);
103107
}
104108
infer_context->SetOpInferSymbolicShapeCache(op_shape_info,
@@ -217,10 +221,14 @@ void StackCreateOp::VerifySig() {
217221

218222
bool StackCreateOp::InferSymbolicShape(
219223
pir::InferSymbolicShapeContext *infer_context) {
220-
symbol::DimExpr mark_symbol = infer_context->GetNextSymName();
221-
infer_context->SetShapeOrDataForValue(result(0), mark_symbol);
222-
infer_context->SetShapeOrDataForValue(result(1), mark_symbol);
223-
infer_context->SetShapeOrDataForValue(result(2), mark_symbol);
224+
std::vector<symbol::DimExpr> shape;
225+
shape.emplace_back(symbol::DimExpr(infer_context->GetNextSymName()));
226+
const symbol::ShapeOrDataDimExprs &mark_shape_or_data =
227+
symbol::ShapeOrDataDimExprs(symbol::TensorShapeOrDataDimExprs(shape));
228+
229+
infer_context->SetShapeOrDataForValue(result(0), mark_shape_or_data);
230+
infer_context->SetShapeOrDataForValue(result(1), mark_shape_or_data);
231+
infer_context->SetShapeOrDataForValue(result(2), mark_shape_or_data);
224232
return true;
225233
}
226234

0 commit comments

Comments
 (0)