Skip to content

Commit fb76a73

Browse files
committed
Add preliminary support for the xcore target.
1 parent 95fd0b0 commit fb76a73

15 files changed

+242
-96
lines changed

larq_compute_engine/mlir/BUILD

Lines changed: 42 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -36,28 +36,60 @@ gentbl(
3636
)
3737

3838
gentbl(
39-
name = "prepare_lce_inc_gen",
39+
name = "prepare_lce_target_arm_inc_gen",
4040
tbl_outs = [
41-
("-gen-rewriters", "transforms/generated_prepare.inc"),
41+
("-gen-rewriters", "transforms/generated_prepare_target_arm.inc"),
4242
],
4343
tblgen = "@llvm-project//mlir:mlir-tblgen",
44-
td_file = "transforms/prepare_patterns.td",
44+
td_file = "transforms/prepare_patterns_target_arm.td",
4545
td_srcs = [
4646
"ir/lce_ops.td",
4747
"transforms/op_removal_patterns.td",
48+
"transforms/prepare_patterns_common.td",
4849
"@llvm-project//mlir:StdOpsTdFiles",
4950
"@org_tensorflow//tensorflow/compiler/mlir/tensorflow:tensorflow_ops_td_files",
5051
"@org_tensorflow//tensorflow/compiler/mlir/lite:tensorflow_lite_ops_td_files",
5152
],
5253
)
5354

5455
gentbl(
55-
name = "optimize_lce_inc_gen",
56+
name = "prepare_lce_target_other_inc_gen",
5657
tbl_outs = [
57-
("-gen-rewriters", "transforms/generated_optimize.inc"),
58+
("-gen-rewriters", "transforms/generated_prepare_target_other.inc"),
5859
],
5960
tblgen = "@llvm-project//mlir:mlir-tblgen",
60-
td_file = "transforms/optimize_patterns.td",
61+
td_file = "transforms/prepare_patterns_common.td",
62+
td_srcs = [
63+
"ir/lce_ops.td",
64+
"transforms/op_removal_patterns.td",
65+
"@llvm-project//mlir:StdOpsTdFiles",
66+
"@org_tensorflow//tensorflow/compiler/mlir/tensorflow:tensorflow_ops_td_files",
67+
"@org_tensorflow//tensorflow/compiler/mlir/lite:tensorflow_lite_ops_td_files",
68+
],
69+
)
70+
71+
gentbl(
72+
name = "optimize_lce_target_arm_inc_gen",
73+
tbl_outs = [
74+
("-gen-rewriters", "transforms/generated_optimize_target_arm.inc"),
75+
],
76+
tblgen = "@llvm-project//mlir:mlir-tblgen",
77+
td_file = "transforms/optimize_patterns_target_arm.td",
78+
td_srcs = [
79+
"ir/lce_ops.td",
80+
"transforms/optimize_patterns_common.td",
81+
"@org_tensorflow//tensorflow/compiler/mlir/lite:tensorflow_lite_ops_td_files",
82+
"@llvm-project//mlir:StdOpsTdFiles",
83+
],
84+
)
85+
86+
gentbl(
87+
name = "optimize_lce_target_other_inc_gen",
88+
tbl_outs = [
89+
("-gen-rewriters", "transforms/generated_optimize_target_other.inc"),
90+
],
91+
tblgen = "@llvm-project//mlir:mlir-tblgen",
92+
td_file = "transforms/optimize_patterns_common.td",
6193
td_srcs = [
6294
"ir/lce_ops.td",
6395
"@org_tensorflow//tensorflow/compiler/mlir/lite:tensorflow_lite_ops_td_files",
@@ -136,7 +168,8 @@ cc_library(
136168
cc_library(
137169
name = "larq_compute_engine_prepare",
138170
srcs = [
139-
"transforms/generated_prepare.inc",
171+
"transforms/generated_prepare_target_arm.inc",
172+
"transforms/generated_prepare_target_other.inc",
140173
"transforms/prepare_tf.cc",
141174
],
142175
hdrs = [
@@ -157,7 +190,8 @@ cc_library(
157190
cc_library(
158191
name = "larq_compute_engine_optimize",
159192
srcs = [
160-
"transforms/generated_optimize.inc",
193+
"transforms/generated_optimize_target_arm.inc",
194+
"transforms/generated_optimize_target_other.inc",
161195
"transforms/optimize.cc",
162196
],
163197
hdrs = [

larq_compute_engine/mlir/python/converter.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ def convert_keras_model(
5555
*, # Require remaining arguments to be keyword-only.
5656
inference_input_type: tf.DType = tf.float32,
5757
inference_output_type: tf.DType = tf.float32,
58+
target: str = "arm",
5859
experimental_default_int8_range: Optional[Tuple[float, float]] = None,
5960
experimental_enable_bitpacked_activations: bool = False,
6061
) -> bytes:
@@ -73,6 +74,7 @@ def convert_keras_model(
7374
must be either `tf.float32` or `tf.int8`.
7475
inference_output_type: Data type of the output layer. Defaults to `tf.float32`,
7576
must be either `tf.float32` or `tf.int8`.
77+
target: Target hardware platform. Must be "arm" or "xcore".
7678
experimental_default_int8_range: Tuple of integers representing `(min, max)`
7779
range values for all arrays without a specified range. Intended for
7880
experimenting with quantization via "dummy quantization". (default None)
@@ -98,6 +100,8 @@ def convert_keras_model(
98100
"Expected `inference_output_type` to be either `tf.float32` or `tf.int8`, "
99101
f"got {inference_output_type}."
100102
)
103+
if target not in ("arm", "xcore"):
104+
raise ValueError(f'Expected `target` to be "arm" or "xcore", but got {target}.')
101105

102106
if not tf.executing_eagerly():
103107
raise RuntimeError(
@@ -147,6 +151,7 @@ def convert_keras_model(
147151
[tensor.shape.as_list() for tensor in input_tensors],
148152
[get_tensor_name(tensor) for tensor in output_tensors],
149153
should_quantize,
154+
target,
150155
experimental_default_int8_range,
151156
experimental_enable_bitpacked_activations,
152157
)

larq_compute_engine/mlir/python/converter_test.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ def test_larq_zoo_models(self):
3131
[[1, 224, 224, 3]],
3232
["Identity"],
3333
False,
34+
"arm",
3435
None,
3536
False,
3637
)
@@ -39,6 +40,20 @@ def test_wrong_arg(self):
3940
with self.assertRaises(ValueError):
4041
convert_keras_model("./model.h5")
4142

43+
def test_target_arg(self):
44+
with context.eager_mode():
45+
model = lqz.sota.QuickNet(weights=None)
46+
47+
# These should work
48+
convert_keras_model(model, target="arm")
49+
convert_keras_model(model, target="xcore")
50+
51+
# Anything else shouldn't
52+
with self.assertRaises(
53+
ValueError, msg='Expected `target` to be "arm" or "xcore"'
54+
):
55+
convert_keras_model(model, target="x86")
56+
4257

4358
if __name__ == "__main__":
4459
unittest.main()

larq_compute_engine/mlir/python/graphdef_tfl_flatbuffer.cc

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,14 +41,23 @@ pybind11::bytes ConvertGraphDefToTFLiteFlatBuffer(
4141
const std::vector<string>& input_dtypes,
4242
const std::vector<std::vector<int>>& input_shapes,
4343
const std::vector<string>& output_arrays, const bool should_quantize,
44-
const pybind11::object& default_ranges,
44+
const std::string& target_str, const pybind11::object& default_ranges,
4545
const bool experimental_enable_bitpacked_activations) {
4646
GraphDef graphdef;
4747
if (!tensorflow::LoadProtoFromBuffer(std::string(graphdef_bytes), &graphdef)
4848
.ok()) {
4949
throw std::runtime_error("Could not load GraphDef.");
5050
}
5151

52+
LCETarget target;
53+
if (target_str == "arm") {
54+
target = LCETarget::ARM;
55+
} else if (target_str == "xcore") {
56+
target = LCETarget::XCORE;
57+
} else {
58+
throw std::runtime_error("Invalid target.");
59+
}
60+
5261
GraphImportConfig specs;
5362
specs.prune_unused_nodes = true;
5463
specs.convert_legacy_fed_inputs = true;
@@ -88,7 +97,7 @@ pybind11::bytes ConvertGraphDefToTFLiteFlatBuffer(
8897
}
8998
mlir::PassManager pm(&context);
9099
tensorflow::AddTFToLCETFLConversionPasses(
91-
quant_specs, &pm, experimental_enable_bitpacked_activations);
100+
quant_specs, &pm, target, experimental_enable_bitpacked_activations);
92101

93102
// Convert back to outlined while format for export back to flatbuffer.
94103
pm.addPass(mlir::TFL::CreateWhileOutlinePass());

larq_compute_engine/mlir/tests/optimize.mlir

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
// RUN: lce-tf-opt %s -tfl-optimize-lce -verify-diagnostics | FileCheck %s
1+
// RUN: lce-tf-opt %s -tfl-optimize-lce=target-arm=true -verify-diagnostics | FileCheck %s --check-prefixes CHECK,CHECK-ARM
2+
// RUN: lce-tf-opt %s -tfl-optimize-lce=target-arm=false -verify-diagnostics | FileCheck %s --check-prefixes CHECK,CHECK-OTHER
23

34
// CHECK-LABEL: @fuse_add_into_bconv2d
45
func @fuse_add_into_bconv2d(%arg0: tensor<256x32x32x1xi32>, %arg1: tensor<16x3x3x3xf32>, %arg2: tensor<16xf32>, %arg3: none) -> tensor<256x30x30x16xf32> {
@@ -13,7 +14,6 @@ func @fuse_add_into_bconv2d(%arg0: tensor<256x32x32x1xi32>, %arg1: tensor<16x3x3
1314
// CHECK-NEXT: return %0
1415
}
1516

16-
1717
// CHECK-LABEL: @fuse_sub_into_bconv2d
1818
func @fuse_sub_into_bconv2d(%arg0: tensor<256x32x32x1xi32>, %arg1: tensor<16x3x3x3xf32>, %arg2: tensor<16xf32>, %arg3: none) -> tensor<256x30x30x16xf32> {
1919
%cst = constant dense<0.5> : tensor<16xf32>
@@ -144,15 +144,19 @@ func @do_not_fuse_relu_into_bconv2d_no_post_activation_multiplier(%arg0: tensor<
144144
// CHECK-NEXT: return %1
145145
}
146146

147-
// CHECK-LABEL: @reorder_maxpool_2d_quantize
148-
func @reorder_maxpool_2d_quantize(%arg0: tensor<256x32x32x65xf32>) -> tensor<256x16x8x3xi32> {
147+
// CHECK-LABEL: @target_specific_reorder_maxpool_2d_quantize
148+
func @target_specific_reorder_maxpool_2d_quantize(%arg0: tensor<256x32x32x65xf32>) -> tensor<256x16x8x3xi32> {
149149
%0 = "tfl.max_pool_2d"(%arg0) {filter_height = 3 : i32, filter_width = 2 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 2 : i32, stride_w = 4 : i32} : (tensor<256x32x32x65xf32>) -> tensor<256x16x8x65xf32>
150150
%1 = "lq.Quantize"(%0) : (tensor<256x16x8x65xf32>) -> tensor<256x16x8x3xi32>
151151
return %1 : tensor<256x16x8x3xi32>
152152

153-
// CHECK-NEXT: %0 = "lq.Quantize"(%arg0) : (tensor<256x32x32x65xf32>) -> tensor<256x32x32x3xi32>
154-
// CHECK-NEXT: %1 = "lq.BMaxPool2d"(%0) {filter_height = 3 : i32, filter_width = 2 : i32, padding = "SAME", stride_height = 2 : i32, stride_width = 4 : i32} : (tensor<256x32x32x3xi32>) -> tensor<256x16x8x3xi32>
155-
// CHECK-NEXT: return %1
153+
// CHECK-ARM-NEXT: %0 = "lq.Quantize"(%arg0) : (tensor<256x32x32x65xf32>) -> tensor<256x32x32x3xi32>
154+
// CHECK-ARM-NEXT: %1 = "lq.BMaxPool2d"(%0) {filter_height = 3 : i32, filter_width = 2 : i32, padding = "SAME", stride_height = 2 : i32, stride_width = 4 : i32} : (tensor<256x32x32x3xi32>) -> tensor<256x16x8x3xi32>
155+
// CHECK-ARM-NEXT: return %1
156+
157+
// CHECK-OTHER-NEXT: %0 = "tfl.max_pool_2d"(%arg0) {filter_height = 3 : i32, filter_width = 2 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 2 : i32, stride_w = 4 : i32} : (tensor<256x32x32x65xf32>) -> tensor<256x16x8x65xf32>
158+
// CHECK-OTHER-NEXT: %1 = "lq.Quantize"(%0) : (tensor<256x16x8x65xf32>) -> tensor<256x16x8x3xi32>
159+
// CHECK-OTHER-NEXT: return %1
156160
}
157161

158162
// CHECK-LABEL: @do_not_reorder_maxpool_2d_quantize_multiple_uses

0 commit comments

Comments
 (0)