Skip to content

Commit a76f70c

Browse files
committed
Do not fuse locations when normalizing constants for Add and Mul
Signed-off-by: Rickert, Jonas <[email protected]>
1 parent f2ec420 commit a76f70c

File tree

2 files changed

+34
-4
lines changed

2 files changed

+34
-4
lines changed

src/Dialect/ONNX/Transforms/ConstProp.td

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -302,9 +302,9 @@ def CreateScatterNDOfConst :
302302
// Use commutativity to normalize constants in the second position of Add.
303303
def AddConstCommutative1 : NamedPat<"AddConstCommutative1",
304304
// From add(c, x).
305-
(ONNXAddOp (ONNXConstantOp:$c $_, $_, $_, $_, $_, $_, $_, $_), $x),
305+
(ONNXAddOp:$addOp (ONNXConstantOp:$c $_, $_, $_, $_, $_, $_, $_, $_), $x),
306306
// To add(x, c).
307-
(ONNXAddOp $x, $c),
307+
(ONNXAddOp $x, $c, (location $addOp)),
308308
// To avoid infinite loop, constrain the first arguments to be anything but a constant.
309309
[(IsNotAConstant:$x)]>;
310310

@@ -575,9 +575,9 @@ def SumConstProp : NamedPat<"SumConstProp",
575575
// Use commutativity to normalize constants in the second position of Mul.
576576
def MulConstCommutative1 : NamedPat<"MulConstCommutative1",
577577
// From mul(c, x).
578-
(ONNXMulOp (ONNXConstantOp:$c $_, $_, $_, $_, $_, $_, $_, $_), $x),
578+
(ONNXMulOp:$mulOp (ONNXConstantOp:$c $_, $_, $_, $_, $_, $_, $_, $_), $x),
579579
// To mul(x, c).
580-
(ONNXMulOp $x, $c),
580+
(ONNXMulOp $x, $c, (location $mulOp)),
581581
// To avoid infinite loop, constrain the first arguments to be anything but a constant.
582582
[(IsNotAConstant:$x)]>;
583583

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
// RUN: onnx-mlir-opt --shape-inference --constprop-onnx %s -split-input-file --mlir-print-debuginfo | FileCheck %s
2+
3+
4+
//===----------------------------------------------------------------------===//
5+
/// Commutative tests
6+
7+
// CHECK-LABEL: @test_add_constant_1_loc
8+
func.func @test_add_constant_1_loc(%arg0 : tensor<3xf32>) -> tensor<3xf32> {
9+
%0 = onnx.Constant dense<[0.0, 1.0, 2.0]> : tensor<3xf32> loc("Constant")
10+
%1 = "onnx.Add"(%0, %arg0) : (tensor<3xf32> , tensor<3xf32>) -> tensor<3xf32> loc("Add")
11+
"onnx.Return"(%1) : (tensor<3xf32>) -> ()
12+
// CHECK-NEXT: [[CONST:%.+]] = onnx.Constant dense<[0.000000e+00, 1.000000e+00, 2.000000e+00]> : tensor<3xf32> loc([[LOC_CONST:#.+]])
13+
// CHECK-NEXT: [[ADD:%.+]] = "onnx.Add"(%arg0, [[CONST]]) : (tensor<3xf32>, tensor<3xf32>) -> tensor<3xf32> loc([[LOC_ADD:#.+]])
14+
// CHECK-DAG: [[LOC_CONST]] = loc("Constant")
15+
// CHECK-DAG: [[LOC_ADD]] = loc("Add")
16+
}
17+
18+
// -----
19+
20+
// CHECK-LABEL: @test_mul_constant_1_loc
21+
func.func @test_mul_constant_1_loc(%arg0 : tensor<3xf32>) -> tensor<3xf32> {
22+
%0 = onnx.Constant dense<[0.0, 1.0, 2.0]> : tensor<3xf32> loc("Constant")
23+
%1 = "onnx.Mul"(%0, %arg0) : (tensor<3xf32> , tensor<3xf32>) -> tensor<3xf32> loc("Mul")
24+
"onnx.Return"(%1) : (tensor<3xf32>) -> ()
25+
// CHECK-NEXT: [[CONST:%.+]] = onnx.Constant dense<[0.000000e+00, 1.000000e+00, 2.000000e+00]> : tensor<3xf32> loc([[LOC_CONST:#.+]])
26+
// CHECK-NEXT: [[MUL:%.+]] = "onnx.Mul"(%arg0, [[CONST]]) : (tensor<3xf32>, tensor<3xf32>) -> tensor<3xf32> loc([[LOC_MUL:#.+]])
27+
// CHECK-DAG: [[LOC_CONST]] = loc("Constant")
28+
// CHECK-DAG: [[LOC_MUL]] = loc("Mul")
29+
}
30+

0 commit comments

Comments
 (0)