File tree Expand file tree Collapse file tree 2 files changed +16
-2
lines changed
paddle/fluid/pir/dialect/operator/ir Expand file tree Collapse file tree 2 files changed +16
-2
lines changed Original file line number Diff line number Diff line change @@ -2613,6 +2613,17 @@ std::vector<pir::Type> TensorToArrayOp::InferMeta(
26132613 return argument_outputs;
26142614}
26152615
2616+ bool TensorToArrayOp::InferSymbolicShape (
2617+ pir::InferSymbolicShapeContext *infer_context) {
2618+ const auto &x_shape_or_data =
2619+ infer_context->GetShapeOrDataForValue (x ())
2620+ .dyn_cast <symbol::RankedTensorArrayShapeOrDataDimExprs>();
2621+ infer_context->SetShapeOrDataForValue (
2622+ x_grad (), symbol::ShapeOrDataDimExprs{x_shape_or_data});
2623+
2624+ return true ;
2625+ }
2626+
26162627OpInfoTuple SliceArrayOp::GetOpInfo () {
26172628 std::vector<paddle::dialect::OpInputInfo> inputs = {
26182629 paddle::dialect::OpInputInfo (" input" ,
Original file line number Diff line number Diff line change @@ -407,8 +407,10 @@ class TEST_API ArrayToTensorOp : public pir::Op<ArrayToTensorOp,
407407 const std::vector<std::vector<bool >> &stop_gradients);
408408};
409409
410- class TEST_API TensorToArrayOp
411- : public pir::Op<TensorToArrayOp, OpYamlInfoInterface, InferMetaInterface> {
410+ class TEST_API TensorToArrayOp : public pir::Op<TensorToArrayOp,
411+ OpYamlInfoInterface,
412+ InferMetaInterface,
413+ InferSymbolicShapeInterface> {
412414 public:
413415 using Op::Op;
414416 static const char *name () { return " pd_op.tensor_to_array" ; }
@@ -429,6 +431,7 @@ class TEST_API TensorToArrayOp
429431 static std::vector<pir::Type> InferMeta (
430432 const std::vector<pir::Value> &input_values,
431433 pir::AttributeMap *p_attributes);
434+ bool InferSymbolicShape (pir::InferSymbolicShapeContext *infer_context);
432435};
433436
434437class TEST_API SliceArrayOp
You can’t perform that action at this time.
0 commit comments