Skip to content

Commit c6b2115

Browse files
authored
Revert "【Infer Symbolic Shape No.232】Add infer_symbol_shape for StridedSlice …" (#70061)
This reverts commit d72d6ad.
1 parent 67f3613 commit c6b2115

File tree

7 files changed

+18
-287
lines changed

7 files changed

+18
-287
lines changed

paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/infer_sym_slice_utils.h

Lines changed: 0 additions & 216 deletions
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,6 @@ inline ExprVec GetSliceDims(const ExprVec &in_dims,
139139
for (size_t i = 0; i < axes.size(); ++i) {
140140
auto out_dim = ends[i] - starts[i];
141141
int64_t axis = axes[i];
142-
143142
// If in_dims[axis] or ends[i] have symbol, nedd get Min(in_dims[axis] -
144143
// start[i], ends[i] - start[i] )
145144
if (!out_dim.isa<int64_t>() &&
@@ -291,219 +290,4 @@ inline ShapeOrData SliceRawInferSymbolicShape(
291290

292291
return out_shape;
293292
}
294-
295-
inline ExprVec GetStridesSliceDims(
296-
const ExprVec &in_dims,
297-
const std::vector<int64_t> &axes,
298-
const ExprVec &starts_base,
299-
const ExprVec &ends_base,
300-
const ExprVec &strides_base,
301-
std::vector<int64_t> *infer_flags = nullptr) {
302-
ExprVec starts = starts_base;
303-
ExprVec ends = ends_base;
304-
ExprVec strides = strides_base;
305-
auto IsMaxInt = [](const symbol::DimExpr &expr) {
306-
return expr.isa<int64_t>() &&
307-
expr.Get<int64_t>() ==
308-
static_cast<int64_t>(std::numeric_limits<int>::max());
309-
};
310-
311-
for (size_t i = 0; i < axes.size(); ++i) {
312-
int64_t axis = axes.at(i);
313-
int64_t start_i = 0;
314-
315-
if (starts.at(i).isa<int64_t>()) {
316-
if (in_dims.at(axis).isa<int64_t>()) {
317-
starts.at(i) =
318-
(starts.at(i).Get<int64_t>() > in_dims.at(axis).Get<int64_t>())
319-
? in_dims.at(axis)
320-
: starts.at(i);
321-
starts.at(i) =
322-
(starts.at(i).Get<int64_t>() < -in_dims.at(axis).Get<int64_t>())
323-
? symbol::DimExpr({-1}) * in_dims.at(axis)
324-
: starts.at(i);
325-
}
326-
start_i = starts.at(i).Get<int64_t>();
327-
}
328-
329-
int64_t end_i = 0;
330-
if (ends.at(i).isa<int64_t>()) {
331-
if (in_dims.at(axis).isa<int64_t>()) {
332-
ends[i] = std::min(ends.at(i).Get<int64_t>(),
333-
in_dims.at(axis).Get<int64_t>());
334-
}
335-
if (ends.at(i).Get<int64_t>() < 0) {
336-
ends[i] = ends.at(i) + in_dims.at(axis);
337-
}
338-
if (ends.at(i).isa<int64_t>()) {
339-
end_i = ends.at(i).Get<int64_t>();
340-
}
341-
}
342-
343-
ends.at(i) = IsMaxInt(ends.at(i)) ? in_dims.at(axis) : ends.at(i);
344-
bool both_negative_or_positive =
345-
(start_i >= 0 && end_i >= 0) || (start_i <= 0 && end_i <= 0);
346-
bool start_negative_end_positive = start_i <= 0 && end_i >= 0;
347-
bool start_positive_end_negative = start_i >= 0 && end_i <= 0;
348-
349-
if (both_negative_or_positive) {
350-
continue;
351-
} else if (start_negative_end_positive) {
352-
starts.at(i) = starts.at(i) + in_dims.at(axis);
353-
} else if (start_positive_end_negative) {
354-
starts.at(i) = starts.at(i) - in_dims.at(axis);
355-
} else {
356-
PADDLE_THROW(common::errors::Fatal("Dead code"));
357-
}
358-
}
359-
360-
ExprVec slice_dims(in_dims);
361-
PADDLE_ENFORCE_EQ(
362-
(axes.size() == starts.size() && axes.size() == ends.size() &&
363-
axes.size() == strides.size()),
364-
true,
365-
common::errors::InvalidArgument(
366-
"The size of axes must equal size of starts, ends, and strides."));
367-
368-
for (size_t i = 0; i < axes.size(); ++i) {
369-
auto out_dim = symbol::DimExpr({-1}) * ((starts[i] - ends[i]) / strides[i]);
370-
int64_t axis = axes[i];
371-
372-
if (!out_dim.isa<int64_t>() &&
373-
(!in_dims[axis].isa<int64_t>() || !ends[i].isa<int64_t>())) {
374-
symbol::List<symbol::DimExpr> min_lists{
375-
symbol::DimExpr({-1}) * ((starts[i] - in_dims[axis]) / strides[i]),
376-
out_dim};
377-
378-
slice_dims[axis] =
379-
symbol::DimExpr({symbol::Min<symbol::DimExpr>({min_lists})});
380-
} else {
381-
slice_dims[axis] = out_dim;
382-
}
383-
}
384-
385-
return slice_dims;
386-
}
387-
388-
inline ShapeOrData StridedSliceRawInferSymbolicShape(
389-
const pir::Value x,
390-
const pir::Value out,
391-
const ExprVec &starts_expr,
392-
const ExprVec &ends_expr,
393-
const ExprVec &strides_expr,
394-
const std::vector<int64_t> &axes_raw,
395-
const std::vector<int64_t> &infer_flags_raw,
396-
const std::vector<int64_t> &decrease_axis,
397-
pir::InferSymbolicShapeContext *infer_context) {
398-
const auto &in_shapeordata = infer_context->GetShapeOrDataForValue(x);
399-
ExprVec starts = starts_expr;
400-
ExprVec ends = ends_expr;
401-
ExprVec strides = strides_expr;
402-
std::vector<int64_t> infer_flags = [&infer_flags_raw, &axes_raw] {
403-
return infer_flags_raw.empty() ? std::vector<int64_t>(axes_raw.size(), 1)
404-
: infer_flags_raw;
405-
}();
406-
407-
const auto &GetShapeDimExprs = [&]() -> symbol::ShapeOrDataDimExprs {
408-
const ExprVec &in_dims = in_shapeordata.shape();
409-
std::vector<int64_t> axes = FormatSliceAxes(axes_raw, in_dims.size());
410-
ExprVec slice_dims =
411-
GetStridesSliceDims(in_dims, axes, starts, ends, strides, &infer_flags);
412-
ExprVec out_dims = GetDecreasedDims(slice_dims, decrease_axis);
413-
414-
auto IsOne = [](const symbol::DimExpr &expr) {
415-
return expr.isa<int64_t>() && expr.dyn_cast<int64_t>() == 1;
416-
};
417-
auto IsIntType = [](pir::Value value) {
418-
const auto &dtype = value.type().dyn_cast<pir::DenseTensorType>().dtype();
419-
return dtype.isa<pir::Int32Type>() || dtype.isa<pir::Int64Type>();
420-
};
421-
if (IsIntType(x) &&
422-
(out_dims.empty() || (out_dims.size() == 1 && IsOne(out_dims[0])))) {
423-
return symbol::ShapeOrDataDimExprs{symbol::TensorShapeOrDataDimExprs(
424-
out_dims,
425-
std::vector<symbol::DimExpr>{infer_context->GetNextSymName()})};
426-
}
427-
428-
return symbol::ShapeOrDataDimExprs{
429-
symbol::TensorShapeOrDataDimExprs(out_dims)};
430-
};
431-
432-
// When `pd.slice` is operating on a tensor which is produced by a `pd.shape`
433-
// op, the result should be written into data.
434-
const auto &GetDataDimExprs = [&]() -> symbol::ShapeOrDataDimExprs {
435-
std::vector<symbol::DimExpr> out_data;
436-
437-
// Currently, we DO NOT support the case that any element in `axes` `starts`
438-
// or `ends` is a Symbol.
439-
auto vec_int64 = details::VecExpr2Int64(starts);
440-
PADDLE_ENFORCE_EQ(
441-
vec_int64.has_value(),
442-
true,
443-
common::errors::InvalidArgument(
444-
"for slice op, all the elements in `starts` must be int64_t"));
445-
std::vector<int64_t> starts_int = vec_int64.value();
446-
447-
vec_int64 = details::VecExpr2Int64(ends);
448-
PADDLE_ENFORCE_EQ(
449-
vec_int64.has_value(),
450-
true,
451-
common::errors::InvalidArgument(
452-
"for slice op, all the elements in `ends` must be int64_t"));
453-
std::vector<int64_t> ends_int = vec_int64.value();
454-
455-
vec_int64 = details::VecExpr2Int64(strides);
456-
PADDLE_ENFORCE_EQ(
457-
vec_int64.has_value(),
458-
true,
459-
common::errors::InvalidArgument(
460-
"for slice op, all the elements in `strides` must be int64_t"));
461-
462-
const int64_t start =
463-
starts_int[0] < 0 ? starts_int[0] + in_shapeordata.data().value().size()
464-
: starts_int[0];
465-
const int64_t end = [&]() -> int64_t {
466-
if (ends_int[0] < 0) {
467-
return ends_int[0] + in_shapeordata.data().value().size();
468-
}
469-
if (ends_int[0] ==
470-
static_cast<int64_t>(std::numeric_limits<int>::max())) {
471-
return in_shapeordata.data().value().size();
472-
}
473-
return ends_int[0];
474-
}();
475-
476-
const int64_t stride = [&]() -> int64_t {
477-
if (strides[0].isa<int64_t>()) {
478-
return strides[0].Get<int64_t>();
479-
}
480-
return 1;
481-
}();
482-
483-
for (int64_t i = start; i < end; i += stride) {
484-
out_data.push_back(in_shapeordata.data().value().at(i));
485-
}
486-
487-
const ExprVec shape = GetDecreasedDims(
488-
ExprVec{static_cast<int64_t>(out_data.size())}, decrease_axis);
489-
return symbol::ShapeOrDataDimExprs{
490-
symbol::TensorShapeOrDataDimExprs(shape, out_data)};
491-
};
492-
493-
const auto &out_shape = in_shapeordata.data().has_value()
494-
? GetDataDimExprs()
495-
: GetShapeDimExprs();
496-
if (out_shape.data().has_value() && out_shape.shape().empty()) { // 0D tensor
497-
const paddle::dialect::DenseTensorType &tensor_type =
498-
out.type().dyn_cast<paddle::dialect::DenseTensorType>();
499-
const auto &out_ddim = tensor_type.dims();
500-
if (out_ddim.size() == 1 && out_ddim[0] == 1) { // value is 1D
501-
return symbol::ShapeOrDataDimExprs{symbol::TensorShapeOrDataDimExprs(
502-
std::vector<symbol::DimExpr>{1}, out_shape.data().value())};
503-
}
504-
}
505-
506-
return out_shape;
507-
}
508-
509293
} // namespace paddle::dialect::slice_utils

paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.cc

Lines changed: 6 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -3515,42 +3515,12 @@ bool SplitWithNumOpInferSymbolicShape(
35153515
return true;
35163516
}
35173517

3518-
bool StridedSliceOpInferSymbolicShape(
3519-
pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) {
3520-
pir::Value operand_source = op->operand_source(0);
3521-
pir::Value operand_starts = op->operand_source(1);
3522-
pir::Value operand_ends = op->operand_source(2);
3523-
pir::Value operand_strides = op->operand_source(3);
3524-
pir::Value res = op->result(0);
3525-
3526-
const symbol::ShapeOrDataDimExprs &starts_shape_data =
3527-
infer_context->GetShapeOrDataForValue(operand_starts);
3528-
const symbol::ShapeOrDataDimExprs &ends_shape_data =
3529-
infer_context->GetShapeOrDataForValue(operand_ends);
3530-
const symbol::ShapeOrDataDimExprs &strides_shape_data =
3531-
infer_context->GetShapeOrDataForValue(operand_strides);
3532-
3533-
ExprVec starts = slice_utils::GetExprVecFromData(starts_shape_data);
3534-
ExprVec ends = slice_utils::GetExprVecFromData(ends_shape_data);
3535-
ExprVec strides = slice_utils::GetExprVecFromData(strides_shape_data);
3536-
3537-
std::vector<int32_t> axes_vec = details::GetVectorAttr<int32_t>(op, "axes");
3538-
std::vector<int64_t> axes_vec_64(axes_vec.begin(), axes_vec.end());
3539-
3540-
infer_context->SetShapeOrDataForValue(
3541-
res,
3542-
slice_utils::StridedSliceRawInferSymbolicShape(operand_source,
3543-
res,
3544-
starts,
3545-
ends,
3546-
strides,
3547-
axes_vec_64,
3548-
std::vector<int64_t>{},
3549-
std::vector<int64_t>{},
3550-
infer_context));
3551-
3552-
return true;
3553-
}
3518+
// bool StridedSliceOpInferSymbolicShape(pir::Operation *op,
3519+
// pir::InferSymbolicShapeContext
3520+
// *infer_context) {
3521+
// // pass
3522+
// return true;
3523+
// }
35543524

35553525
bool SumOpInferSymbolicShape(pir::Operation *op,
35563526
pir::InferSymbolicShapeContext *infer_context) {

paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@ OP_DECLARE_INFER_SYMBOLIC_SHAPE(SplitWithNum)
139139
OP_DECLARE_INFER_SYMBOLIC_SHAPE(SquaredL2Norm)
140140
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Squeeze)
141141
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Squeeze_)
142-
OP_DECLARE_INFER_SYMBOLIC_SHAPE(StridedSlice)
142+
// OP_DECLARE_INFER_SYMBOLIC_SHAPE(StridedSlice)
143143
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Sum)
144144
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Svd)
145145
OP_DECLARE_INFER_SYMBOLIC_SHAPE(SetValue)

paddle/phi/ops/yaml/ops.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4800,7 +4800,7 @@
48004800
kernel :
48014801
func : strided_slice
48024802
backward : strided_slice_grad
4803-
interfaces : paddle::dialect::InferSymbolicShapeInterface
4803+
# interfaces : paddle::dialect::InferSymbolicShapeInterface
48044804

48054805
- op : sum
48064806
args : (Tensor x, IntArray axis={}, DataType dtype=DataType::UNDEFINED, bool keepdim=false)

python/paddle/base/variable_index.py

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -764,16 +764,7 @@ def get_tensor_with_basic_indexing(
764764
stride = attrs['strides']
765765
if use_strided_slice:
766766
# TODO(zoooo0820): support strided_slice_array until PIR API is ready
767-
if in_pir_mode():
768-
if isinstance(st, (list, tuple)):
769-
if paddle.utils._contain_var(st):
770-
st = paddle.utils.get_int_tensor_list(st)
771-
if isinstance(end, (list, tuple)):
772-
if paddle.utils._contain_var(end):
773-
end = paddle.utils.get_int_tensor_list(end)
774-
if isinstance(stride, (list, tuple)):
775-
if paddle.utils._contain_var(stride):
776-
stride = paddle.utils.get_int_tensor_list(stride)
767+
777768
out = paddle._C_ops.strided_slice(x, axes, st, end, stride)
778769
if len(decrease_axes) > 0:
779770
out = paddle._C_ops.squeeze(out, decrease_axes)

python/paddle/tensor/manipulation.py

Lines changed: 1 addition & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -5632,22 +5632,7 @@ def strided_slice(
56325632
>>> sliced_2 = paddle.strided_slice(x, axes=axes, starts=[minus_3, 0, 2], ends=ends, strides=strides_2)
56335633
>>> # sliced_2 is x[:, 1:3:1, 0:2:1, 2:4:2].
56345634
"""
5635-
if in_dynamic_mode():
5636-
return _C_ops.strided_slice(x, axes, starts, ends, strides)
5637-
elif in_pir_mode():
5638-
5639-
def _convert_to_tensor_list(input):
5640-
if isinstance(input, paddle.pir.Value):
5641-
input.stop_gradient = True
5642-
elif isinstance(input, (list, tuple)):
5643-
if paddle.utils._contain_var(input):
5644-
input = paddle.utils.get_int_tensor_list(input)
5645-
return input
5646-
5647-
starts = _convert_to_tensor_list(starts)
5648-
ends = _convert_to_tensor_list(ends)
5649-
strides = _convert_to_tensor_list(strides)
5650-
5635+
if in_dynamic_or_pir_mode():
56515636
return _C_ops.strided_slice(x, axes, starts, ends, strides)
56525637
else:
56535638
helper = LayerHelper('strided_slice', **locals())

0 commit comments

Comments
 (0)