Skip to content

Commit 8cdf2c1

Browse files
authored
[CINN]Add InferSymbolicShape for tensor_to_array (#69069)
* test * Refine logic * Refine logic * revert useless code * apply review * fix codestyle * revert select_output * fix codestyle
1 parent 77bd1e9 commit 8cdf2c1

File tree

2 files changed

+16
-2
lines changed

2 files changed

+16
-2
lines changed

paddle/fluid/pir/dialect/operator/ir/manual_op.cc

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2613,6 +2613,17 @@ std::vector<pir::Type> TensorToArrayOp::InferMeta(
26132613
return argument_outputs;
26142614
}
26152615

2616+
bool TensorToArrayOp::InferSymbolicShape(
2617+
pir::InferSymbolicShapeContext *infer_context) {
2618+
const auto &x_shape_or_data =
2619+
infer_context->GetShapeOrDataForValue(x())
2620+
.dyn_cast<symbol::RankedTensorArrayShapeOrDataDimExprs>();
2621+
infer_context->SetShapeOrDataForValue(
2622+
x_grad(), symbol::ShapeOrDataDimExprs{x_shape_or_data});
2623+
2624+
return true;
2625+
}
2626+
26162627
OpInfoTuple SliceArrayOp::GetOpInfo() {
26172628
std::vector<paddle::dialect::OpInputInfo> inputs = {
26182629
paddle::dialect::OpInputInfo("input",

paddle/fluid/pir/dialect/operator/ir/manual_op.h

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -407,8 +407,10 @@ class TEST_API ArrayToTensorOp : public pir::Op<ArrayToTensorOp,
407407
const std::vector<std::vector<bool>> &stop_gradients);
408408
};
409409

410-
class TEST_API TensorToArrayOp
411-
: public pir::Op<TensorToArrayOp, OpYamlInfoInterface, InferMetaInterface> {
410+
class TEST_API TensorToArrayOp : public pir::Op<TensorToArrayOp,
411+
OpYamlInfoInterface,
412+
InferMetaInterface,
413+
InferSymbolicShapeInterface> {
412414
public:
413415
using Op::Op;
414416
static const char *name() { return "pd_op.tensor_to_array"; }
@@ -429,6 +431,7 @@ class TEST_API TensorToArrayOp
429431
static std::vector<pir::Type> InferMeta(
430432
const std::vector<pir::Value> &input_values,
431433
pir::AttributeMap *p_attributes);
434+
bool InferSymbolicShape(pir::InferSymbolicShapeContext *infer_context);
432435
};
433436

434437
class TEST_API SliceArrayOp

0 commit comments

Comments
 (0)