Skip to content

Commit e923bb3

Browse files
authored
[Infer Symbolic Shape No.240] Add symbol_infer_interface for uniform_random_batch_size_like (#68980)
* Add symbol_infer_interface for uniform_random_batch_size_like * Fix * Fix * Fix * Fix * revert irrelevant change
1 parent 2f96898 commit e923bb3

File tree

3 files changed

+14
-13
lines changed

3 files changed

+14
-13
lines changed

paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.cc

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4052,11 +4052,15 @@ bool UniformInplace_OpInferSymbolicShape(
40524052
return UniformInplaceOpInferSymbolicShape(op, infer_context);
40534053
}
40544054

4055-
// bool UniformRandomBatchSizeLikeOpInferSymbolicShape(
4056-
// pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) {
4057-
// // pass
4058-
// return true;
4059-
// }
4055+
bool UniformRandomBatchSizeLikeOpInferSymbolicShape(
4056+
pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) {
4057+
return BatchSizeLikeInferSymbolicShape(op, infer_context);
4058+
}
4059+
4060+
bool UniformRandomBatchSizeLikeSrOpInferSymbolicShape(
4061+
pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) {
4062+
return UniformRandomBatchSizeLikeOpInferSymbolicShape(op, infer_context);
4063+
}
40604064

40614065
bool UniqueOpInferSymbolicShape(pir::Operation *op,
40624066
pir::InferSymbolicShapeContext *infer_context) {
@@ -4224,12 +4228,8 @@ bool UnsqueezeOpInferSymbolicShape(
42244228
int x_dims_size = x_sym_shape.size();
42254229

42264230
std::vector<symbol::DimExpr> axis_sym;
4227-
if (axis_shape_or_data.data().has_value()) {
4228-
axis_sym = axis_shape_or_data.data().value();
4229-
} else {
4230-
axis_sym =
4231-
details::GetOrCreateExprVecFromData(axis_shape_or_data, infer_context);
4232-
}
4231+
axis_sym =
4232+
details::GetOrCreateExprVecFromData(axis_shape_or_data, infer_context);
42334233
int axis_sym_size = axis_sym.size();
42344234

42354235
// GetUnsqueezeShape

paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,8 @@ OP_DECLARE_INFER_SYMBOLIC_SHAPE(TransLayout)
156156
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Unbind)
157157
OP_DECLARE_INFER_SYMBOLIC_SHAPE(UniformInplace)
158158
OP_DECLARE_INFER_SYMBOLIC_SHAPE(UniformInplace_)
159-
// OP_DECLARE_INFER_SYMBOLIC_SHAPE(UniformRandomBatchSizeLike)
159+
OP_DECLARE_INFER_SYMBOLIC_SHAPE(UniformRandomBatchSizeLike)
160+
OP_DECLARE_INFER_SYMBOLIC_SHAPE(UniformRandomBatchSizeLikeSr)
160161
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Unique)
161162
OP_DECLARE_INFER_SYMBOLIC_SHAPE(UniqueConsecutive)
162163
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Unsqueeze)

paddle/phi/ops/yaml/ops.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5190,7 +5190,7 @@
51905190
data_type: dtype
51915191
no_need_buffer: input
51925192
traits : pir::SideEffectTrait, paddle::dialect::ForwardOnlyTrait
5193-
# interfaces : paddle::dialect::InferSymbolicShapeInterface
5193+
interfaces : paddle::dialect::InferSymbolicShapeInterface
51945194

51955195
- op : unique_consecutive
51965196
args : (Tensor x, bool return_inverse = false, bool return_counts = false, int[] axis = {}, DataType dtype = DataType::FLOAT32)

0 commit comments

Comments
 (0)