-
Notifications
You must be signed in to change notification settings - Fork 366
Transform SequenceAt to split for special cases #3018
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Signed-off-by: chentong319 <[email protected]>
Signed-off-by: chentong319 <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM.
ONNXSplitToSequenceOp splitToSequence; | ||
if (!(splitToSequence = mlir::dyn_cast<ONNXSplitToSequenceOp>( | ||
inputSequence.getDefiningOp()))) | ||
return false; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can be shorten with
ONNXSplitToSequenceOp splitToSequence = inputSequence.getDefiningOp<ONNXSplitToSequenceOp>();
if (!splitToSequence)
return false;
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
Value replaceSequenceAt( | ||
PatternRewriter &rewriter, Location loc, Value sequenceAtResult) { | ||
ONNXSequenceAtOp op = | ||
mlir::cast<ONNXSequenceAtOp>(sequenceAtResult.getDefiningOp()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same here, we can use ONNXSequenceAtOp op = sequenceAtResult.getDefiningOp<ONNXSequenceAtOp>()
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
src/Dialect/ONNX/DialectBuilder.cpp
Outdated
@@ -444,7 +444,7 @@ TensorType OnnxBuilder::toTensor(Type input) const { | |||
} | |||
|
|||
TypeRange OnnxBuilder::toTensors(TypeRange inputs) const { | |||
assert(inputs.size() >= 2 && "Expect at least two inputs"); | |||
//assert(inputs.size() >= 2 && "Expect at least two inputs"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Look like we can remove this completely.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
Signed-off-by: chentong319 <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks for the quick fix and the lit-test.
It would be interesting to study how this pattern fairs compared to the original pattern (it was for an LSTM, correct?)
@chentong319 Dominic is eager to test this. Can you implement the minor suggestions and merge it into main? Thanks |
thanks @chentong319 . |
Jenkins Linux amd64 Build #16051 [push] Transform SequenceAt to ... started at 12:34 |
Jenkins Linux s390x Build #16054 [push] Transform SequenceAt to ... started at 13:34 |
Jenkins Linux ppc64le Build #15081 [push] Transform SequenceAt to ... started at 13:50 |
Jenkins Linux amd64 Build #16051 [push] Transform SequenceAt to ... failed after 1 min 36 sec |
Jenkins Linux s390x Build #16054 [push] Transform SequenceAt to ... passed after 1 hr 41 min |
Jenkins Linux ppc64le Build #15081 [push] Transform SequenceAt to ... passed after 2 hr 18 min |
Certain PyTorch.onnx.export will break the LSTM op into lower level operations and generate SplitToSequence and SequenceAt operation pattern. For example:
ONNX-MLIR currently does not have the lowering for SplitToSequence. In general, sequence related ops are not well optimized in ONNX-MLIR. I found that such code pattern can be converted into tensor operations to avoid the sequence ops.
The two onnx.Split is supposed to be able merged into one. But current onnx-mlir didn't. Need further investigation.
But the exported model has another type of SequenceAt, in which the output type of SequenceAt is different from the element type of the input sequence type. I think it is an error in the exporter. However, I tried to fix this issue in the transformation of this PR.
Output from this PR:
Other small fixes in this PR:
--EmitONNXBasic
will fail for this model.Test:
With same random input, the outputs of lstm_no_data and lstm_no_dynamo are the same for atol=0.01, rtol=0.05, with or without NNPA turned on.