@@ -71,7 +71,7 @@ bool ArgmaxOpInferSymbolicShape(pir::Operation *op,
7171 const auto &axis_shape_or_data =
7272 infer_context->GetShapeOrDataForValue (op->operand_source (1 ));
7373 int axis =
74- static_cast <int >(axis_shape_or_data.data ().value ()[ 0 ] .Get <int64_t >());
74+ static_cast <int >(axis_shape_or_data.data ().value (). at ( 0 ) .Get <int64_t >());
7575 if (axis < 0 ) axis += rank;
7676
7777 const auto &out_sym_shape = [&] {
@@ -84,14 +84,14 @@ bool ArgmaxOpInferSymbolicShape(pir::Operation *op,
8484 }
8585 } else {
8686 for (int i = 0 ; i < axis; i++) {
87- out_sym_shape.emplace_back (input_sym_shape[i] );
87+ out_sym_shape.emplace_back (input_sym_shape. at (i) );
8888 }
8989 if (keepdims) {
9090 out_sym_shape.emplace_back (std::int64_t (1 ));
9191 }
9292
9393 for (int i = axis + 1 ; i < rank; i++) {
94- out_sym_shape.emplace_back (input_sym_shape[i] );
94+ out_sym_shape.emplace_back (input_sym_shape. at (i) );
9595 }
9696 }
9797 return out_sym_shape;
@@ -216,7 +216,7 @@ bool DiagEmbedOpInferSymbolicShape(
216216 int dim2_ = dim2 < 0 ? x_dims.size () + dim2 + 1 : dim2;
217217 int64_t offset_ = static_cast <int64_t >(std::abs (offset));
218218 symbol::DimExpr new_dim_len =
219- symbol::DimExpr (offset_) + x_dims[ x_dims.size () - 1 ] ;
219+ symbol::DimExpr (offset_) + x_dims. at ( x_dims.size () - 1 ) ;
220220
221221 const auto &out_dims = [&] {
222222 std::vector<symbol::DimExpr> out_dims = x_dims;
@@ -245,8 +245,8 @@ bool DiagonalOpInferSymbolicShape(
245245 int axis2_ = axis2 < 0 ? x_dims.size () + axis2 : axis2;
246246
247247 auto out_dims = x_dims;
248- auto axis1_size = out_dims[ axis1_] ;
249- auto axis2_size = out_dims[ axis2_] ;
248+ auto axis1_size = out_dims. at ( axis1_) ;
249+ auto axis2_size = out_dims. at ( axis2_) ;
250250 out_dims.erase (out_dims.begin () + std::max (axis1_, axis2_));
251251 out_dims.erase (out_dims.begin () + std::min (axis1_, axis2_));
252252
@@ -308,7 +308,7 @@ bool DistributeFpnProposalsOpInferSymbolicShape(
308308 std::vector<symbol::DimExpr> level_dim = {next_sym_name, 4 };
309309 multi_rois_out_shape.emplace_back (
310310 symbol::TensorShapeOrDataDimExprs (level_dim));
311- last_dim = last_dim - level_dim[ 0 ] ;
311+ last_dim = last_dim - level_dim. at ( 0 ) ;
312312 }
313313 multi_rois_out_shape.emplace_back (symbol::TensorShapeOrDataDimExprs (
314314 {infer_context->GetNextSymName (), 4 }));
@@ -381,14 +381,14 @@ bool FlattenOpInferSymbolicShape(
381381 std::vector<symbol::DimExpr> out_shape;
382382 out_shape.reserve (in_dims_size - stop_axis + start_axis + 1 );
383383 for (int i = 0 ; i < start_axis; ++i) {
384- out_shape.push_back (x_shape[i] );
384+ out_shape.push_back (x_shape. at (i) );
385385 }
386386 for (int i = start_axis; i <= stop_axis; i++) {
387- outer = outer * x_shape[i] ;
387+ outer = outer * x_shape. at (i) ;
388388 }
389389 out_shape.push_back (outer);
390390 for (int i = stop_axis + 1 ; i < in_dims_size; i++) {
391- out_shape.push_back (x_shape[i] );
391+ out_shape.push_back (x_shape. at (i) );
392392 }
393393
394394 symbol::ShapeOrDataDimExprs out_shape_data{
@@ -422,13 +422,13 @@ bool KthvalueOpInferSymbolicShape(
422422 if (axis < 0 ) axis += dim_size;
423423 std::vector<symbol::DimExpr> out_dims;
424424 for (int i = 0 ; i < axis; i++) {
425- out_dims.emplace_back (input_dims[i] );
425+ out_dims.emplace_back (input_dims. at (i) );
426426 }
427427 if (keepdim && dim_size > 0 ) {
428428 out_dims.emplace_back (symbol::DimExpr (1 ));
429429 }
430430 for (int i = axis + 1 ; i < dim_size; i++) {
431- out_dims.emplace_back (input_dims[i] );
431+ out_dims.emplace_back (input_dims. at (i) );
432432 }
433433 symbol::ShapeOrDataDimExprs shape_data{
434434 symbol::TensorShapeOrDataDimExprs (out_dims)};
@@ -539,7 +539,8 @@ bool PadOpInferSymbolicShape(pir::Operation *op,
539539 std::vector<symbol::DimExpr> out_dims;
540540 out_dims.reserve (rank);
541541 for (size_t i = 0 ; i < rank; ++i) {
542- out_dims.push_back (x_dims_sym[i] + paddings[2 * i] + paddings[2 * i + 1 ]);
542+ out_dims.push_back (x_dims_sym.at (i) + paddings.at (2 * i) +
543+ paddings.at (2 * i + 1 ));
543544 }
544545 return out_dims;
545546 }();
@@ -583,15 +584,15 @@ bool Pad3dOpInferSymbolicShape(pir::Operation *op,
583584 " [6], but received [%d]." ,
584585 paddings.size ()));
585586 if (data_format == " NCDHW" ) {
586- out_dims[ 1 ] = x_shape[ 1 ] ;
587- out_dims[ 2 ] = x_shape[ 2 ] + paddings[ 4 ] + paddings[ 5 ] ;
588- out_dims[ 3 ] = x_shape[ 3 ] + paddings[ 2 ] + paddings[ 3 ] ;
589- out_dims[ 4 ] = x_shape[ 4 ] + paddings[ 0 ] + paddings[ 1 ] ;
587+ out_dims. at ( 1 ) = x_shape. at ( 1 ) ;
588+ out_dims. at ( 2 ) = x_shape. at ( 2 ) + paddings. at ( 4 ) + paddings. at ( 5 ) ;
589+ out_dims. at ( 3 ) = x_shape. at ( 3 ) + paddings. at ( 2 ) + paddings. at ( 3 ) ;
590+ out_dims. at ( 4 ) = x_shape. at ( 4 ) + paddings. at ( 0 ) + paddings. at ( 1 ) ;
590591 } else {
591- out_dims[ 1 ] = x_shape[ 1 ] + paddings[ 4 ] + paddings[ 5 ] ;
592- out_dims[ 2 ] = x_shape[ 2 ] + paddings[ 2 ] + paddings[ 3 ] ;
593- out_dims[ 3 ] = x_shape[ 3 ] + paddings[ 0 ] + paddings[ 1 ] ;
594- out_dims[ 4 ] = x_shape[ 4 ] ;
592+ out_dims. at ( 1 ) = x_shape. at ( 1 ) + paddings. at ( 4 ) + paddings. at ( 5 ) ;
593+ out_dims. at ( 2 ) = x_shape. at ( 2 ) + paddings. at ( 2 ) + paddings. at ( 3 ) ;
594+ out_dims. at ( 3 ) = x_shape. at ( 3 ) + paddings. at ( 0 ) + paddings. at ( 1 ) ;
595+ out_dims. at ( 4 ) = x_shape. at ( 4 ) ;
595596 }
596597 return out_dims;
597598 }();
@@ -652,9 +653,9 @@ bool RepeatInterleaveOpInferSymbolicShape(
652653 std::vector<symbol::DimExpr> out_sym_shape;
653654 for (int i = 0 ; i < x_rank; i++) {
654655 if (i == axis) {
655- out_sym_shape.push_back (in_dims_sym[i] * repeats);
656+ out_sym_shape.push_back (in_dims_sym. at (i) * repeats);
656657 } else {
657- out_sym_shape.push_back (in_dims_sym[i] );
658+ out_sym_shape.push_back (in_dims_sym. at (i) );
658659 }
659660 }
660661 return out_sym_shape;
@@ -692,6 +693,13 @@ bool ReshapeOpInferSymbolicShape(
692693 return true ;
693694 };
694695
696+ const auto &IsPositiveInteger = [&](const symbol::DimExpr &dim_expr) {
697+ if (dim_expr.isa <int64_t >()) {
698+ return dim_expr.dyn_cast <int64_t >() > static_cast <int64_t >(0 );
699+ }
700+ return true ;
701+ };
702+
695703 const auto &IsZero = [&](const symbol::DimExpr &dim_expr) {
696704 if (dim_expr.isa <int64_t >()) {
697705 return dim_expr.dyn_cast <int64_t >() == static_cast <int64_t >(0 );
@@ -702,26 +710,33 @@ bool ReshapeOpInferSymbolicShape(
702710 const std::vector<symbol::DimExpr> out_dims = [&] {
703711 const auto &original_shape =
704712 infer_context->GetShapeOrDataForValue (op->operand_source (0 )).shape ();
713+ ExprVec target_shape;
714+ if (shape_dim_expr.data ().has_value ()) {
715+ target_shape = shape_dim_expr.data ().value ();
716+ }
705717
718+ // replace '0' with original shape
719+ for (size_t i = 0 ; i < target_shape.size (); i++) {
720+ if (i < original_shape.size () && IsZero (target_shape.at (i))) {
721+ target_shape.at (i) = original_shape.at (i);
722+ }
723+ }
724+
725+ // replace '-1' with infered shape
706726 const auto &numel =
707727 GetProduct (original_shape, [](const auto &) { return true ; });
708-
709- ExprVec target_shape = details::GetExprVecFromData (shape_dim_expr);
710728 const auto &product_exclude_minus_one =
711- GetProduct (target_shape, IsNotMinusOne);
712-
729+ GetProduct (target_shape, IsPositiveInteger);
713730 const auto &input_dims = target_shape;
714731
715732 std::vector<symbol::DimExpr> out_dims;
716733 out_dims.reserve (input_dims.size ());
717734 for (size_t i = 0 ; i < input_dims.size (); ++i) {
718- auto out_dim_expr = IsNotMinusOne (input_dims[i] )
719- ? input_dims[i]
735+ auto out_dim_expr = IsNotMinusOne (input_dims. at (i) )
736+ ? input_dims. at (i)
720737 : (numel / product_exclude_minus_one);
721- out_dim_expr = IsZero (input_dims[i]) ? original_shape[i] : out_dim_expr;
722738 out_dims.emplace_back (out_dim_expr);
723739 }
724-
725740 return out_dims;
726741 }();
727742
@@ -868,7 +883,7 @@ bool SplitOpInferSymbolicShape(pir::Operation *op,
868883 const bool &all_sections_sym_not_minus_one =
869884 All (sections_sym, IsNotMinusOne);
870885 if (all_sections_sym_not_minus_one) {
871- infer_context->AddEqualCstr (x_dims_sym[ axis] , sum_exclude_minus_one);
886+ infer_context->AddEqualCstr (x_dims_sym. at ( axis) , sum_exclude_minus_one);
872887 }
873888
874889 symbol::TensorListShapeOrDataDimExprs shape_data_list;
@@ -881,10 +896,11 @@ bool SplitOpInferSymbolicShape(pir::Operation *op,
881896 return shape_data_list;
882897 }
883898 for (uint32_t idx = 0 ; idx < sections_sym.size (); idx++) {
884- const auto §ion_sym = sections_sym[idx];
885- output_dims_sym[axis] = IsNotMinusOne (section_sym)
886- ? section_sym
887- : x_dims_sym[axis] - sum_exclude_minus_one;
899+ const auto §ion_sym = sections_sym.at (idx);
900+ output_dims_sym.at (axis) =
901+ IsNotMinusOne (section_sym)
902+ ? section_sym
903+ : x_dims_sym.at (axis) - sum_exclude_minus_one;
888904
889905 shape_data_list.push_back (
890906 symbol::TensorShapeOrDataDimExprs (output_dims_sym));
@@ -1052,7 +1068,7 @@ bool TileOpInferSymbolicShape(pir::Operation *op,
10521068 }
10531069
10541070 for (size_t i = 0 ; i < repeat_times_dimexpr.size (); ++i) {
1055- out_shape[i] = x_dimexpr[i] * repeat_times_dimexpr[i] ;
1071+ out_shape. at (i) = x_dimexpr. at (i) * repeat_times_dimexpr. at (i) ;
10561072 }
10571073
10581074 symbol::ShapeOrDataDimExprs shape_data{
@@ -1084,7 +1100,7 @@ bool TopkOpInferSymbolicShape(pir::Operation *op,
10841100
10851101 int x_rank = in_dims_sym.size ();
10861102
1087- int k = k_shape_or_data.data ().value ()[ 0 ] .Get <int64_t >();
1103+ int k = k_shape_or_data.data ().value (). at ( 0 ) .Get <int64_t >();
10881104
10891105 if (axis < 0 ) axis += x_rank;
10901106 const auto &out_sym_shape = [&] {
@@ -1093,7 +1109,7 @@ bool TopkOpInferSymbolicShape(pir::Operation *op,
10931109 if (i == axis) {
10941110 out_sym_shape.push_back (symbol::DimExpr (k));
10951111 } else {
1096- out_sym_shape.push_back (in_dims_sym[i] );
1112+ out_sym_shape.push_back (in_dims_sym. at (i) );
10971113 }
10981114 }
10991115 return out_sym_shape;
@@ -1161,7 +1177,7 @@ bool TransposeOpInferSymbolicShape(
11611177
11621178 std::vector<symbol::DimExpr> out_dims (x_dims);
11631179 for (int i = 0 ; i < axis_size; ++i) {
1164- out_dims[i] = x_dims[ formatted_axis[i]] ;
1180+ out_dims. at (i) = x_dims. at ( formatted_axis. at (i)) ;
11651181 }
11661182
11671183 infer_context->SetShapeOrDataForValue (op->result (0 ),
@@ -1225,26 +1241,27 @@ bool SqueezeOpInferSymbolicShape(
12251241 for (size_t i = 0 ; i < in_dims_sym.size (); ++i) {
12261242 // TODO(lanxianghit): if symbol here, maybe we need the result of dim expr
12271243 // simplification
1228- if (in_dims_sym[i] == 1 ) {
1229- should_squeeze[i] = true ;
1244+ if (in_dims_sym. at (i) == 1 ) {
1245+ should_squeeze. at (i) = true ;
12301246 }
12311247 }
12321248 } else {
12331249 for (size_t i = 0 ; i < num_squeeze_dims; ++i) {
12341250 if (in_dims_sym.size () == 0 ) {
12351251 continue ;
12361252 }
1237- int current = squeeze_dims[i] < 0 ? squeeze_dims[i] + in_dims_sym.size ()
1238- : squeeze_dims[i];
1253+ int current = squeeze_dims.at (i) < 0
1254+ ? squeeze_dims.at (i) + in_dims_sym.size ()
1255+ : squeeze_dims.at (i);
12391256
1240- if (!should_squeeze[ current] ) {
1257+ if (!should_squeeze. at ( current) ) {
12411258 // At compile time, dim of SYMBOL is allowed to squeeze?
1242- if (in_dims_sym[ current] == 1 ) {
1243- should_squeeze[ current] = true ;
1244- } else if (!in_dims_sym[ current] .Has <std::int64_t >()) {
1245- should_squeeze[ current] = true ;
1259+ if (in_dims_sym. at ( current) == 1 ) {
1260+ should_squeeze. at ( current) = true ;
1261+ } else if (!in_dims_sym. at ( current) .Has <std::int64_t >()) {
1262+ should_squeeze. at ( current) = true ;
12461263 } else {
1247- should_squeeze[ current] = true ;
1264+ should_squeeze. at ( current) = true ;
12481265 }
12491266 }
12501267 }
@@ -1253,8 +1270,8 @@ bool SqueezeOpInferSymbolicShape(
12531270 // Make output dimensions
12541271 std::vector<symbol::DimExpr> output_shape_sym;
12551272 for (size_t i = 0 ; i < in_dims_sym.size (); ++i) {
1256- if (!should_squeeze[i] ) {
1257- output_shape_sym.emplace_back (in_dims_sym[i] );
1273+ if (!should_squeeze. at (i) ) {
1274+ output_shape_sym.emplace_back (in_dims_sym. at (i) );
12581275 }
12591276 }
12601277
@@ -1349,9 +1366,9 @@ bool UniqueOpInferSymbolicShape(pir::Operation *op,
13491366 return counts_dims;
13501367 }
13511368 std::vector<symbol::DimExpr> out_dims = x_dims_sym;
1352- int axis = axes[ 0 ] ;
1369+ int axis = axes. at ( 0 ) ;
13531370 axis = axis >= 0 ? axis : axis + rank;
1354- out_dims[ axis] = unique_dim_sym;
1371+ out_dims. at ( axis) = unique_dim_sym;
13551372 return out_dims;
13561373 }();
13571374
@@ -1365,9 +1382,9 @@ bool UniqueOpInferSymbolicShape(pir::Operation *op,
13651382 }
13661383 inverse_dims.push_back (product);
13671384 } else {
1368- int axis = axes[ 0 ] ;
1385+ int axis = axes. at ( 0 ) ;
13691386 axis = axis >= 0 ? axis : axis + rank;
1370- inverse_dims.push_back (x_dims_sym[ axis] );
1387+ inverse_dims.push_back (x_dims_sym. at ( axis) );
13711388 }
13721389 return inverse_dims;
13731390 }();
@@ -1421,9 +1438,9 @@ bool UniqueConsecutiveOpInferSymbolicShape(
14211438 return counts_dims;
14221439 }
14231440 std::vector<symbol::DimExpr> out_dims = x_dims_sym;
1424- int axis = axes[ 0 ] ;
1441+ int axis = axes. at ( 0 ) ;
14251442 axis = axis >= 0 ? axis : axis + rank;
1426- out_dims[ axis] = unique_dim_sym;
1443+ out_dims. at ( axis) = unique_dim_sym;
14271444 return out_dims;
14281445 }();
14291446
@@ -1437,9 +1454,9 @@ bool UniqueConsecutiveOpInferSymbolicShape(
14371454 }
14381455 inverse_dims.push_back (product);
14391456 } else {
1440- int axis = axes[ 0 ] ;
1457+ int axis = axes. at ( 0 ) ;
14411458 axis = axis >= 0 ? axis : axis + rank;
1442- inverse_dims.push_back (x_dims_sym[ axis] );
1459+ inverse_dims.push_back (x_dims_sym. at ( axis) );
14431460 }
14441461 return inverse_dims;
14451462 }();
@@ -1509,21 +1526,21 @@ bool UnsqueezeOpInferSymbolicShape(
15091526
15101527 // Move old axis, and insert new axis
15111528 for (int i = cur_output_rank; i >= cur; --i) {
1512- if (result_sym_dims[i] == 1 ) {
1529+ if (result_sym_dims. at (i) == 1 ) {
15131530 // Move axis
1514- result_sym_dims[ i + 1 ] = 1 ;
1515- result_sym_dims[i] = 0 ;
1531+ result_sym_dims. at ( i + 1 ) = 1 ;
1532+ result_sym_dims. at (i) = 0 ;
15161533 }
15171534 }
1518- result_sym_dims[ cur] = 1 ;
1535+ result_sym_dims. at ( cur) = 1 ;
15191536 // Add the output size.
15201537 cur_output_rank++;
15211538 }
15221539
15231540 // Make output shape
15241541 for (int in_idx = 0 , out_idx = 0 ; out_idx < output_rank; ++out_idx) {
1525- if (result_sym_dims[ out_idx] == 0 ) {
1526- result_sym_dims[ out_idx] = x_sym_shape[ in_idx++] ;
1542+ if (result_sym_dims. at ( out_idx) == 0 ) {
1543+ result_sym_dims. at ( out_idx) = x_sym_shape. at ( in_idx++) ;
15271544 }
15281545 }
15291546
0 commit comments