Skip to content

Commit c39ac39

Browse files
authored
fix bug of infer symbol shape about tensor_list (#66662)
1 parent 6a75bf1 commit c39ac39

File tree

1 file changed

+6
-24
lines changed

1 file changed

+6
-24
lines changed

paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.cc

Lines changed: 6 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -925,17 +925,9 @@ bool SplitOpInferSymbolicShape(pir::Operation *op,
925925
axis = axis >= 0 ? axis : std::max(int64_t(0), int64_t(axis + rank));
926926

927927
// sections
928-
const std::vector<symbol::DimExpr> &sections_sym = [&] {
929-
const auto &sections_shape_or_data =
930-
infer_context->GetShapeOrDataForValue(op->operand_source(1));
931-
std::vector<symbol::DimExpr> sections_sym;
932-
if (sections_shape_or_data.data().has_value()) {
933-
sections_sym = sections_shape_or_data.data().value();
934-
} else {
935-
sections_sym = sections_shape_or_data.shape();
936-
}
937-
return sections_sym;
938-
}();
928+
const std::vector<symbol::DimExpr> &sections_sym =
929+
details::GetExprVecFromData(
930+
infer_context->GetShapeOrDataForValue(op->operand_source(1)));
939931

940932
// output
941933
const symbol::TensorListShapeOrDataDimExprs &output_shape_data_list = [&] {
@@ -1123,19 +1115,9 @@ bool TileOpInferSymbolicShape(pir::Operation *op,
11231115
symbol::ShapeOrDataDimExprs repeat_times_shape_or_data =
11241116
infer_context->GetShapeOrDataForValue(operand_repeat_times);
11251117

1126-
std::vector<symbol::DimExpr> x_dimexpr;
1127-
if (x_shape_or_data.data().has_value()) {
1128-
x_dimexpr = x_shape_or_data.data().value();
1129-
} else {
1130-
x_dimexpr = x_shape_or_data.shape();
1131-
}
1132-
1133-
std::vector<symbol::DimExpr> repeat_times_dimexpr;
1134-
if (repeat_times_shape_or_data.data().has_value()) {
1135-
repeat_times_dimexpr = repeat_times_shape_or_data.data().value();
1136-
} else {
1137-
repeat_times_dimexpr = repeat_times_shape_or_data.shape();
1138-
}
1118+
std::vector<symbol::DimExpr> x_dimexpr = x_shape_or_data.shape();
1119+
std::vector<symbol::DimExpr> repeat_times_dimexpr =
1120+
details::GetExprVecFromData(repeat_times_shape_or_data);
11391121
if (repeat_times_dimexpr.empty()) {
11401122
repeat_times_dimexpr = std::vector<symbol::DimExpr>(x_dimexpr.size(), 1);
11411123
}

0 commit comments

Comments
 (0)