Skip to content

Commit 6f85162

Browse files
committed
Extend quantiser support so as to accelerate more binary models.
Add the ability to convert `tf.where`-style binary quantisers, and add support for boolean input to `LceQuantize` and `LceDequantize`.
1 parent 72e5150 commit 6f85162

File tree

10 files changed

+178
-23
lines changed

10 files changed

+178
-23
lines changed

larq_compute_engine/mlir/ir/lce_ops.td

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -70,11 +70,11 @@ def LQ_QuantizeOp : LQ_Op<"Quantize", [NoSideEffect]> {
7070
let summary = "Binary quantize operator";
7171

7272
let description = [{
73-
Converts floating point or integer tensors to binarized bitpacked tensors.
73+
Converts floating point, integer, or boolean tensors to binarized bitpacked tensors.
7474
}];
7575

7676
let arguments = (ins
77-
TensorOf<[BF16, F16, F32, F64, I32, I64, QI8, QI16]>:$x
77+
TensorOf<[BF16, F16, F32, F64, I32, I64, QI8, QI16, I1]>:$x
7878
);
7979

8080
let results = (outs
@@ -90,15 +90,15 @@ def LQ_DequantizeOp : LQ_Op<"Dequantize", [NoSideEffect]> {
9090
let summary = "Binary dequantize operator";
9191

9292
let description = [{
93-
Converts binarized bitpacked tensors to floating point or integer tensors.
93+
Converts binarized bitpacked tensors to floating point, integer, or boolean tensors.
9494
}];
9595

9696
let arguments = (ins
9797
TensorOf<[I32]>:$x
9898
);
9999

100100
let results = (outs
101-
TensorOf<[BF16, F16, F32, F64, I32, I64, QI8, QI16]>:$y
101+
TensorOf<[BF16, F16, F32, F64, I32, I64, QI8, QI16, I1]>:$y
102102
);
103103

104104
let hasFolder = 1;

larq_compute_engine/mlir/tests/optimize.mlir

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,50 @@
11
// RUN: lce-tf-opt %s -tfl-optimize-lce=target=arm -verify-diagnostics | FileCheck %s --check-prefixes CHECK,CHECK-ARM
22
// RUN: lce-tf-opt %s -tfl-optimize-lce=target=xcore -verify-diagnostics | FileCheck %s --check-prefixes CHECK,CHECK-XCORE
33

4+
// CHECK-LABEL: @optimize_quantize_greater_equal_zero
5+
func @optimize_quantize_greater_equal_zero(%arg0: tensor<48x48x64xf32>) -> tensor<48x48x2xi32> {
6+
%cst = constant dense<0.0> : tensor<f32>
7+
%0 = "tfl.greater_equal"(%arg0, %cst) : (tensor<48x48x64xf32>, tensor<f32>) -> tensor<48x48x64xi1>
8+
%1 = "lq.Quantize"(%0) : (tensor<48x48x64xi1>) -> tensor<48x48x2xi32>
9+
return %1 : tensor<48x48x2xi32>
10+
11+
// CHECK-NEXT: %0 = "lq.Quantize"(%arg0) : (tensor<48x48x64xf32>) -> tensor<48x48x2xi32>
12+
// CHECK-NEXT: return %0
13+
}
14+
15+
// CHECK-LABEL: @optimize_quantize_greater_equal_non_zero
16+
func @optimize_quantize_greater_equal_non_zero(%arg0: tensor<48x48x64xf32>, %arg1: tensor<48x48x64xf32>) -> tensor<48x48x2xi32> {
17+
%0 = "tfl.greater_equal"(%arg0, %arg1) : (tensor<48x48x64xf32>, tensor<48x48x64xf32>) -> tensor<48x48x64xi1>
18+
%1 = "lq.Quantize"(%0) : (tensor<48x48x64xi1>) -> tensor<48x48x2xi32>
19+
return %1 : tensor<48x48x2xi32>
20+
21+
// CHECK-NEXT: %0 = tfl.sub %arg0, %arg1 {fused_activation_function = "NONE"} : tensor<48x48x64xf32>
22+
// CHECK-NEXT: %1 = "lq.Quantize"(%0) : (tensor<48x48x64xf32>) -> tensor<48x48x2xi32>
23+
// CHECK-NEXT: return %1
24+
}
25+
26+
// CHECK-LABEL: @optimize_quantize_less_equal_zero
27+
func @optimize_quantize_less_equal_zero(%arg0: tensor<48x48x64xf32>) -> tensor<48x48x2xi32> {
28+
%cst = constant dense<0.0> : tensor<64xf32>
29+
%0 = "tfl.less_equal"(%cst, %arg0) : (tensor<64xf32>, tensor<48x48x64xf32>) -> tensor<48x48x64xi1>
30+
%1 = "lq.Quantize"(%0) : (tensor<48x48x64xi1>) -> tensor<48x48x2xi32>
31+
return %1 : tensor<48x48x2xi32>
32+
33+
// CHECK-NEXT: %0 = "lq.Quantize"(%arg0) : (tensor<48x48x64xf32>) -> tensor<48x48x2xi32>
34+
// CHECK-NEXT: return %0
35+
}
36+
37+
// CHECK-LABEL: @optimize_quantize_less_equal_non_zero
38+
func @optimize_quantize_less_equal_non_zero(%arg0: tensor<48x48x64xf32>, %arg1: tensor<48x48x64xf32>) -> tensor<48x48x2xi32> {
39+
%0 = "tfl.less_equal"(%arg0, %arg1) : (tensor<48x48x64xf32>, tensor<48x48x64xf32>) -> tensor<48x48x64xi1>
40+
%1 = "lq.Quantize"(%0) : (tensor<48x48x64xi1>) -> tensor<48x48x2xi32>
41+
return %1 : tensor<48x48x2xi32>
42+
43+
// CHECK-NEXT: %0 = tfl.sub %arg1, %arg0 {fused_activation_function = "NONE"} : tensor<48x48x64xf32>
44+
// CHECK-NEXT: %1 = "lq.Quantize"(%0) : (tensor<48x48x64xf32>) -> tensor<48x48x2xi32>
45+
// CHECK-NEXT: return %1
46+
}
47+
448
// CHECK-LABEL: @fuse_add_into_bconv2d
549
func @fuse_add_into_bconv2d(%arg0: tensor<256x32x32x1xi32>, %arg1: tensor<16x3x3x3xf32>, %arg2: tensor<16xf32>, %arg3: none) -> tensor<256x30x30x16xf32> {
650
%cst = constant dense<1.5> : tensor<16xf32>

larq_compute_engine/mlir/tests/prepare-tf.mlir

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,32 @@
11
// RUN: lce-tf-opt %s -tfl-prepare-lce=target=arm -verify-diagnostics | FileCheck %s --check-prefixes CHECK,CHECK-ARM
22
// RUN: lce-tf-opt %s -tfl-prepare-lce=target=xcore -verify-diagnostics | FileCheck %s --check-prefixes CHECK,CHECK-XCORE
33

4-
// CHECK-LABEL: @fuse_bsign
5-
func @fuse_bsign(%arg0: tensor<8x16xf32>) -> tensor<8x16xf32> {
4+
// CHECK-LABEL: @fuse_bsign_tf_where
5+
func @fuse_bsign_tf_where(%arg0: tensor<8x16xi1>) -> tensor<8x16xf32> {
6+
%cst_l = constant dense<1.0> : tensor<8x16xf32>
7+
%cst_r = constant dense<-1.0> : tensor<8x16xf32>
8+
%0 = "tf.SelectV2"(%arg0, %cst_l, %cst_r) : (tensor<8x16xi1>, tensor<8x16xf32>, tensor<8x16xf32>) -> tensor<8x16xf32>
9+
return %0 : tensor<8x16xf32>
10+
11+
// CHECK-NEXT: %0 = "lq.Quantize"(%arg0) : (tensor<8x16xi1>) -> tensor<8x1xi32>
12+
// CHECK-NEXT: %1 = "lq.Dequantize"(%0) : (tensor<8x1xi32>) -> tensor<8x16xf32>
13+
// CHECK-NEXT: return %1
14+
}
15+
16+
// CHECK-LABEL: @fuse_bsign_tf_where_scalar_constants
17+
func @fuse_bsign_tf_where_scalar_constants(%arg0: tensor<8x16xi1>) -> tensor<8x16xf32> {
18+
%cst_l = constant dense<1.0> : tensor<f32>
19+
%cst_r = constant dense<-1.0> : tensor<f32>
20+
%0 = "tf.SelectV2"(%arg0, %cst_l, %cst_r) : (tensor<8x16xi1>, tensor<f32>, tensor<f32>) -> tensor<8x16xf32>
21+
return %0 : tensor<8x16xf32>
22+
23+
// CHECK-NEXT: %0 = "lq.Quantize"(%arg0) : (tensor<8x16xi1>) -> tensor<8x1xi32>
24+
// CHECK-NEXT: %1 = "lq.Dequantize"(%0) : (tensor<8x1xi32>) -> tensor<8x16xf32>
25+
// CHECK-NEXT: return %1
26+
}
27+
28+
// CHECK-LABEL: @fuse_bsign_legacy_tf_sign
29+
func @fuse_bsign_legacy_tf_sign(%arg0: tensor<8x16xf32>) -> tensor<8x16xf32> {
630
%0 = "tf.Sign"(%arg0) : (tensor<8x16xf32>) -> tensor<8x16xf32>
731
%cst = constant dense<0.1> : tensor<f32>
832
%2 = "tf.AddV2"(%0, %cst) : (tensor<8x16xf32>, tensor<f32>) -> tensor<8x16xf32>

larq_compute_engine/mlir/transforms/optimize_patterns_common.td

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,30 @@ def HasOneUse : Constraint<CPred<"$0.hasOneUse()">>;
1111

1212
class ConstantValue<string val> : AttrConstraint<CPred<"IsConstantValue($_self, " # val # ")">>;
1313

14+
def : Pat<(LQ_QuantizeOp
15+
(TFL_GreaterEqualOp:$ge_op
16+
$input,
17+
(ConstantOp:$threshold ConstantValue<"0.0f">))),
18+
(LQ_QuantizeOp $input),
19+
[(HasOneUse $ge_op), (HasOneUse $threshold)],
20+
(addBenefit 150)>;
21+
22+
def : Pat<(LQ_QuantizeOp
23+
(TFL_GreaterEqualOp:$ge_op
24+
$input,
25+
$threshold)),
26+
(LQ_QuantizeOp
27+
(TFL_SubOp $input, $threshold, TFL_AF_None)),
28+
[(HasOneUse $ge_op)],
29+
(addBenefit 100)>;
30+
31+
def : Pat<(LQ_QuantizeOp
32+
(TFL_LessEqualOp:$ge_op $lhs, $rhs)),
33+
(LQ_QuantizeOp
34+
(TFL_GreaterEqualOp $rhs, $lhs)),
35+
[(HasOneUse $ge_op)],
36+
(addBenefit 100)>;
37+
1438
// TODO: Check shapes before fusing
1539
multiclass FuseAddOrSubWithBConv2D<Op binaryOp> {
1640
def : Pat<(binaryOp

larq_compute_engine/mlir/transforms/prepare_patterns_common.td

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,18 @@ include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td"
44
include "larq_compute_engine/mlir/ir/lce_ops.td"
55
include "larq_compute_engine/mlir/transforms/op_removal_patterns.td"
66

7+
class ConstantValue<string val> : AttrConstraint<CPred<"IsConstantValue($_self, " # val # ")">>;
78

8-
// This relies on implementation details of larq.math.sign. We should make
9-
// this more general in the future
9+
// Base quantiser pattern that matches the `tf.where` implementation of `ste_sign`.
10+
def : Pat<(TF_SelectV2Op:$select_op
11+
$cond,
12+
(ConstantOp ConstantValue<"1.0f">),
13+
(ConstantOp ConstantValue<"-1.0f">)),
14+
(LQ_DequantizeOp (LQ_QuantizeOp $cond)),
15+
[], (addBenefit 100)>;
16+
17+
// A fallback for the old version of `ste_sign` that uses a specific `tf.sign`
18+
// based implementation of `larq.math.sign`.
1019
def : Pat<(TF_SignOp (TF_AddV2Op (TF_SignOp $arg), $c)),
1120
(LQ_DequantizeOp (LQ_QuantizeOp $arg)), [], (addBenefit 100)>;
1221
def : Pat<(TF_SignOp (TF_AddV2Op $c, (TF_SignOp $arg))),

larq_compute_engine/mlir/transforms/prepare_tf.cc

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,14 @@ struct PrepareLCE : public PassWrapper<PrepareLCE, FunctionPass> {
3636
clEnumValN(LCETarget::XCORE, "xcore", "XCORE target"))};
3737
};
3838

39+
bool IsConstantValue(Attribute values, float expected_value) {
40+
if (!values.isa<DenseElementsAttr>()) return false;
41+
42+
for (auto value : values.cast<DenseElementsAttr>().getValues<float>()) {
43+
if (value != expected_value) return false;
44+
}
45+
return true;
46+
}
3947
DenseElementsAttr GetConstantVector(Attribute filter, float val) {
4048
auto filter_type = filter.getType().cast<ShapedType>();
4149
auto filter_shape = filter_type.getShape();

larq_compute_engine/tests/end2end_test.py

Lines changed: 25 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import functools
12
import os
23
import sys
34
import tempfile
@@ -23,7 +24,7 @@ def convert_keras_model_as_saved_model(model, **kwargs):
2324
return convert_saved_model(saved_model_dir, **kwargs)
2425

2526

26-
def toy_model(**kwargs):
27+
def toy_model(binary_quantizer="ste_sign", **kwargs):
2728
def block(padding, pad_values, activation):
2829
def dummy(x):
2930
shortcut = x
@@ -32,8 +33,8 @@ def dummy(x):
3233
kernel_size=3,
3334
padding=padding,
3435
pad_values=pad_values,
35-
input_quantizer="ste_sign",
36-
kernel_quantizer="ste_sign",
36+
input_quantizer=binary_quantizer,
37+
kernel_quantizer=binary_quantizer,
3738
use_bias=False,
3839
activation=activation,
3940
)(x)
@@ -59,7 +60,7 @@ def dummy(x):
5960
return tf.keras.Model(inputs=img_input, outputs=out)
6061

6162

62-
def toy_model_sequential(**kwargs):
63+
def toy_model_sequential(binary_quantizer="ste_sign", **kwargs):
6364
return tf.keras.models.Sequential(
6465
[
6566
tf.keras.layers.Input((224, 224, 3)),
@@ -70,8 +71,8 @@ def toy_model_sequential(**kwargs):
7071
lq.layers.QuantConv2D(
7172
32,
7273
(3, 3),
73-
input_quantizer="ste_sign",
74-
kernel_quantizer="ste_sign",
74+
input_quantizer=binary_quantizer,
75+
kernel_quantizer=binary_quantizer,
7576
padding="same",
7677
pad_values=1.0,
7778
use_bias=False,
@@ -85,8 +86,8 @@ def toy_model_sequential(**kwargs):
8586
lq.layers.QuantConv2D(
8687
32,
8788
(3, 3),
88-
input_quantizer="ste_sign",
89-
kernel_quantizer="ste_sign",
89+
input_quantizer=binary_quantizer,
90+
kernel_quantizer=binary_quantizer,
9091
strides=(2, 2),
9192
padding="same",
9293
pad_values=1.0,
@@ -104,8 +105,8 @@ def toy_model_sequential(**kwargs):
104105
lq.layers.QuantConv2D(
105106
32,
106107
(3, 3),
107-
input_quantizer="ste_sign",
108-
kernel_quantizer="ste_sign",
108+
input_quantizer=binary_quantizer,
109+
kernel_quantizer=binary_quantizer,
109110
padding="same",
110111
pad_values=1.0,
111112
use_bias=False,
@@ -165,12 +166,25 @@ def dataset():
165166
)
166167

167168

169+
def tf_where_binary_quantizer(x):
170+
return tf.where(x >= 0, tf.ones_like(x), -tf.ones_like(x))
171+
172+
168173
@pytest.mark.parametrize(
169174
"conversion_function", [convert_keras_model, convert_keras_model_as_saved_model]
170175
)
171176
@pytest.mark.parametrize(
172177
"model_cls",
173-
[toy_model, toy_model_sequential, toy_model_int8, lqz.sota.QuickNetSmall],
178+
[
179+
toy_model,
180+
functools.partial(toy_model, binary_quantizer=tf_where_binary_quantizer),
181+
toy_model_sequential,
182+
functools.partial(
183+
toy_model_sequential, binary_quantizer=tf_where_binary_quantizer
184+
),
185+
toy_model_int8,
186+
lqz.sota.QuickNetSmall,
187+
],
174188
)
175189
def test_simple_model(dataset, conversion_function, model_cls):
176190
model = model_cls(weights="imagenet")

larq_compute_engine/tflite/kernels/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ cc_library(
4141
"//larq_compute_engine/core/indirect_bgemm:kernels",
4242
"@flatbuffers",
4343
"@org_tensorflow//tensorflow/lite:framework",
44+
"@org_tensorflow//tensorflow/lite:type_to_tflitetype",
4445
"@org_tensorflow//tensorflow/lite/kernels/internal:kernel_utils",
4546
"@org_tensorflow//tensorflow/lite/kernels/internal:tensor",
4647
"@ruy//ruy/profiler:instrumentation",

larq_compute_engine/tflite/kernels/quantization.cc

Lines changed: 33 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
1+
#include <type_traits>
2+
13
#include "larq_compute_engine/core/bitpacking/utils.h"
24
#include "ruy/profiler/instrumentation.h"
35
#include "tensorflow/lite/c/common.h"
46
#include "tensorflow/lite/kernels/internal/cppmath.h"
57
#include "tensorflow/lite/kernels/internal/tensor.h"
68
#include "tensorflow/lite/kernels/kernel_util.h"
9+
#include "tensorflow/lite/portable_type_to_tflitetype.h"
710

811
using namespace tflite;
912

@@ -20,8 +23,9 @@ TfLiteStatus QuantizePrepare(TfLiteContext* context, TfLiteNode* node) {
2023
const TfLiteTensor* input = GetInput(context, node, 0);
2124
TfLiteTensor* output = GetOutput(context, node, 0);
2225

23-
TF_LITE_ENSURE(context,
24-
input->type == kTfLiteFloat32 || input->type == kTfLiteInt8);
26+
TF_LITE_ENSURE(context, input->type == kTfLiteFloat32 ||
27+
input->type == kTfLiteInt8 ||
28+
input->type == kTfLiteBool);
2529
TF_LITE_ENSURE_EQ(context, output->type, kTfLiteInt32);
2630

2731
int num_dims = NumDimensions(input);
@@ -44,8 +48,9 @@ TfLiteStatus DequantizePrepare(TfLiteContext* context, TfLiteNode* node) {
4448
TfLiteTensor* output = GetOutput(context, node, 0);
4549

4650
TF_LITE_ENSURE_EQ(context, input->type, kTfLiteInt32);
47-
TF_LITE_ENSURE(context,
48-
output->type == kTfLiteFloat32 || output->type == kTfLiteInt8);
51+
TF_LITE_ENSURE(context, output->type == kTfLiteFloat32 ||
52+
output->type == kTfLiteInt8 ||
53+
output->type == kTfLiteBool);
4954

5055
int num_dims = NumDimensions(input);
5156

@@ -80,6 +85,27 @@ TfLiteStatus QuantizeEval(TfLiteContext* context, TfLiteNode* node) {
8085
} else if (input->type == kTfLiteInt8) {
8186
bitpack_tensor(GetTensorShape(input), GetTensorData<std::int8_t>(input),
8287
input->params.zero_point, GetTensorData<TBitpacked>(output));
88+
} else if (input->type == kTfLiteBool) {
89+
// The strategy here is to interpret the input data as an unsigned integer
90+
// (of the same width as the bool type for the target). We then call
91+
// bitpacking, with a 'zero point' of 1. This means that the value with all
92+
// zero bits will be bitpacked as bit 1, and all other values will be
93+
// bitpacked as bit 0. Assuming that `false` is represented by a value with
94+
// all zero bits, this gives the correct result of bitpacking `false` as bit
95+
// 1 and `true` as bit 0.
96+
97+
static_assert(std::is_same<::tflite::TfLiteTypeToType<kTfLiteBool>::Type,
98+
bool>::value,
99+
"");
100+
using BOOL_UINT = std::conditional<
101+
sizeof(bool) == 1, std::uint8_t,
102+
std::conditional<sizeof(bool) == 2, std::uint16_t,
103+
std::conditional<sizeof(bool) == 4, std::uint32_t,
104+
std::uint64_t>::type>::type>::type;
105+
static_assert(sizeof(bool) == sizeof(BOOL_UINT), "");
106+
107+
bitpack_tensor(GetTensorShape(input), GetTensorData<BOOL_UINT>(input),
108+
BOOL_UINT(1), GetTensorData<TBitpacked>(output));
83109
} else {
84110
return kTfLiteError;
85111
}
@@ -110,6 +136,9 @@ TfLiteStatus DequantizeEval(TfLiteContext* context, TfLiteNode* node) {
110136
unpack_matrix(GetTensorData<TBitpacked>(input), num_rows, num_cols,
111137
GetTensorData<std::int8_t>(output), zero_bit_result,
112138
one_bit_result);
139+
} else if (output->type == kTfLiteBool) {
140+
unpack_matrix(GetTensorData<TBitpacked>(input), num_rows, num_cols,
141+
GetTensorData<bool>(output), true, false);
113142
} else {
114143
return kTfLiteError;
115144
}

larq_compute_engine/tflite/tests/quantization_test.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,8 @@ TEST_P(QuantizationOpTest, Float) { TestQuantization<float>(GetParam()); }
116116

117117
TEST_P(QuantizationOpTest, Int8) { TestQuantization<std::int8_t>(GetParam()); }
118118

119+
TEST_P(QuantizationOpTest, Bool) { TestQuantization<bool>(GetParam()); }
120+
119121
INSTANTIATE_TEST_SUITE_P(AllCombinations, QuantizationOpTest,
120122
::testing::Values(std::array<int, 4>{1, 1, 1, 1},
121123
std::array<int, 4>{1, 4, 4, 1},

0 commit comments

Comments
 (0)