File tree Expand file tree Collapse file tree 3 files changed +26
-8
lines changed
fluid/pir/dialect/operator/interface/infer_symbolic_shape Expand file tree Collapse file tree 3 files changed +26
-8
lines changed Original file line number Diff line number Diff line change @@ -1960,12 +1960,30 @@ bool SequenceMaskOpInferSymbolicShape(
19601960 return true ;
19611961}
19621962
1963- // bool ShuffleBatchOpInferSymbolicShape(pir::Operation *op,
1964- // pir::InferSymbolicShapeContext
1965- // *infer_context) {
1966- // // pass
1967- // return true;
1968- // }
1963+ bool ShuffleBatchOpInferSymbolicShape (
1964+ pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) {
1965+ const auto &x_shape =
1966+ infer_context->GetShapeOrDataForValue (op->operand_source (0 )).shape ();
1967+ infer_context->SetShapeOrDataForValue (
1968+ op->result (0 ),
1969+ symbol::ShapeOrDataDimExprs{symbol::TensorShapeOrDataDimExprs (x_shape)});
1970+ infer_context->SetShapeOrDataForValue (
1971+ op->result (2 ),
1972+ symbol::ShapeOrDataDimExprs{symbol::TensorShapeOrDataDimExprs ({1 })});
1973+ const symbol::DimExpr shuffleidx = [&] {
1974+ symbol::DimExpr shuffleidx{1 };
1975+ for (const auto &i : x_shape) {
1976+ shuffleidx = shuffleidx * i;
1977+ }
1978+ return shuffleidx;
1979+ }();
1980+
1981+ infer_context->SetShapeOrDataForValue (
1982+ op->result (1 ),
1983+ symbol::ShapeOrDataDimExprs{
1984+ symbol::TensorShapeOrDataDimExprs ({shuffleidx})});
1985+ return true ;
1986+ }
19691987
19701988bool StftOpInferSymbolicShape (pir::Operation *op,
19711989 pir::InferSymbolicShapeContext *infer_context) {
Original file line number Diff line number Diff line change @@ -88,7 +88,7 @@ OP_DECLARE_INFER_SYMBOLIC_SHAPE(ReduceAs)
8888OP_DECLARE_INFER_SYMBOLIC_SHAPE (Searchsorted)
8989OP_DECLARE_INFER_SYMBOLIC_SHAPE (SegmentPool)
9090OP_DECLARE_INFER_SYMBOLIC_SHAPE (SequenceMask)
91- // OP_DECLARE_INFER_SYMBOLIC_SHAPE(ShuffleBatch)
91+ OP_DECLARE_INFER_SYMBOLIC_SHAPE (ShuffleBatch)
9292OP_DECLARE_INFER_SYMBOLIC_SHAPE (Solve)
9393OP_DECLARE_INFER_SYMBOLIC_SHAPE (Stft)
9494OP_DECLARE_INFER_SYMBOLIC_SHAPE (Swiglu)
Original file line number Diff line number Diff line change 45084508 traits : pir::SideEffectTrait
45094509 data_transform :
45104510 skip_transform : seed
4511- # interfaces : paddle::dialect::InferSymbolicShapeInterface
4511+ interfaces : paddle::dialect::InferSymbolicShapeInterface
45124512
45134513- op : shuffle_channel
45144514 args : (Tensor x, int group = 1)
You can’t perform that action at this time.
0 commit comments