Skip to content

Commit 2f96898

Browse files
authored
[Infer Symbolic Shape No.225] Add symbol_infer_interface for shuffle_batch (#68978)
* Add symbol_infer_interface for shuffle_batch * Fix * Refine logic * apply review * revert irrelevant change
1 parent aa857e5 commit 2f96898

File tree

3 files changed

+26
-8
lines changed

3 files changed

+26
-8
lines changed

paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/binary_infer_sym.cc

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1960,12 +1960,30 @@ bool SequenceMaskOpInferSymbolicShape(
19601960
return true;
19611961
}
19621962

1963-
// bool ShuffleBatchOpInferSymbolicShape(pir::Operation *op,
1964-
// pir::InferSymbolicShapeContext
1965-
// *infer_context) {
1966-
// // pass
1967-
// return true;
1968-
// }
1963+
bool ShuffleBatchOpInferSymbolicShape(
1964+
pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) {
1965+
const auto &x_shape =
1966+
infer_context->GetShapeOrDataForValue(op->operand_source(0)).shape();
1967+
infer_context->SetShapeOrDataForValue(
1968+
op->result(0),
1969+
symbol::ShapeOrDataDimExprs{symbol::TensorShapeOrDataDimExprs(x_shape)});
1970+
infer_context->SetShapeOrDataForValue(
1971+
op->result(2),
1972+
symbol::ShapeOrDataDimExprs{symbol::TensorShapeOrDataDimExprs({1})});
1973+
const symbol::DimExpr shuffleidx = [&] {
1974+
symbol::DimExpr shuffleidx{1};
1975+
for (const auto &i : x_shape) {
1976+
shuffleidx = shuffleidx * i;
1977+
}
1978+
return shuffleidx;
1979+
}();
1980+
1981+
infer_context->SetShapeOrDataForValue(
1982+
op->result(1),
1983+
symbol::ShapeOrDataDimExprs{
1984+
symbol::TensorShapeOrDataDimExprs({shuffleidx})});
1985+
return true;
1986+
}
19691987

19701988
bool StftOpInferSymbolicShape(pir::Operation *op,
19711989
pir::InferSymbolicShapeContext *infer_context) {

paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/binary_infer_sym.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ OP_DECLARE_INFER_SYMBOLIC_SHAPE(ReduceAs)
8888
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Searchsorted)
8989
OP_DECLARE_INFER_SYMBOLIC_SHAPE(SegmentPool)
9090
OP_DECLARE_INFER_SYMBOLIC_SHAPE(SequenceMask)
91-
// OP_DECLARE_INFER_SYMBOLIC_SHAPE(ShuffleBatch)
91+
OP_DECLARE_INFER_SYMBOLIC_SHAPE(ShuffleBatch)
9292
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Solve)
9393
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Stft)
9494
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Swiglu)

paddle/phi/ops/yaml/ops.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4508,7 +4508,7 @@
45084508
traits : pir::SideEffectTrait
45094509
data_transform :
45104510
skip_transform : seed
4511-
# interfaces : paddle::dialect::InferSymbolicShapeInterface
4511+
interfaces : paddle::dialect::InferSymbolicShapeInterface
45124512

45134513
- op : shuffle_channel
45144514
args : (Tensor x, int group = 1)

0 commit comments

Comments
 (0)