Skip to content

Commit 30039ca

Browse files
authored
【Infer Symbolic Shape No.14】[BUAA] Add check_numerics op (#67735)
* Finished check_numerics op * Fixed errors * Fixed errors
1 parent 80cb18a commit 30039ca

File tree

3 files changed

+17
-0
lines changed

3 files changed

+17
-0
lines changed

paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.cc

100755100644
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -496,6 +496,21 @@ bool Cast_OpInferSymbolicShape(pir::Operation *op,
496496
return CastOpInferSymbolicShape(op, infer_context);
497497
}
498498

499+
bool CheckNumericsOpInferSymbolicShape(
500+
pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) {
501+
infer_context->SetShapeOrDataForValue(
502+
op->result(0),
503+
symbol::ShapeOrDataDimExprs{
504+
symbol::TensorShapeOrDataDimExprs({symbol::DimExpr(3)})});
505+
506+
infer_context->SetShapeOrDataForValue(
507+
op->result(1),
508+
symbol::ShapeOrDataDimExprs{
509+
symbol::TensorShapeOrDataDimExprs({symbol::DimExpr(3)})});
510+
511+
return true;
512+
}
513+
499514
bool CholeskyOpInferSymbolicShape(
500515
pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) {
501516
const auto &x_shape =

paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ OP_DECLARE_INFER_SYMBOLIC_SHAPE(Cumprod_)
4545
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Cumsum)
4646
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Cumsum_)
4747
OP_DECLARE_INFER_SYMBOLIC_SHAPE(ChannelShuffle)
48+
OP_DECLARE_INFER_SYMBOLIC_SHAPE(CheckNumerics)
4849
// OP_DECLARE_INFER_SYMBOLIC_SHAPE(Crop)
4950
OP_DECLARE_INFER_SYMBOLIC_SHAPE(DecodeJpeg)
5051
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Diag)

paddle/phi/ops/yaml/ops.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -875,6 +875,7 @@
875875
func : CheckNumericsInferMeta
876876
kernel :
877877
func : check_numerics
878+
interfaces : paddle::dialect::InferSymbolicShapeInterface
878879

879880
- op : cholesky
880881
args : (Tensor x, bool upper=false)

0 commit comments

Comments
 (0)