Skip to content

Commit a7653c1

Browse files
authored
Remove the pattern unstick_4ds_squeeze_stick_3ds (#3062)
Signed-off-by: Tung D. Le <[email protected]>
1 parent d35d593 commit a7653c1

File tree

8 files changed

+110
-0
lines changed

8 files changed

+110
-0
lines changed

src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/OpHelper.cpp

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -636,5 +636,25 @@ IntegerAttr getDefaultSaturation(PatternRewriter &rewriter) {
636636
return IntegerAttr();
637637
}
638638

639+
// Create an array tensor to contain three dimensions of layout 3DS.
640+
// The tensor is created from 4DS's shape by removing the value 1 at axis 1.
641+
// e.g. 4DS tensor: tensor<3, 1, 4, 5>,
642+
// this function returns a tensor: tensor<3xi64> = [3, 4, 5]
643+
Value create3DSShapeFrom4DS(OpBuilder &builder, Location loc, Value val4DS) {
644+
OnnxBuilder create(builder, loc);
645+
ArrayRef<int64_t> shape4DS = getShape(val4DS.getType());
646+
assert(shape4DS.size() == 4 && "The tensor must have rank of 4");
647+
assert(shape4DS[1] == 1 && "The second dim must be 1");
648+
if (hasStaticShape(val4DS.getType())) {
649+
return create.constantInt64(
650+
ArrayRef<int64_t>{shape4DS[0], shape4DS[2], shape4DS[3]});
651+
}
652+
Value dim0 = create.dim(val4DS, 0);
653+
Value dim1 = create.dim(val4DS, 2);
654+
Value dim2 = create.dim(val4DS, 3);
655+
return create.concat(
656+
RankedTensorType::get({3}, builder.getI64Type()), {dim0, dim1, dim2}, 0);
657+
}
658+
639659
} // namespace zhigh
640660
} // namespace onnx_mlir

src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/OpHelper.hpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,13 @@ bool hasNNPAUse(mlir::Value v);
110110
/// Get saturation settings.
111111
mlir::IntegerAttr getDefaultSaturation(mlir::PatternRewriter &rewriter);
112112

113+
/// Create an array tensor to contain three dimensions of layout 3DS.
114+
/// The tensor is created from 4DS's shape by removing the value 1 at axis 1.
115+
/// e.g. 4DS tensor: tensor<3, 1, 4, 5>,
116+
/// this function returns a tensor: tensor<3xi64> = [3, 4, 5]
117+
mlir::Value create3DSShapeFrom4DS(
118+
mlir::OpBuilder &builder, mlir::Location loc, mlir::Value val3DS);
119+
113120
} // namespace zhigh
114121
} // namespace onnx_mlir
115122
#endif

src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/OpHelper.td

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,13 @@ def TensorHas3DSLayout: Constraint<
109109
"ZTensor has 3DS layout"
110110
>;
111111

112+
def TensorHas4DSLayout: Constraint<
113+
CPred<"::onnx_mlir::is4DSLayout("
114+
"::onnx_mlir::zhigh::convertZTensorDataLayoutToStringAttr($_builder, "
115+
"::onnx_mlir::zhigh::getZTensorLayout($0.getType())))">,
116+
"ZTensor has 4DS layout"
117+
>;
118+
112119
def TensorHasNHWCLayout: Constraint<
113120
CPred<"::onnx_mlir::isNHWCLayout("
114121
"::onnx_mlir::zhigh::convertZTensorDataLayoutToStringAttr($_builder, "
@@ -238,4 +245,17 @@ def GetDefaultSaturation : NativeCodeCall<
238245
"::onnx_mlir::zhigh::getDefaultSaturation($_builder)"
239246
>;
240247

248+
class IsConstOf<int v>: Constraint<
249+
CPred<"onnx_mlir::isConstOf($0, " # v # ")">,
250+
"Value is a s scalar constant of v"
251+
>;
252+
253+
// Create an array tensor to contain three dimensions of layout 3DS.
254+
// The tensor is created from 4DS's shape by removing the value 1 at axis 1.
255+
// e.g. 4DS tensor: tensor<3, 1, 4, 5>,
256+
// this function returns a tensor: tensor<3xi64> = [3, 4, 5]
257+
def Create3DSShapeFrom4DS: NativeCodeCall<
258+
"::onnx_mlir::zhigh::create3DSShapeFrom4DS($_builder, $_loc, $0)"
259+
>;
260+
241261
#endif // OP_HELPER

src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/Stick/Stick.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313

1414
#include "src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/ShapeHelper.hpp"
1515
#include "src/Compiler/CompilerOptions.hpp"
16+
#include "src/Dialect/ONNX/ONNXOps/OpHelper.hpp"
1617

1718
using namespace mlir;
1819
using namespace onnx_mlir;
@@ -138,6 +139,7 @@ void ZHighStickOp::getCanonicalizationPatterns(
138139
results.insert<NoneTypeStickRemovalPattern>(context);
139140
results.insert<StickUnstickSameLayoutRemovalPattern>(context);
140141
results.insert<StickUnstickDiffLayoutRemovalPattern>(context);
142+
results.insert<Stick3DSSqueezeUnstick4DSPattern>(context);
141143
results.insert<ReplaceONNXLeakyReluPattern>(context);
142144
results.insert<ReplaceONNXSoftplusPattern>(context);
143145
results.insert<ReplaceONNXReciprocalSqrtPattern>(context);

src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/Stick/ZHighStick.td

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -317,4 +317,23 @@ def ReshapeTransposeReshape3DSTo2DPattern : Pat<
317317
]
318318
>;
319319

320+
// Pattern in the CCFD model.
321+
// 4DS and 3DS have exactly same data values when the second dim of 4DS is 1.
322+
// (The second dim of 4DS indicates unidirectional or bidirectional LSTM/GRU/RNN,
323+
// in which: 1 means unidirectional, 2 means bidirectional)
324+
// This rewriting is true no matter dims (except the 2nd dim) are dynamic or static.
325+
def Stick3DSSqueezeUnstick4DSPattern: Pat<
326+
// Input: X -> unstick (4DS) -> Squeeze (axis=1) -> stick (3DS).
327+
(ZHighStickOp:$stick
328+
(ONNXSqueezeOp (ZHighUnstickOp:$unstick $X), $axes),
329+
$_, $_),
330+
// Output: initial X value unchanged, but transformed with the new layout.
331+
(ZHighReshapeOp $X, (Create3DSShapeFrom4DS $X), (GetLayout $stick)),
332+
// Conditions.
333+
[(TensorHas4DSLayout $X), // Input is 4DS.
334+
(TensorHas3DSLayout $stick), // Output is 3DS.
335+
(IsConstOf<1> $axes), // squeeze at axis 1.
336+
]
337+
>;
338+
320339
#endif // STICK_TD

src/Accelerators/NNPA/Support/LayoutHelper.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,10 @@ bool is4DLayout(StringAttr layout) {
8181
return (layout && layout.getValue().equals_insensitive(LAYOUT_4D));
8282
}
8383

84+
bool is4DSLayout(StringAttr layout) {
85+
return (layout && layout.getValue().equals_insensitive(LAYOUT_4DS));
86+
}
87+
8488
bool isNHWCLayout(StringAttr layout) {
8589
return (layout && layout.getValue().equals_insensitive(LAYOUT_NHWC));
8690
}

src/Accelerators/NNPA/Support/LayoutHelper.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ zdnn_data_layouts convertLayoutAttrToZDNNDataLayout(
4747
bool is2DLayout(mlir::StringAttr layout);
4848
bool is3DSLayout(mlir::StringAttr layout);
4949
bool is4DLayout(mlir::StringAttr layout);
50+
bool is4DSLayout(mlir::StringAttr layout);
5051
bool isNHWCLayout(mlir::StringAttr layout);
5152

5253
mlir::StringAttr getNCHWLayoutAttr(mlir::PatternRewriter &rewriter);

test/mlir/accelerators/nnpa/transform/zhigh-combine.mlir

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -528,3 +528,40 @@ func.func @test_Roberta_pattern2_bs1_notmod64(%arg0: tensor<1x384x756xf32>, %arg
528528
// CHECK: }
529529
}
530530

531+
// -----
532+
533+
func.func @replace_unstick_squeeze_stick_static(%arg0: tensor<7x1x128x200xf16, #zhigh.layout<{dataLayout = "4DS"}>>) -> tensor<7x128x200xf16, #zhigh.layout<{dataLayout = "3DS"}>> {
534+
%cst1 = onnx.Constant dense<1> : tensor<1xi64>
535+
%0 = "zhigh.Unstick"(%arg0) : (tensor<7x1x128x200xf16, #zhigh.layout<{dataLayout = "4DS"}>>) -> tensor<7x1x128x200xf32>
536+
%1 = "onnx.Squeeze"(%0, %cst1) : (tensor<7x1x128x200xf32>, tensor<1xi64>) -> tensor<7x128x200xf32>
537+
%2 = "zhigh.Stick"(%1) {layout = "3DS"} : (tensor<7x128x200xf32>) -> tensor<7x128x200xf16, #zhigh.layout<{dataLayout = "3DS"}>>
538+
"func.return"(%2) : (tensor<7x128x200xf16, #zhigh.layout<{dataLayout = "3DS"}>>) -> ()
539+
540+
// CHECK-LABEL: func.func @replace_unstick_squeeze_stick_static
541+
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<7x1x128x200xf16, #zhigh.layout<{dataLayout = "4DS"}>>) -> tensor<7x128x200xf16, #zhigh.layout<{dataLayout = "3DS"}>> {
542+
// CHECK: [[VAR_0_:%.+]] = onnx.Constant dense<[7, 128, 200]> : tensor<3xi64>
543+
// CHECK: [[VAR_1_:%.+]] = "zhigh.Reshape"([[PARAM_0_]], [[VAR_0_]]) {layout = "3DS"} : (tensor<7x1x128x200xf16, #zhigh.layout<{dataLayout = "4DS"}>>, tensor<3xi64>) -> tensor<7x128x200xf16, #zhigh.layout<{dataLayout = "3DS"}>>
544+
// CHECK: return [[VAR_1_]] : tensor<7x128x200xf16, #zhigh.layout<{dataLayout = "3DS"}>>
545+
// CHECK: }
546+
}
547+
548+
// -----
549+
550+
func.func @replace_unstick_squeeze_stick_dynamic(%arg0: tensor<?x1x?x200xf16, #zhigh.layout<{dataLayout = "4DS"}>>) -> tensor<?x?x200xf16, #zhigh.layout<{dataLayout = "3DS"}>> {
551+
%cst1 = onnx.Constant dense<1> : tensor<1xi64>
552+
%0 = "zhigh.Unstick"(%arg0) : (tensor<?x1x?x200xf16, #zhigh.layout<{dataLayout = "4DS"}>>) -> tensor<?x1x?x200xf32>
553+
%1 = "onnx.Squeeze"(%0, %cst1) : (tensor<?x1x?x200xf32>, tensor<1xi64>) -> tensor<?x?x200xf32>
554+
%2 = "zhigh.Stick"(%1) {layout = "3DS"} : (tensor<?x?x200xf32>) -> tensor<?x?x200xf16, #zhigh.layout<{dataLayout = "3DS"}>>
555+
"func.return"(%2) : (tensor<?x?x200xf16, #zhigh.layout<{dataLayout = "3DS"}>>) -> ()
556+
557+
// CHECK-LABEL: func.func @replace_unstick_squeeze_stick_dynamic
558+
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<?x1x?x200xf16, #zhigh.layout<{dataLayout = "4DS"}>>) -> tensor<?x?x200xf16, #zhigh.layout<{dataLayout = "3DS"}>> {
559+
// CHECK-DAG: [[VAR_0_:%.+]] = onnx.Constant dense<200> : tensor<1xi64>
560+
// CHECK-DAG: [[VAR_1_:%.+]] = "onnx.Dim"([[PARAM_0_]]) {axis = 0 : si64} : (tensor<?x1x?x200xf16, #zhigh.layout<{dataLayout = "4DS"}>>) -> tensor<1xi64>
561+
// CHECK-DAG: [[VAR_2_:%.+]] = "onnx.Dim"([[PARAM_0_]]) {axis = 2 : si64} : (tensor<?x1x?x200xf16, #zhigh.layout<{dataLayout = "4DS"}>>) -> tensor<1xi64>
562+
// CHECK: [[VAR_3_:%.+]] = "onnx.Concat"([[VAR_1_]], [[VAR_2_]], [[VAR_0_]]) {axis = 0 : si64} : (tensor<1xi64>, tensor<1xi64>, tensor<1xi64>) -> tensor<3xi64>
563+
// CHECK: [[VAR_4_:%.+]] = "zhigh.Reshape"([[PARAM_0_]], [[VAR_3_]]) {layout = "3DS"} : (tensor<?x1x?x200xf16, #zhigh.layout<{dataLayout = "4DS"}>>, tensor<3xi64>) -> tensor<?x?x200xf16, #zhigh.layout<{dataLayout = "3DS"}>>
564+
// CHECK: return [[VAR_4_]] : tensor<?x?x200xf16, #zhigh.layout<{dataLayout = "3DS"}>>
565+
// CHECK: }
566+
}
567+

0 commit comments

Comments
 (0)