Skip to content

Conversation

chentong319
Copy link
Collaborator

@chentong319 chentong319 commented Nov 26, 2024

Certain PyTorch.onnx.export will break the LSTM op into lower level operations and generate SplitToSequence and SequenceAt operation pattern. For example:

  %15 = onnx.Constant dense<0> : tensor<i64>
  %38 = onnx.Constant dense<1> : tensor<i64>
  %65 = onnx.Constant dense<100> : tensor<i64>
  %66 = "onnx.SplitToSequence"(%arg0, %65) {axis = 2 : si64, keepdims = 1 : si64} : (tensor<1x1x400xf32>, tensor<i64>) -> !onnx.Seq<tensor<1x1x100xf32>>
  %67 = "onnx.SequenceAt"(%66, %15) : (!onnx.Seq<tensor<1x1x100xf32>>, tensor<i64>) -> tensor<1x1x100xf32>
  %68 = "onnx.SequenceAt"(%66, %38) : (!onnx.Seq<tensor<1x1x100xf32>>, tensor<i64>) -> tensor<1x1x100xf32>
  %40 = "onnx.Add"(%67, %68) : (tensor<1x1x100xf32>, tensor<1x1x100xf32>) -> tensor<1x1x100xf32>

ONNX-MLIR currently does not have the lowering for SplitToSequence. In general, sequence related ops are not well optimized in ONNX-MLIR. I found that such code pattern can be converted into tensor operations to avoid the sequence ops.

    %0 = onnx.Constant dense<100> : tensor<4xi64>
    %1:4 = "onnx.Split"(%arg0, %0) {axis = 2 : si64} : (tensor<1x1x400xf32>, tensor<4xi64>) -> (tensor<1x1x100xf32>, tensor<1x1x100xf32>, tensor<1x1x100xf32>, tensor<1x1x100xf32>)
    %2:4 = "onnx.Split"(%arg0, %0) {axis = 2 : si64} : (tensor<1x1x400xf32>, tensor<4xi64>) -> (tensor<1x1x100xf32>, tensor<1x1x100xf32>, tensor<1x1x100xf32>, tensor<1x1x100xf32>)
    %3 = "onnx.Add"(%1#0, %2#1) : (tensor<1x1x100xf32>, tensor<1x1x100xf32>) -> tensor<1x1x100xf32>

The two onnx.Split is supposed to be able merged into one. But current onnx-mlir didn't. Need further investigation.

But the exported model has another type of SequenceAt, in which the output type of SequenceAt is different from the element type of the input sequence type. I think it is an error in the exporter. However, I tried to fix this issue in the transformation of this PR.

    %26 = onnx.Constant dense<0> : tensor<i64>
    %27 = onnx.Constant dense<1> : tensor<i64>
    %32 = "onnx.SplitToSequence"(%arg0, %27) {axis = 0 : si64, keepdims = 0 : si64} : (tensor<1x1x100xf32>, tensor<i64>) -> !onnx.Seq<tensor<1x1x100xf32>>
    %33 = "onnx.SequenceAt"(%32, %26) : (!onnx.Seq<tensor<1x1x100xf32>>, tensor<i64>) -> tensor<1x100xf32>

Output from this PR:

    %0 = onnx.Constant dense<0> : tensor<1xi64>
    %1 = onnx.Constant dense<1> : tensor<1xi64>
    %2 = "onnx.Split"(%arg0, %1) {axis = 0 : si64} : (tensor<1x1x100xf32>, tensor<1xi64>) -> tensor<1x1x100xf32>
    %3 = "onnx.Squeeze"(%2, %0) : (tensor<1x1x100xf32>, tensor<1xi64>) -> tensor<1x100xf32>

Other small fixes in this PR:

  1. Handle the case that onnx.ConstantOp has not been normalized yet. Otherwise, --EmitONNXBasic will fail for this model.
  2. It is not needed to check the tensor range size to be larger than one. In the model, the number of output from Split happened to be one.

Test:
With same random input, the outputs of lstm_no_data and lstm_no_dynamo are the same for atol=0.01, rtol=0.05, with or without NNPA turned on.

Copy link
Collaborator

@tungld tungld left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM.

ONNXSplitToSequenceOp splitToSequence;
if (!(splitToSequence = mlir::dyn_cast<ONNXSplitToSequenceOp>(
inputSequence.getDefiningOp())))
return false;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can be shorten with

ONNXSplitToSequenceOp splitToSequence = inputSequence.getDefiningOp<ONNXSplitToSequenceOp>();
if (!splitToSequence)
  return false;

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

Value replaceSequenceAt(
PatternRewriter &rewriter, Location loc, Value sequenceAtResult) {
ONNXSequenceAtOp op =
mlir::cast<ONNXSequenceAtOp>(sequenceAtResult.getDefiningOp());
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here, we can use ONNXSequenceAtOp op = sequenceAtResult.getDefiningOp<ONNXSequenceAtOp>()

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

@@ -444,7 +444,7 @@ TensorType OnnxBuilder::toTensor(Type input) const {
}

TypeRange OnnxBuilder::toTensors(TypeRange inputs) const {
assert(inputs.size() >= 2 && "Expect at least two inputs");
//assert(inputs.size() >= 2 && "Expect at least two inputs");
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Look like we can remove this completely.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

Signed-off-by: chentong319 <[email protected]>
Copy link
Collaborator

@AlexandreEichenberger AlexandreEichenberger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks for the quick fix and the lit-test.

It would be interesting to study how this pattern fairs compared to the original pattern (it was for an LSTM, correct?)

@AlexandreEichenberger
Copy link
Collaborator

@chentong319 Dominic is eager to test this. Can you implement the minor suggestions and merge it into main? Thanks

@AlexandreEichenberger
Copy link
Collaborator

thanks @chentong319 .

@chentong319 chentong319 merged commit 45f07d5 into onnx:main Dec 4, 2024
6 of 7 checks passed
@jenkins-droid
Copy link
Collaborator

Jenkins Linux amd64 Build #16051 [push] Transform SequenceAt to ... started at 12:34

@jenkins-droid
Copy link
Collaborator

Jenkins Linux s390x Build #16054 [push] Transform SequenceAt to ... started at 13:34

@jenkins-droid
Copy link
Collaborator

Jenkins Linux ppc64le Build #15081 [push] Transform SequenceAt to ... started at 13:50

@jenkins-droid
Copy link
Collaborator

Jenkins Linux amd64 Build #16051 [push] Transform SequenceAt to ... failed after 1 min 36 sec

@jenkins-droid
Copy link
Collaborator

Jenkins Linux s390x Build #16054 [push] Transform SequenceAt to ... passed after 1 hr 41 min

@jenkins-droid
Copy link
Collaborator

Jenkins Linux ppc64le Build #15081 [push] Transform SequenceAt to ... passed after 2 hr 18 min

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants