@@ -139,7 +139,6 @@ inline ExprVec GetSliceDims(const ExprVec &in_dims,
139139 for (size_t i = 0 ; i < axes.size (); ++i) {
140140 auto out_dim = ends[i] - starts[i];
141141 int64_t axis = axes[i];
142-
143142 // If in_dims[axis] or ends[i] have symbol, nedd get Min(in_dims[axis] -
144143 // start[i], ends[i] - start[i] )
145144 if (!out_dim.isa <int64_t >() &&
@@ -291,219 +290,4 @@ inline ShapeOrData SliceRawInferSymbolicShape(
291290
292291 return out_shape;
293292}
294-
295- inline ExprVec GetStridesSliceDims (
296- const ExprVec &in_dims,
297- const std::vector<int64_t > &axes,
298- const ExprVec &starts_base,
299- const ExprVec &ends_base,
300- const ExprVec &strides_base,
301- std::vector<int64_t > *infer_flags = nullptr ) {
302- ExprVec starts = starts_base;
303- ExprVec ends = ends_base;
304- ExprVec strides = strides_base;
305- auto IsMaxInt = [](const symbol::DimExpr &expr) {
306- return expr.isa <int64_t >() &&
307- expr.Get <int64_t >() ==
308- static_cast <int64_t >(std::numeric_limits<int >::max ());
309- };
310-
311- for (size_t i = 0 ; i < axes.size (); ++i) {
312- int64_t axis = axes.at (i);
313- int64_t start_i = 0 ;
314-
315- if (starts.at (i).isa <int64_t >()) {
316- if (in_dims.at (axis).isa <int64_t >()) {
317- starts.at (i) =
318- (starts.at (i).Get <int64_t >() > in_dims.at (axis).Get <int64_t >())
319- ? in_dims.at (axis)
320- : starts.at (i);
321- starts.at (i) =
322- (starts.at (i).Get <int64_t >() < -in_dims.at (axis).Get <int64_t >())
323- ? symbol::DimExpr ({-1 }) * in_dims.at (axis)
324- : starts.at (i);
325- }
326- start_i = starts.at (i).Get <int64_t >();
327- }
328-
329- int64_t end_i = 0 ;
330- if (ends.at (i).isa <int64_t >()) {
331- if (in_dims.at (axis).isa <int64_t >()) {
332- ends[i] = std::min (ends.at (i).Get <int64_t >(),
333- in_dims.at (axis).Get <int64_t >());
334- }
335- if (ends.at (i).Get <int64_t >() < 0 ) {
336- ends[i] = ends.at (i) + in_dims.at (axis);
337- }
338- if (ends.at (i).isa <int64_t >()) {
339- end_i = ends.at (i).Get <int64_t >();
340- }
341- }
342-
343- ends.at (i) = IsMaxInt (ends.at (i)) ? in_dims.at (axis) : ends.at (i);
344- bool both_negative_or_positive =
345- (start_i >= 0 && end_i >= 0 ) || (start_i <= 0 && end_i <= 0 );
346- bool start_negative_end_positive = start_i <= 0 && end_i >= 0 ;
347- bool start_positive_end_negative = start_i >= 0 && end_i <= 0 ;
348-
349- if (both_negative_or_positive) {
350- continue ;
351- } else if (start_negative_end_positive) {
352- starts.at (i) = starts.at (i) + in_dims.at (axis);
353- } else if (start_positive_end_negative) {
354- starts.at (i) = starts.at (i) - in_dims.at (axis);
355- } else {
356- PADDLE_THROW (common::errors::Fatal (" Dead code" ));
357- }
358- }
359-
360- ExprVec slice_dims (in_dims);
361- PADDLE_ENFORCE_EQ (
362- (axes.size () == starts.size () && axes.size () == ends.size () &&
363- axes.size () == strides.size ()),
364- true ,
365- common::errors::InvalidArgument (
366- " The size of axes must equal size of starts, ends, and strides." ));
367-
368- for (size_t i = 0 ; i < axes.size (); ++i) {
369- auto out_dim = symbol::DimExpr ({-1 }) * ((starts[i] - ends[i]) / strides[i]);
370- int64_t axis = axes[i];
371-
372- if (!out_dim.isa <int64_t >() &&
373- (!in_dims[axis].isa <int64_t >() || !ends[i].isa <int64_t >())) {
374- symbol::List<symbol::DimExpr> min_lists{
375- symbol::DimExpr ({-1 }) * ((starts[i] - in_dims[axis]) / strides[i]),
376- out_dim};
377-
378- slice_dims[axis] =
379- symbol::DimExpr ({symbol::Min<symbol::DimExpr>({min_lists})});
380- } else {
381- slice_dims[axis] = out_dim;
382- }
383- }
384-
385- return slice_dims;
386- }
387-
388- inline ShapeOrData StridedSliceRawInferSymbolicShape (
389- const pir::Value x,
390- const pir::Value out,
391- const ExprVec &starts_expr,
392- const ExprVec &ends_expr,
393- const ExprVec &strides_expr,
394- const std::vector<int64_t > &axes_raw,
395- const std::vector<int64_t > &infer_flags_raw,
396- const std::vector<int64_t > &decrease_axis,
397- pir::InferSymbolicShapeContext *infer_context) {
398- const auto &in_shapeordata = infer_context->GetShapeOrDataForValue (x);
399- ExprVec starts = starts_expr;
400- ExprVec ends = ends_expr;
401- ExprVec strides = strides_expr;
402- std::vector<int64_t > infer_flags = [&infer_flags_raw, &axes_raw] {
403- return infer_flags_raw.empty () ? std::vector<int64_t >(axes_raw.size (), 1 )
404- : infer_flags_raw;
405- }();
406-
407- const auto &GetShapeDimExprs = [&]() -> symbol::ShapeOrDataDimExprs {
408- const ExprVec &in_dims = in_shapeordata.shape ();
409- std::vector<int64_t > axes = FormatSliceAxes (axes_raw, in_dims.size ());
410- ExprVec slice_dims =
411- GetStridesSliceDims (in_dims, axes, starts, ends, strides, &infer_flags);
412- ExprVec out_dims = GetDecreasedDims (slice_dims, decrease_axis);
413-
414- auto IsOne = [](const symbol::DimExpr &expr) {
415- return expr.isa <int64_t >() && expr.dyn_cast <int64_t >() == 1 ;
416- };
417- auto IsIntType = [](pir::Value value) {
418- const auto &dtype = value.type ().dyn_cast <pir::DenseTensorType>().dtype ();
419- return dtype.isa <pir::Int32Type>() || dtype.isa <pir::Int64Type>();
420- };
421- if (IsIntType (x) &&
422- (out_dims.empty () || (out_dims.size () == 1 && IsOne (out_dims[0 ])))) {
423- return symbol::ShapeOrDataDimExprs{symbol::TensorShapeOrDataDimExprs (
424- out_dims,
425- std::vector<symbol::DimExpr>{infer_context->GetNextSymName ()})};
426- }
427-
428- return symbol::ShapeOrDataDimExprs{
429- symbol::TensorShapeOrDataDimExprs (out_dims)};
430- };
431-
432- // When `pd.slice` is operating on a tensor which is produced by a `pd.shape`
433- // op, the result should be written into data.
434- const auto &GetDataDimExprs = [&]() -> symbol::ShapeOrDataDimExprs {
435- std::vector<symbol::DimExpr> out_data;
436-
437- // Currently, we DO NOT support the case that any element in `axes` `starts`
438- // or `ends` is a Symbol.
439- auto vec_int64 = details::VecExpr2Int64 (starts);
440- PADDLE_ENFORCE_EQ (
441- vec_int64.has_value (),
442- true ,
443- common::errors::InvalidArgument (
444- " for slice op, all the elements in `starts` must be int64_t" ));
445- std::vector<int64_t > starts_int = vec_int64.value ();
446-
447- vec_int64 = details::VecExpr2Int64 (ends);
448- PADDLE_ENFORCE_EQ (
449- vec_int64.has_value (),
450- true ,
451- common::errors::InvalidArgument (
452- " for slice op, all the elements in `ends` must be int64_t" ));
453- std::vector<int64_t > ends_int = vec_int64.value ();
454-
455- vec_int64 = details::VecExpr2Int64 (strides);
456- PADDLE_ENFORCE_EQ (
457- vec_int64.has_value (),
458- true ,
459- common::errors::InvalidArgument (
460- " for slice op, all the elements in `strides` must be int64_t" ));
461-
462- const int64_t start =
463- starts_int[0 ] < 0 ? starts_int[0 ] + in_shapeordata.data ().value ().size ()
464- : starts_int[0 ];
465- const int64_t end = [&]() -> int64_t {
466- if (ends_int[0 ] < 0 ) {
467- return ends_int[0 ] + in_shapeordata.data ().value ().size ();
468- }
469- if (ends_int[0 ] ==
470- static_cast <int64_t >(std::numeric_limits<int >::max ())) {
471- return in_shapeordata.data ().value ().size ();
472- }
473- return ends_int[0 ];
474- }();
475-
476- const int64_t stride = [&]() -> int64_t {
477- if (strides[0 ].isa <int64_t >()) {
478- return strides[0 ].Get <int64_t >();
479- }
480- return 1 ;
481- }();
482-
483- for (int64_t i = start; i < end; i += stride) {
484- out_data.push_back (in_shapeordata.data ().value ().at (i));
485- }
486-
487- const ExprVec shape = GetDecreasedDims (
488- ExprVec{static_cast <int64_t >(out_data.size ())}, decrease_axis);
489- return symbol::ShapeOrDataDimExprs{
490- symbol::TensorShapeOrDataDimExprs (shape, out_data)};
491- };
492-
493- const auto &out_shape = in_shapeordata.data ().has_value ()
494- ? GetDataDimExprs ()
495- : GetShapeDimExprs ();
496- if (out_shape.data ().has_value () && out_shape.shape ().empty ()) { // 0D tensor
497- const paddle::dialect::DenseTensorType &tensor_type =
498- out.type ().dyn_cast <paddle::dialect::DenseTensorType>();
499- const auto &out_ddim = tensor_type.dims ();
500- if (out_ddim.size () == 1 && out_ddim[0 ] == 1 ) { // value is 1D
501- return symbol::ShapeOrDataDimExprs{symbol::TensorShapeOrDataDimExprs (
502- std::vector<symbol::DimExpr>{1 }, out_shape.data ().value ())};
503- }
504- }
505-
506- return out_shape;
507- }
508-
509293} // namespace paddle::dialect::slice_utils
0 commit comments