|
19 | 19 | #include "paddle/pir/include/core/ir_printer.h" |
20 | 20 | #include "paddle/pir/include/dialect/control_flow/ir/cf_op.h" |
21 | 21 | #include "paddle/pir/include/dialect/control_flow/ir/cf_type.h" |
22 | | - |
23 | 22 | namespace pir { |
24 | 23 |
|
25 | 24 | void YieldOp::Build(Builder &builder, |
@@ -92,13 +91,18 @@ TuplePopOp TuplePushOp::tuple_pop_op() { |
92 | 91 |
|
93 | 92 | void TuplePushOp::CacheGradOpSymbolicShape( |
94 | 93 | pir::InferSymbolicShapeContext *infer_context) { |
95 | | - const auto &x_shape = GetInputShape(infer_context, this->operation(), 0); |
96 | | - pir::InferSymbolicShapeCacheKey op_shape_info("cf.tuple_pop", {x_shape}, ); |
| 94 | + const auto &x_shape = |
| 95 | + paddle::dialect::GetInputShape(infer_context, this->operation(), 0); |
| 96 | + pir::InferSymbolicShapeCacheKey op_shape_info( |
| 97 | + "cf.tuple_pop", |
| 98 | + {x_shape}, |
| 99 | + pir::GetOrderedOriginalAttributes("cf.tuple_pop", |
| 100 | + this->operation()->attributes())); |
97 | 101 |
|
98 | 102 | std::vector<symbol::ShapeOrDataDimExprs> pop_value_shape_list; |
99 | 103 | for (size_t index = 1; index < num_operands(); ++index) { |
100 | | - const auto &pop_value_shape = |
101 | | - GetGradVarShapeFromInput(infer_context, this->operation(), index); |
| 104 | + const auto &pop_value_shape = paddle::dialect::GetGradVarShapeFromInput( |
| 105 | + infer_context, this->operation(), index); |
102 | 106 | pop_value_shape_list.emplace_back(pop_value_shape); |
103 | 107 | } |
104 | 108 | infer_context->SetOpInferSymbolicShapeCache(op_shape_info, |
@@ -217,10 +221,14 @@ void StackCreateOp::VerifySig() { |
217 | 221 |
|
218 | 222 | bool StackCreateOp::InferSymbolicShape( |
219 | 223 | pir::InferSymbolicShapeContext *infer_context) { |
220 | | - symbol::DimExpr mark_symbol = infer_context->GetNextSymName(); |
221 | | - infer_context->SetShapeOrDataForValue(result(0), mark_symbol); |
222 | | - infer_context->SetShapeOrDataForValue(result(1), mark_symbol); |
223 | | - infer_context->SetShapeOrDataForValue(result(2), mark_symbol); |
| 224 | + std::vector<symbol::DimExpr> shape; |
| 225 | + shape.emplace_back(symbol::DimExpr(infer_context->GetNextSymName())); |
| 226 | + const symbol::ShapeOrDataDimExprs &mark_shape_or_data = |
| 227 | + symbol::ShapeOrDataDimExprs(symbol::TensorShapeOrDataDimExprs(shape)); |
| 228 | + |
| 229 | + infer_context->SetShapeOrDataForValue(result(0), mark_shape_or_data); |
| 230 | + infer_context->SetShapeOrDataForValue(result(1), mark_shape_or_data); |
| 231 | + infer_context->SetShapeOrDataForValue(result(2), mark_shape_or_data); |
224 | 232 | return true; |
225 | 233 | } |
226 | 234 |
|
|
0 commit comments