Skip to content

Commit b5b2892

Browse files
committed
Return a failure instead of crashing if shape inference can not be run because of unraked operand types
Signed-off-by: Jonas Rickert <[email protected]>
1 parent 45f07d5 commit b5b2892

22 files changed

+82
-8
lines changed

src/Dialect/ONNX/ONNXOps/Math/DFT.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,9 @@ LogicalResult ONNXGenericDFTOpShapeHelper<OP_TYPE>::customComputeShape(
3333
// Get info about input data operand.
3434
Value input = operandAdaptor.getInput();
3535
// Get the rank to compensate for N dimensions.
36+
if (!hasShapeAndRank(input)) {
37+
return failure();
38+
}
3639
int64_t rank = createIE->getShapedTypeRank(input);
3740

3841
// Check if the dimension for axis is a literal and in range.

src/Dialect/ONNX/ONNXOps/Math/MatMul.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,9 @@ LogicalResult ONNXGenericMatMulOpShapeHelper<OP_TYPE>::computeShape() {
5555
std::tie(A, B) = matMulInputs(operandAdaptor);
5656

5757
// Size all the arrays to padded length.
58+
if (!hasShapeAndRank(A) || !hasShapeAndRank(B)) {
59+
return failure();
60+
}
5861
uint64_t aRank = createIE->getShapedTypeRank(A);
5962
uint64_t bRank = createIE->getShapedTypeRank(B);
6063
int paddedRank = std::max(aRank, bRank);

src/Dialect/ONNX/ONNXOps/Math/Reduction.cpp

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,9 @@ LogicalResult ONNXGenericReductionOpShapeHelper<OP_TYPE>::customComputeShape(
2929
DimsExpr &axes, int noopWithEmptyAxes) {
3030
typename OP_TYPE::Adaptor operandAdaptor(operands, op->getAttrDictionary());
3131
Value data = operandAdaptor.getData();
32+
if (!hasShapeAndRank(data)) {
33+
return failure();
34+
}
3235
int64_t rank = createIE->getShapedTypeRank(data);
3336
// Normalize the axes: at present, we only support compile time axes, but
3437
// with keep_dim on, it might not be too difficult to generate the code.
@@ -104,7 +107,11 @@ LogicalResult ONNXGenericReductionOpShapeHelper<OP_TYPE>::computeShape() {
104107
createIE->getIntFromArrayAsSymbols(operandAdaptor.getAxes(), axes);
105108
} else {
106109
// When the axis is dynamic, try to infer the rank of output tensor
107-
int64_t dataRank = createIE->getShapedTypeRank(operandAdaptor.getData());
110+
const auto data = operandAdaptor.getData();
111+
if (!hasShapeAndRank(data)) {
112+
return failure();
113+
}
114+
int64_t dataRank = createIE->getShapedTypeRank(data);
108115
int64_t axlesSize = createIE->getArraySize(operandAdaptor.getAxes());
109116
if (!operandAdaptor.getKeepdims() && axlesSize < 0 /*undef shape*/) {
110117
// Even though we did not compute the shape in ShapeHelper, return

src/Dialect/ONNX/ONNXOps/Math/TopK.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,9 @@ LogicalResult ONNXTopKOpShapeHelper::computeShape() {
3131
// Get info about X and K operands.
3232
Value X = operandAdaptor.getX();
3333
Value K = operandAdaptor.getK();
34+
if (!hasShapeAndRank(X)) {
35+
return failure();
36+
}
3437
int64_t rank = createIE->getShapedTypeRank(X);
3538

3639
// Axis to compute TopK.

src/Dialect/ONNX/ONNXOps/NN/Conv.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -374,6 +374,9 @@ LogicalResult ONNXConvTransposeOpShapeHelper::computeShape() {
374374
Value wValue = operandAdaptor.getW();
375375

376376
// Basic information.
377+
if (!hasShapeAndRank(xValue)) {
378+
return failure();
379+
}
377380
int64_t rank = createIE->getShapedTypeRank(xValue);
378381
int64_t spatialOffset = 2;
379382
int64_t spatialRank = rank - spatialOffset;

src/Dialect/ONNX/ONNXOps/NN/NNHelper.cpp.inc

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,9 @@ LogicalResult ONNXGenericPoolOpShapeHelper<OP_TYPE>::customComputeShape(
3131
std::optional<ArrayAttr> strideOpt, std::optional<ArrayAttr> dilationOpt,
3232
bool hasFilter, bool ceilMode) {
3333
// Basic information.
34+
if(!hasShapeAndRank(xValue)) {
35+
return failure();
36+
}
3437
int64_t rank = createIE->getShapedTypeRank(xValue);
3538
int64_t spatialOffset = 2;
3639
int64_t spatialRank = rank - spatialOffset;

src/Dialect/ONNX/ONNXOps/NN/Pooling.cpp

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -48,14 +48,18 @@ LogicalResult ONNXGenericGlobalPoolOpShapeHelper<OP_TYPE>::computeShape() {
4848
template <>
4949
LogicalResult ONNXMaxRoiPoolOpShapeHelper::computeShape() {
5050
ONNXMaxRoiPoolOpAdaptor operandAdaptor(operands, op->getAttrDictionary());
51-
5251
IndexExpr channel = createIE->getShapeAsDim(operandAdaptor.getX(), 1);
53-
uint64_t roisRank = createIE->getShapedTypeRank(operandAdaptor.getRois());
52+
53+
const auto rois = operandAdaptor.getRois();
54+
if (!hasShapeAndRank(rois)) {
55+
return failure();
56+
}
57+
uint64_t roisRank = createIE->getShapedTypeRank(rois);
5458
if (roisRank != 2)
5559
return op->emitError("rois rank is expected to be 2d");
5660

5761
// 2d tensor: (num_rois, 5)
58-
IndexExpr numRois = createIE->getShapeAsDim(operandAdaptor.getRois(), 0);
62+
IndexExpr numRois = createIE->getShapeAsDim(rois, 0);
5963
DimsExpr pooledDims;
6064
createIE->getIntFromArrayAsLiterals(
6165
operandAdaptor.getPooledShape(), pooledDims);

src/Dialect/ONNX/ONNXOps/ShapeHelper.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -285,6 +285,11 @@ LogicalResult ONNXBroadcastOpShapeHelper::customComputeShape(
285285
DimsExpr dimsExpr;
286286
uint64_t numOfInputs = initialOperands.size();
287287

288+
if (!llvm::all_of(initialOperands,
289+
[](Value initalOperand) { return hasShapeAndRank(initalOperand); })) {
290+
return failure();
291+
}
292+
288293
// Compute rank of the output. Rank of the output is the maximum rank of all
289294
// initial operands.
290295
uint64_t additionalOperRank =

src/Dialect/ONNX/ONNXOps/Tensor/Compress.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,9 @@ LogicalResult ONNXCompressOpShapeHelper::computeShape() {
3131
ONNXCompressOpAdaptor operandAdaptor(operands);
3232
Value input = operandAdaptor.getInput();
3333
Value cond = operandAdaptor.getCondition();
34+
if (!hasShapeAndRank(input)) {
35+
return failure();
36+
}
3437
int64_t inputRank = createIE->getShapedTypeRank(input);
3538
createIE->assertHasShapeAndRank(cond);
3639
std::optional<int64_t> optionalAxis = compressOp.getAxis();

src/Dialect/ONNX/ONNXOps/Tensor/DepthToSpace.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,9 @@ LogicalResult ONNXDepthToSpaceOpShapeHelper::computeShape() {
3030
ONNXDepthToSpaceOp depthOp = llvm::cast<ONNXDepthToSpaceOp>(op);
3131
ONNXDepthToSpaceOpAdaptor operandAdaptor(operands);
3232
Value input = operandAdaptor.getInput();
33+
if (!hasShapeAndRank(input)) {
34+
return failure();
35+
}
3336
int64_t inputRank = createIE->getShapedTypeRank(input);
3437
assert(inputRank == 4 && "Unexpected input tensor rank");
3538
int64_t blocksize = depthOp.getBlocksize();

0 commit comments

Comments
 (0)