Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
b8b94f0
added function strip_lcedequantize_ops:
simonmaurer Mar 4, 2021
fd8f9d0
reformatted using black code style
simonmaurer Mar 4, 2021
bfa5590
added pytest module for verifying lce_dequantize_ops
simonmaurer Apr 12, 2021
658e148
fixed larq import errors and renamed unit test function
simonmaurer Apr 13, 2021
f860479
fix PyFlakes error due to typo when defining toy_model
simonmaurer Apr 13, 2021
46c0071
using Interpreter from larq_compute_engine.tflite.python.interpreter …
simonmaurer Apr 13, 2021
e50143f
reformatted strip_lcedequantize_test.py using black code style
simonmaurer Apr 13, 2021
3be2a3a
added function strip_lcedequantize_ops:
simonmaurer Mar 4, 2021
59d814f
reformatted using black code style
simonmaurer Mar 4, 2021
17e6b46
added pytest module for verifying lce_dequantize_ops
simonmaurer Apr 12, 2021
8b875f5
fixed larq import errors and renamed unit test function
simonmaurer Apr 13, 2021
dda14d3
fix PyFlakes error due to typo when defining toy_model
simonmaurer Apr 13, 2021
8c572fa
using Interpreter from larq_compute_engine.tflite.python.interpreter …
simonmaurer Apr 13, 2021
49a9877
reformatted strip_lcedequantize_test.py using black code style
simonmaurer Apr 13, 2021
be9b46e
Remove dependency of compute engine interpreter
lgeiger May 20, 2021
559ca80
Add bazel target for dequantize test
lgeiger May 20, 2021
5230f06
Update strip_lcedequantize_test.py
simonmaurer Jun 2, 2021
79c69dd
Update strip_lcedequantize_test.py
simonmaurer Jun 2, 2021
7cf93f8
Update strip_lcedequantize_test.py
simonmaurer Jun 6, 2021
9b986f1
fixed merge conflict in strip_lcedequantize_test.py
simonmaurer Jun 11, 2021
f2ef72c
fix: accidentally added merge indicators
simonmaurer Jun 11, 2021
5e31d80
Update strip_lcedequantize_test.py
simonmaurer Jul 28, 2021
8b518a6
Update strip_lcedequantize_test.py
simonmaurer Jul 28, 2021
bafc8d6
Adapt unit test for output type checking
simonmaurer Jul 28, 2021
9e6a268
Update strip_lcedequantize_test.py
simonmaurer Jul 28, 2021
ecff8d3
set tf.float32 as parametrized input type
simonmaurer Jul 29, 2021
2c45ed1
Updated strip_lcedequantize_ops() to support more models:
simonmaurer Jul 29, 2021
9a24a91
Unit tests for tf.int8 input/output models
simonmaurer Jul 29, 2021
77ee842
Correction in toy_model_int8_sign
simonmaurer Jul 30, 2021
cc44059
Extended Unit tests for test_strip_lcedequantize_ops() to parametrize…
simonmaurer Aug 1, 2021
c683131
Clean up using black code style
simonmaurer Aug 1, 2021
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .github/workflows/unittests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,8 @@ jobs:
run: bazelisk test larq_compute_engine/mlir/tests:all --test_output=all
- name: Run End2End tests
run: bazelisk test larq_compute_engine/tests:end2end_test --test_output=all
- name: Run Strip dequantize op tests
run: bazelisk test larq_compute_engine/tests:strip_lcedequantize_test --test_output=all

ConverterPython:
runs-on: ubuntu-latest
Expand Down
143 changes: 143 additions & 0 deletions larq_compute_engine/mlir/python/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,3 +225,146 @@ def modify_integer_quantized_model_io_type(

# Convert the model to a bytearray
return _convert_model_from_object_to_bytearray(model)


def strip_lcedequantize_ops(model):
"""Strip the LceDequantize ops to directly output bitpacked tf.int32 tensors."""
# Convert the model to an object
model = _convert_model_from_bytearray_to_object(model)

if len(model.subgraphs) > 1:
raise ValueError(
"Model must only have one subgraph. Instead, it has "
"{} subgraphs.".format(len(model.subgraphs))
)

# Ensure model has at least one LceDequantize and/or Dequantize operator
lce_dequant_opcode_idx, dequant_opcode_idx = None, None
for idx, opcode in enumerate(model.operatorCodes):
if opcode.customCode == b"LceDequantize":
lce_dequant_opcode_idx = idx
elif opcode.builtinCode == tflite_schema.BuiltinOperator.DEQUANTIZE:
dequant_opcode_idx = idx
if lce_dequant_opcode_idx is not None and dequant_opcode_idx is not None:
break
if lce_dequant_opcode_idx is None and dequant_opcode_idx is None:
raise ValueError(
"Model does not contain any LceDequantize or Dequantize operators."
)

# Ensure model outputs are dequantized and remove Dequantize ops first if any
if dequant_opcode_idx is not None:
subgraph = model.subgraphs[0]
tensors = subgraph.tensors
operators = subgraph.operators
remove_tensors_idxs = set()

output_dequant_ops = []
for op in operators:
# Find output Dequantize operator
if (
op.opcodeIndex == dequant_opcode_idx
and op.outputs[0] in subgraph.outputs
):
pos, float_tensor, int_tensor = (
"output",
tensors[op.outputs[0]],
tensors[op.inputs[0]],
)
output_dequant_ops.append(op)
# Otherwise, ignore
else:
continue
# If found, validate the input/output tensor type
if float_tensor.type != tflite_schema.TensorType.FLOAT32:
raise ValueError(
"Model {} type must be tf.float32. Expected type for tensor with "
"name '{}' is tf.float32, instead type is tf.{}".format(
pos,
float_tensor.name,
_convert_tflite_enum_type_to_tf_type(float_tensor.type).name,
)
)
if int_tensor.type != tflite_schema.TensorType.INT8:
raise ValueError(
"Model is not integer quantized. Expected type for tensor with "
"name '{}' is tf.int8, instead type is tf.{}".format(
int_tensor.name,
_convert_tflite_enum_type_to_tf_type(int_tensor.type).name,
)
)

# Remove the Dequantize operators
for op in output_dequant_ops:
subgraph.outputs[subgraph.outputs == op.outputs[0]] = op.inputs[0]
if model.signatureDefs:
signature_def = model.signatureDefs[0]
for i in range(len(signature_def.outputs)):
if signature_def.outputs[i].tensorIndex == op.outputs[0]:
signature_def.outputs[i].tensorIndex = op.inputs[0]
remove_tensors_idxs.add(op.outputs[0])
operators.remove(op)

# Remove tensors marked for deletion.
_remove_tensors_from_model(model, remove_tensors_idxs)

subgraph = model.subgraphs[0]
tensors = subgraph.tensors
operators = subgraph.operators
remove_tensors_idxs = set()

# Ensure model outputs are Lce dequantized and remove LceDequantize ops
lce_output_dequant_ops = []
for op in operators:
# Find output LceDequantize operator
if (
op.opcodeIndex == lce_dequant_opcode_idx
and op.outputs[0] in subgraph.outputs
):
pos, output_tensor, input_tensor = (
"output",
tensors[op.outputs[0]],
tensors[op.inputs[0]],
)
lce_output_dequant_ops.append(op)
# Otherwise, ignore
else:
continue
# If found, validate the input/output tensor type
if (
output_tensor.type != tflite_schema.TensorType.FLOAT32
and output_tensor.type != tflite_schema.TensorType.INT8
):
raise ValueError(
"Model {} type must be tf.float32/tf.int8. Expected type for tensor with "
"name '{}' is tf.float32/tf.int8, instead type is tf.{}".format(
pos,
output_tensor.name,
_convert_tflite_enum_type_to_tf_type(output_tensor.type).name,
)
)
if input_tensor.type != tflite_schema.TensorType.INT32:
raise ValueError(
"Expected type for tensor with "
"name '{}' is tf.int32, instead type is tf.{}".format(
input_tensor.name,
_convert_tflite_enum_type_to_tf_type(input_tensor.type).name,
)
)

# Remove the LceDequantize operators
for op in lce_output_dequant_ops:
subgraph.outputs[subgraph.outputs == op.outputs[0]] = op.inputs[0]
if model.signatureDefs:
signature_def = model.signatureDefs[0]
for i in range(len(signature_def.outputs)):
if signature_def.outputs[i].tensorIndex == op.outputs[0]:
signature_def.outputs[i].tensorIndex = op.inputs[0]
remove_tensors_idxs.add(op.outputs[0])
operators.remove(op)

# Remove tensors marked for deletion.
_remove_tensors_from_model(model, remove_tensors_idxs)

# Convert the model to a bytearray
return _convert_model_from_object_to_bytearray(model)
8 changes: 8 additions & 0 deletions larq_compute_engine/tests/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,14 @@ py_test(
],
)

py_test(
name = "strip_lcedequantize_test",
srcs = ["strip_lcedequantize_test.py"],
deps = [
"//larq_compute_engine/mlir:converter",
],
)

py_test(
name = "convert_model",
srcs = ["convert_model.py"],
Expand Down
73 changes: 73 additions & 0 deletions larq_compute_engine/tests/strip_lcedequantize_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
import sys

import larq as lq
import pytest
import tensorflow as tf

from larq_compute_engine.mlir.python.converter import convert_keras_model
from larq_compute_engine.mlir.python.util import strip_lcedequantize_ops


def toy_model_sign(**kwargs):
img = tf.keras.layers.Input(shape=(224, 224, 3))
x = lq.layers.QuantConv2D(
256,
kernel_size=3,
strides=1,
padding="same",
pad_values=1,
input_quantizer="ste_sign",
kernel_quantizer="ste_sign",
kernel_constraint="weight_clip",
)(img)
x = lq.quantizers.SteSign()(x)
return tf.keras.Model(inputs=img, outputs=x)


def quant(x):
return tf.quantization.fake_quant_with_min_max_vars(x, -3.0, 3.0)


def toy_model_int8_sign(**kwargs):
img = tf.keras.layers.Input(shape=(224, 224, 3))
x = quant(img)
x = lq.layers.QuantConv2D(
256,
kernel_size=3,
strides=1,
padding="same",
pad_values=1,
input_quantizer="ste_sign",
kernel_quantizer="ste_sign",
kernel_constraint="weight_clip",
)(x)
x = lq.quantizers.SteSign()(x)
x = quant(x)
return tf.keras.Model(inputs=img, outputs=x)


@pytest.mark.parametrize("model_cls", [toy_model_sign, toy_model_int8_sign])
@pytest.mark.parametrize("inference_input_type", [tf.float32, tf.int8])
@pytest.mark.parametrize("inference_output_type", [tf.float32, tf.int8])
@pytest.mark.parametrize("experimental_enable_bitpacked_activations", [True, False])
def test_strip_lcedequantize_ops(
model_cls,
inference_input_type,
inference_output_type,
experimental_enable_bitpacked_activations,
):
model_lce = convert_keras_model(
model_cls(),
inference_input_type=inference_input_type,
inference_output_type=inference_output_type,
experimental_enable_bitpacked_activations=experimental_enable_bitpacked_activations,
)
model_lce = strip_lcedequantize_ops(model_lce)
interpreter = tf.lite.Interpreter(model_content=model_lce)
output_details = interpreter.get_output_details()
assert len(output_details) == 1
assert output_details[0]["dtype"] == tf.int32.as_numpy_dtype


if __name__ == "__main__":
sys.exit(pytest.main([__file__, "-s"]))