@@ -528,3 +528,40 @@ func.func @test_Roberta_pattern2_bs1_notmod64(%arg0: tensor<1x384x756xf32>, %arg
528
528
// CHECK: }
529
529
}
530
530
531
+ // -----
532
+
533
+ func.func @replace_unstick_squeeze_stick_static (%arg0: tensor <7 x1 x128 x200 xf16 , #zhigh.layout <{dataLayout = " 4DS" }>>) -> tensor <7 x128 x200 xf16 , #zhigh.layout <{dataLayout = " 3DS" }>> {
534
+ %cst1 = onnx.Constant dense <1 > : tensor <1 xi64 >
535
+ %0 = " zhigh.Unstick" (%arg0 ) : (tensor <7 x1 x128 x200 xf16 , #zhigh.layout <{dataLayout = " 4DS" }>>) -> tensor <7 x1 x128 x200 xf32 >
536
+ %1 = " onnx.Squeeze" (%0 , %cst1 ) : (tensor <7 x1 x128 x200 xf32 >, tensor <1 xi64 >) -> tensor <7 x128 x200 xf32 >
537
+ %2 = " zhigh.Stick" (%1 ) {layout = " 3DS" } : (tensor <7 x128 x200 xf32 >) -> tensor <7 x128 x200 xf16 , #zhigh.layout <{dataLayout = " 3DS" }>>
538
+ " func.return" (%2 ) : (tensor <7 x128 x200 xf16 , #zhigh.layout <{dataLayout = " 3DS" }>>) -> ()
539
+
540
+ // CHECK-LABEL: func.func @replace_unstick_squeeze_stick_static
541
+ // CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<7x1x128x200xf16, #zhigh.layout<{dataLayout = "4DS"}>>) -> tensor<7x128x200xf16, #zhigh.layout<{dataLayout = "3DS"}>> {
542
+ // CHECK: [[VAR_0_:%.+]] = onnx.Constant dense<[7, 128, 200]> : tensor<3xi64>
543
+ // CHECK: [[VAR_1_:%.+]] = "zhigh.Reshape"([[PARAM_0_]], [[VAR_0_]]) {layout = "3DS"} : (tensor<7x1x128x200xf16, #zhigh.layout<{dataLayout = "4DS"}>>, tensor<3xi64>) -> tensor<7x128x200xf16, #zhigh.layout<{dataLayout = "3DS"}>>
544
+ // CHECK: return [[VAR_1_]] : tensor<7x128x200xf16, #zhigh.layout<{dataLayout = "3DS"}>>
545
+ // CHECK: }
546
+ }
547
+
548
+ // -----
549
+
550
+ func.func @replace_unstick_squeeze_stick_dynamic (%arg0: tensor <?x1 x?x200 xf16 , #zhigh.layout <{dataLayout = " 4DS" }>>) -> tensor <?x?x200 xf16 , #zhigh.layout <{dataLayout = " 3DS" }>> {
551
+ %cst1 = onnx.Constant dense <1 > : tensor <1 xi64 >
552
+ %0 = " zhigh.Unstick" (%arg0 ) : (tensor <?x1 x?x200 xf16 , #zhigh.layout <{dataLayout = " 4DS" }>>) -> tensor <?x1 x?x200 xf32 >
553
+ %1 = " onnx.Squeeze" (%0 , %cst1 ) : (tensor <?x1 x?x200 xf32 >, tensor <1 xi64 >) -> tensor <?x?x200 xf32 >
554
+ %2 = " zhigh.Stick" (%1 ) {layout = " 3DS" } : (tensor <?x?x200 xf32 >) -> tensor <?x?x200 xf16 , #zhigh.layout <{dataLayout = " 3DS" }>>
555
+ " func.return" (%2 ) : (tensor <?x?x200 xf16 , #zhigh.layout <{dataLayout = " 3DS" }>>) -> ()
556
+
557
+ // CHECK-LABEL: func.func @replace_unstick_squeeze_stick_dynamic
558
+ // CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<?x1x?x200xf16, #zhigh.layout<{dataLayout = "4DS"}>>) -> tensor<?x?x200xf16, #zhigh.layout<{dataLayout = "3DS"}>> {
559
+ // CHECK-DAG: [[VAR_0_:%.+]] = onnx.Constant dense<200> : tensor<1xi64>
560
+ // CHECK-DAG: [[VAR_1_:%.+]] = "onnx.Dim"([[PARAM_0_]]) {axis = 0 : si64} : (tensor<?x1x?x200xf16, #zhigh.layout<{dataLayout = "4DS"}>>) -> tensor<1xi64>
561
+ // CHECK-DAG: [[VAR_2_:%.+]] = "onnx.Dim"([[PARAM_0_]]) {axis = 2 : si64} : (tensor<?x1x?x200xf16, #zhigh.layout<{dataLayout = "4DS"}>>) -> tensor<1xi64>
562
+ // CHECK: [[VAR_3_:%.+]] = "onnx.Concat"([[VAR_1_]], [[VAR_2_]], [[VAR_0_]]) {axis = 0 : si64} : (tensor<1xi64>, tensor<1xi64>, tensor<1xi64>) -> tensor<3xi64>
563
+ // CHECK: [[VAR_4_:%.+]] = "zhigh.Reshape"([[PARAM_0_]], [[VAR_3_]]) {layout = "3DS"} : (tensor<?x1x?x200xf16, #zhigh.layout<{dataLayout = "4DS"}>>, tensor<3xi64>) -> tensor<?x?x200xf16, #zhigh.layout<{dataLayout = "3DS"}>>
564
+ // CHECK: return [[VAR_4_]] : tensor<?x?x200xf16, #zhigh.layout<{dataLayout = "3DS"}>>
565
+ // CHECK: }
566
+ }
567
+
0 commit comments