Skip to content

Commit eb17f1f

Browse files
gongshaotianlixcli
authored andcommitted
[CINN]Refine code and Fix bug in InferSymbolicShape of reshape op (PaddlePaddle#65958)
* [CINN]Refine code and Fix bug in InferSymbolicShape of reshape op * fix bug where in_dims have symbol
1 parent e7149b3 commit eb17f1f

File tree

1 file changed

+83
-66
lines changed

1 file changed

+83
-66
lines changed

paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.cc

Lines changed: 83 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ bool ArgmaxOpInferSymbolicShape(pir::Operation *op,
7171
const auto &axis_shape_or_data =
7272
infer_context->GetShapeOrDataForValue(op->operand_source(1));
7373
int axis =
74-
static_cast<int>(axis_shape_or_data.data().value()[0].Get<int64_t>());
74+
static_cast<int>(axis_shape_or_data.data().value().at(0).Get<int64_t>());
7575
if (axis < 0) axis += rank;
7676

7777
const auto &out_sym_shape = [&] {
@@ -84,14 +84,14 @@ bool ArgmaxOpInferSymbolicShape(pir::Operation *op,
8484
}
8585
} else {
8686
for (int i = 0; i < axis; i++) {
87-
out_sym_shape.emplace_back(input_sym_shape[i]);
87+
out_sym_shape.emplace_back(input_sym_shape.at(i));
8888
}
8989
if (keepdims) {
9090
out_sym_shape.emplace_back(std::int64_t(1));
9191
}
9292

9393
for (int i = axis + 1; i < rank; i++) {
94-
out_sym_shape.emplace_back(input_sym_shape[i]);
94+
out_sym_shape.emplace_back(input_sym_shape.at(i));
9595
}
9696
}
9797
return out_sym_shape;
@@ -216,7 +216,7 @@ bool DiagEmbedOpInferSymbolicShape(
216216
int dim2_ = dim2 < 0 ? x_dims.size() + dim2 + 1 : dim2;
217217
int64_t offset_ = static_cast<int64_t>(std::abs(offset));
218218
symbol::DimExpr new_dim_len =
219-
symbol::DimExpr(offset_) + x_dims[x_dims.size() - 1];
219+
symbol::DimExpr(offset_) + x_dims.at(x_dims.size() - 1);
220220

221221
const auto &out_dims = [&] {
222222
std::vector<symbol::DimExpr> out_dims = x_dims;
@@ -245,8 +245,8 @@ bool DiagonalOpInferSymbolicShape(
245245
int axis2_ = axis2 < 0 ? x_dims.size() + axis2 : axis2;
246246

247247
auto out_dims = x_dims;
248-
auto axis1_size = out_dims[axis1_];
249-
auto axis2_size = out_dims[axis2_];
248+
auto axis1_size = out_dims.at(axis1_);
249+
auto axis2_size = out_dims.at(axis2_);
250250
out_dims.erase(out_dims.begin() + std::max(axis1_, axis2_));
251251
out_dims.erase(out_dims.begin() + std::min(axis1_, axis2_));
252252

@@ -308,7 +308,7 @@ bool DistributeFpnProposalsOpInferSymbolicShape(
308308
std::vector<symbol::DimExpr> level_dim = {next_sym_name, 4};
309309
multi_rois_out_shape.emplace_back(
310310
symbol::TensorShapeOrDataDimExprs(level_dim));
311-
last_dim = last_dim - level_dim[0];
311+
last_dim = last_dim - level_dim.at(0);
312312
}
313313
multi_rois_out_shape.emplace_back(symbol::TensorShapeOrDataDimExprs(
314314
{infer_context->GetNextSymName(), 4}));
@@ -381,14 +381,14 @@ bool FlattenOpInferSymbolicShape(
381381
std::vector<symbol::DimExpr> out_shape;
382382
out_shape.reserve(in_dims_size - stop_axis + start_axis + 1);
383383
for (int i = 0; i < start_axis; ++i) {
384-
out_shape.push_back(x_shape[i]);
384+
out_shape.push_back(x_shape.at(i));
385385
}
386386
for (int i = start_axis; i <= stop_axis; i++) {
387-
outer = outer * x_shape[i];
387+
outer = outer * x_shape.at(i);
388388
}
389389
out_shape.push_back(outer);
390390
for (int i = stop_axis + 1; i < in_dims_size; i++) {
391-
out_shape.push_back(x_shape[i]);
391+
out_shape.push_back(x_shape.at(i));
392392
}
393393

394394
symbol::ShapeOrDataDimExprs out_shape_data{
@@ -422,13 +422,13 @@ bool KthvalueOpInferSymbolicShape(
422422
if (axis < 0) axis += dim_size;
423423
std::vector<symbol::DimExpr> out_dims;
424424
for (int i = 0; i < axis; i++) {
425-
out_dims.emplace_back(input_dims[i]);
425+
out_dims.emplace_back(input_dims.at(i));
426426
}
427427
if (keepdim && dim_size > 0) {
428428
out_dims.emplace_back(symbol::DimExpr(1));
429429
}
430430
for (int i = axis + 1; i < dim_size; i++) {
431-
out_dims.emplace_back(input_dims[i]);
431+
out_dims.emplace_back(input_dims.at(i));
432432
}
433433
symbol::ShapeOrDataDimExprs shape_data{
434434
symbol::TensorShapeOrDataDimExprs(out_dims)};
@@ -539,7 +539,8 @@ bool PadOpInferSymbolicShape(pir::Operation *op,
539539
std::vector<symbol::DimExpr> out_dims;
540540
out_dims.reserve(rank);
541541
for (size_t i = 0; i < rank; ++i) {
542-
out_dims.push_back(x_dims_sym[i] + paddings[2 * i] + paddings[2 * i + 1]);
542+
out_dims.push_back(x_dims_sym.at(i) + paddings.at(2 * i) +
543+
paddings.at(2 * i + 1));
543544
}
544545
return out_dims;
545546
}();
@@ -583,15 +584,15 @@ bool Pad3dOpInferSymbolicShape(pir::Operation *op,
583584
"[6], but received [%d].",
584585
paddings.size()));
585586
if (data_format == "NCDHW") {
586-
out_dims[1] = x_shape[1];
587-
out_dims[2] = x_shape[2] + paddings[4] + paddings[5];
588-
out_dims[3] = x_shape[3] + paddings[2] + paddings[3];
589-
out_dims[4] = x_shape[4] + paddings[0] + paddings[1];
587+
out_dims.at(1) = x_shape.at(1);
588+
out_dims.at(2) = x_shape.at(2) + paddings.at(4) + paddings.at(5);
589+
out_dims.at(3) = x_shape.at(3) + paddings.at(2) + paddings.at(3);
590+
out_dims.at(4) = x_shape.at(4) + paddings.at(0) + paddings.at(1);
590591
} else {
591-
out_dims[1] = x_shape[1] + paddings[4] + paddings[5];
592-
out_dims[2] = x_shape[2] + paddings[2] + paddings[3];
593-
out_dims[3] = x_shape[3] + paddings[0] + paddings[1];
594-
out_dims[4] = x_shape[4];
592+
out_dims.at(1) = x_shape.at(1) + paddings.at(4) + paddings.at(5);
593+
out_dims.at(2) = x_shape.at(2) + paddings.at(2) + paddings.at(3);
594+
out_dims.at(3) = x_shape.at(3) + paddings.at(0) + paddings.at(1);
595+
out_dims.at(4) = x_shape.at(4);
595596
}
596597
return out_dims;
597598
}();
@@ -652,9 +653,9 @@ bool RepeatInterleaveOpInferSymbolicShape(
652653
std::vector<symbol::DimExpr> out_sym_shape;
653654
for (int i = 0; i < x_rank; i++) {
654655
if (i == axis) {
655-
out_sym_shape.push_back(in_dims_sym[i] * repeats);
656+
out_sym_shape.push_back(in_dims_sym.at(i) * repeats);
656657
} else {
657-
out_sym_shape.push_back(in_dims_sym[i]);
658+
out_sym_shape.push_back(in_dims_sym.at(i));
658659
}
659660
}
660661
return out_sym_shape;
@@ -692,6 +693,13 @@ bool ReshapeOpInferSymbolicShape(
692693
return true;
693694
};
694695

696+
const auto &IsPositiveInteger = [&](const symbol::DimExpr &dim_expr) {
697+
if (dim_expr.isa<int64_t>()) {
698+
return dim_expr.dyn_cast<int64_t>() > static_cast<int64_t>(0);
699+
}
700+
return true;
701+
};
702+
695703
const auto &IsZero = [&](const symbol::DimExpr &dim_expr) {
696704
if (dim_expr.isa<int64_t>()) {
697705
return dim_expr.dyn_cast<int64_t>() == static_cast<int64_t>(0);
@@ -702,26 +710,33 @@ bool ReshapeOpInferSymbolicShape(
702710
const std::vector<symbol::DimExpr> out_dims = [&] {
703711
const auto &original_shape =
704712
infer_context->GetShapeOrDataForValue(op->operand_source(0)).shape();
713+
ExprVec target_shape;
714+
if (shape_dim_expr.data().has_value()) {
715+
target_shape = shape_dim_expr.data().value();
716+
}
705717

718+
// replace '0' with original shape
719+
for (size_t i = 0; i < target_shape.size(); i++) {
720+
if (i < original_shape.size() && IsZero(target_shape.at(i))) {
721+
target_shape.at(i) = original_shape.at(i);
722+
}
723+
}
724+
725+
// replace '-1' with infered shape
706726
const auto &numel =
707727
GetProduct(original_shape, [](const auto &) { return true; });
708-
709-
ExprVec target_shape = details::GetExprVecFromData(shape_dim_expr);
710728
const auto &product_exclude_minus_one =
711-
GetProduct(target_shape, IsNotMinusOne);
712-
729+
GetProduct(target_shape, IsPositiveInteger);
713730
const auto &input_dims = target_shape;
714731

715732
std::vector<symbol::DimExpr> out_dims;
716733
out_dims.reserve(input_dims.size());
717734
for (size_t i = 0; i < input_dims.size(); ++i) {
718-
auto out_dim_expr = IsNotMinusOne(input_dims[i])
719-
? input_dims[i]
735+
auto out_dim_expr = IsNotMinusOne(input_dims.at(i))
736+
? input_dims.at(i)
720737
: (numel / product_exclude_minus_one);
721-
out_dim_expr = IsZero(input_dims[i]) ? original_shape[i] : out_dim_expr;
722738
out_dims.emplace_back(out_dim_expr);
723739
}
724-
725740
return out_dims;
726741
}();
727742

@@ -868,7 +883,7 @@ bool SplitOpInferSymbolicShape(pir::Operation *op,
868883
const bool &all_sections_sym_not_minus_one =
869884
All(sections_sym, IsNotMinusOne);
870885
if (all_sections_sym_not_minus_one) {
871-
infer_context->AddEqualCstr(x_dims_sym[axis], sum_exclude_minus_one);
886+
infer_context->AddEqualCstr(x_dims_sym.at(axis), sum_exclude_minus_one);
872887
}
873888

874889
symbol::TensorListShapeOrDataDimExprs shape_data_list;
@@ -881,10 +896,11 @@ bool SplitOpInferSymbolicShape(pir::Operation *op,
881896
return shape_data_list;
882897
}
883898
for (uint32_t idx = 0; idx < sections_sym.size(); idx++) {
884-
const auto &section_sym = sections_sym[idx];
885-
output_dims_sym[axis] = IsNotMinusOne(section_sym)
886-
? section_sym
887-
: x_dims_sym[axis] - sum_exclude_minus_one;
899+
const auto &section_sym = sections_sym.at(idx);
900+
output_dims_sym.at(axis) =
901+
IsNotMinusOne(section_sym)
902+
? section_sym
903+
: x_dims_sym.at(axis) - sum_exclude_minus_one;
888904

889905
shape_data_list.push_back(
890906
symbol::TensorShapeOrDataDimExprs(output_dims_sym));
@@ -1052,7 +1068,7 @@ bool TileOpInferSymbolicShape(pir::Operation *op,
10521068
}
10531069

10541070
for (size_t i = 0; i < repeat_times_dimexpr.size(); ++i) {
1055-
out_shape[i] = x_dimexpr[i] * repeat_times_dimexpr[i];
1071+
out_shape.at(i) = x_dimexpr.at(i) * repeat_times_dimexpr.at(i);
10561072
}
10571073

10581074
symbol::ShapeOrDataDimExprs shape_data{
@@ -1084,7 +1100,7 @@ bool TopkOpInferSymbolicShape(pir::Operation *op,
10841100

10851101
int x_rank = in_dims_sym.size();
10861102

1087-
int k = k_shape_or_data.data().value()[0].Get<int64_t>();
1103+
int k = k_shape_or_data.data().value().at(0).Get<int64_t>();
10881104

10891105
if (axis < 0) axis += x_rank;
10901106
const auto &out_sym_shape = [&] {
@@ -1093,7 +1109,7 @@ bool TopkOpInferSymbolicShape(pir::Operation *op,
10931109
if (i == axis) {
10941110
out_sym_shape.push_back(symbol::DimExpr(k));
10951111
} else {
1096-
out_sym_shape.push_back(in_dims_sym[i]);
1112+
out_sym_shape.push_back(in_dims_sym.at(i));
10971113
}
10981114
}
10991115
return out_sym_shape;
@@ -1161,7 +1177,7 @@ bool TransposeOpInferSymbolicShape(
11611177

11621178
std::vector<symbol::DimExpr> out_dims(x_dims);
11631179
for (int i = 0; i < axis_size; ++i) {
1164-
out_dims[i] = x_dims[formatted_axis[i]];
1180+
out_dims.at(i) = x_dims.at(formatted_axis.at(i));
11651181
}
11661182

11671183
infer_context->SetShapeOrDataForValue(op->result(0),
@@ -1225,26 +1241,27 @@ bool SqueezeOpInferSymbolicShape(
12251241
for (size_t i = 0; i < in_dims_sym.size(); ++i) {
12261242
// TODO(lanxianghit): if symbol here, maybe we need the result of dim expr
12271243
// simplification
1228-
if (in_dims_sym[i] == 1) {
1229-
should_squeeze[i] = true;
1244+
if (in_dims_sym.at(i) == 1) {
1245+
should_squeeze.at(i) = true;
12301246
}
12311247
}
12321248
} else {
12331249
for (size_t i = 0; i < num_squeeze_dims; ++i) {
12341250
if (in_dims_sym.size() == 0) {
12351251
continue;
12361252
}
1237-
int current = squeeze_dims[i] < 0 ? squeeze_dims[i] + in_dims_sym.size()
1238-
: squeeze_dims[i];
1253+
int current = squeeze_dims.at(i) < 0
1254+
? squeeze_dims.at(i) + in_dims_sym.size()
1255+
: squeeze_dims.at(i);
12391256

1240-
if (!should_squeeze[current]) {
1257+
if (!should_squeeze.at(current)) {
12411258
// At compile time, dim of SYMBOL is allowed to squeeze?
1242-
if (in_dims_sym[current] == 1) {
1243-
should_squeeze[current] = true;
1244-
} else if (!in_dims_sym[current].Has<std::int64_t>()) {
1245-
should_squeeze[current] = true;
1259+
if (in_dims_sym.at(current) == 1) {
1260+
should_squeeze.at(current) = true;
1261+
} else if (!in_dims_sym.at(current).Has<std::int64_t>()) {
1262+
should_squeeze.at(current) = true;
12461263
} else {
1247-
should_squeeze[current] = true;
1264+
should_squeeze.at(current) = true;
12481265
}
12491266
}
12501267
}
@@ -1253,8 +1270,8 @@ bool SqueezeOpInferSymbolicShape(
12531270
// Make output dimensions
12541271
std::vector<symbol::DimExpr> output_shape_sym;
12551272
for (size_t i = 0; i < in_dims_sym.size(); ++i) {
1256-
if (!should_squeeze[i]) {
1257-
output_shape_sym.emplace_back(in_dims_sym[i]);
1273+
if (!should_squeeze.at(i)) {
1274+
output_shape_sym.emplace_back(in_dims_sym.at(i));
12581275
}
12591276
}
12601277

@@ -1349,9 +1366,9 @@ bool UniqueOpInferSymbolicShape(pir::Operation *op,
13491366
return counts_dims;
13501367
}
13511368
std::vector<symbol::DimExpr> out_dims = x_dims_sym;
1352-
int axis = axes[0];
1369+
int axis = axes.at(0);
13531370
axis = axis >= 0 ? axis : axis + rank;
1354-
out_dims[axis] = unique_dim_sym;
1371+
out_dims.at(axis) = unique_dim_sym;
13551372
return out_dims;
13561373
}();
13571374

@@ -1365,9 +1382,9 @@ bool UniqueOpInferSymbolicShape(pir::Operation *op,
13651382
}
13661383
inverse_dims.push_back(product);
13671384
} else {
1368-
int axis = axes[0];
1385+
int axis = axes.at(0);
13691386
axis = axis >= 0 ? axis : axis + rank;
1370-
inverse_dims.push_back(x_dims_sym[axis]);
1387+
inverse_dims.push_back(x_dims_sym.at(axis));
13711388
}
13721389
return inverse_dims;
13731390
}();
@@ -1421,9 +1438,9 @@ bool UniqueConsecutiveOpInferSymbolicShape(
14211438
return counts_dims;
14221439
}
14231440
std::vector<symbol::DimExpr> out_dims = x_dims_sym;
1424-
int axis = axes[0];
1441+
int axis = axes.at(0);
14251442
axis = axis >= 0 ? axis : axis + rank;
1426-
out_dims[axis] = unique_dim_sym;
1443+
out_dims.at(axis) = unique_dim_sym;
14271444
return out_dims;
14281445
}();
14291446

@@ -1437,9 +1454,9 @@ bool UniqueConsecutiveOpInferSymbolicShape(
14371454
}
14381455
inverse_dims.push_back(product);
14391456
} else {
1440-
int axis = axes[0];
1457+
int axis = axes.at(0);
14411458
axis = axis >= 0 ? axis : axis + rank;
1442-
inverse_dims.push_back(x_dims_sym[axis]);
1459+
inverse_dims.push_back(x_dims_sym.at(axis));
14431460
}
14441461
return inverse_dims;
14451462
}();
@@ -1509,21 +1526,21 @@ bool UnsqueezeOpInferSymbolicShape(
15091526

15101527
// Move old axis, and insert new axis
15111528
for (int i = cur_output_rank; i >= cur; --i) {
1512-
if (result_sym_dims[i] == 1) {
1529+
if (result_sym_dims.at(i) == 1) {
15131530
// Move axis
1514-
result_sym_dims[i + 1] = 1;
1515-
result_sym_dims[i] = 0;
1531+
result_sym_dims.at(i + 1) = 1;
1532+
result_sym_dims.at(i) = 0;
15161533
}
15171534
}
1518-
result_sym_dims[cur] = 1;
1535+
result_sym_dims.at(cur) = 1;
15191536
// Add the output size.
15201537
cur_output_rank++;
15211538
}
15221539

15231540
// Make output shape
15241541
for (int in_idx = 0, out_idx = 0; out_idx < output_rank; ++out_idx) {
1525-
if (result_sym_dims[out_idx] == 0) {
1526-
result_sym_dims[out_idx] = x_sym_shape[in_idx++];
1542+
if (result_sym_dims.at(out_idx) == 0) {
1543+
result_sym_dims.at(out_idx) = x_sym_shape.at(in_idx++);
15271544
}
15281545
}
15291546

0 commit comments

Comments
 (0)