Skip to content
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
b8b94f0
added function strip_lcedequantize_ops:
simonmaurer Mar 4, 2021
fd8f9d0
reformatted using black code style
simonmaurer Mar 4, 2021
bfa5590
added pytest module for verifying lce_dequantize_ops
simonmaurer Apr 12, 2021
658e148
fixed larq import errors and renamed unit test function
simonmaurer Apr 13, 2021
f860479
fix PyFlakes error due to typo when defining toy_model
simonmaurer Apr 13, 2021
46c0071
using Interpreter from larq_compute_engine.tflite.python.interpreter …
simonmaurer Apr 13, 2021
e50143f
reformatted strip_lcedequantize_test.py using black code style
simonmaurer Apr 13, 2021
3be2a3a
added function strip_lcedequantize_ops:
simonmaurer Mar 4, 2021
59d814f
reformatted using black code style
simonmaurer Mar 4, 2021
17e6b46
added pytest module for verifying lce_dequantize_ops
simonmaurer Apr 12, 2021
8b875f5
fixed larq import errors and renamed unit test function
simonmaurer Apr 13, 2021
dda14d3
fix PyFlakes error due to typo when defining toy_model
simonmaurer Apr 13, 2021
8c572fa
using Interpreter from larq_compute_engine.tflite.python.interpreter …
simonmaurer Apr 13, 2021
49a9877
reformatted strip_lcedequantize_test.py using black code style
simonmaurer Apr 13, 2021
be9b46e
Remove dependency of compute engine interpreter
lgeiger May 20, 2021
559ca80
Add bazel target for dequantize test
lgeiger May 20, 2021
5230f06
Update strip_lcedequantize_test.py
simonmaurer Jun 2, 2021
79c69dd
Update strip_lcedequantize_test.py
simonmaurer Jun 2, 2021
7cf93f8
Update strip_lcedequantize_test.py
simonmaurer Jun 6, 2021
9b986f1
fixed merge conflict in strip_lcedequantize_test.py
simonmaurer Jun 11, 2021
f2ef72c
fix: accidentally added merge indicators
simonmaurer Jun 11, 2021
5e31d80
Update strip_lcedequantize_test.py
simonmaurer Jul 28, 2021
8b518a6
Update strip_lcedequantize_test.py
simonmaurer Jul 28, 2021
bafc8d6
Adapt unit test for output type checking
simonmaurer Jul 28, 2021
9e6a268
Update strip_lcedequantize_test.py
simonmaurer Jul 28, 2021
ecff8d3
set tf.float32 as parametrized input type
simonmaurer Jul 29, 2021
2c45ed1
Updated strip_lcedequantize_ops() to support more models:
simonmaurer Jul 29, 2021
9a24a91
Unit tests for tf.int8 input/output models
simonmaurer Jul 29, 2021
77ee842
Correction in toy_model_int8_sign
simonmaurer Jul 30, 2021
cc44059
Extended Unit tests for test_strip_lcedequantize_ops() to parametrize…
simonmaurer Aug 1, 2021
c683131
Clean up using black code style
simonmaurer Aug 1, 2021
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .github/workflows/unittests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,8 @@ jobs:
run: bazelisk test larq_compute_engine/mlir/tests:all --test_output=all
- name: Run End2End tests
run: bazelisk test larq_compute_engine/tests:end2end_test --test_output=all
- name: Run Strip dequantize op tests
run: bazelisk test larq_compute_engine/tests:strip_lcedequantize_test --test_output=all

ConverterPython:
runs-on: ubuntu-latest
Expand Down
76 changes: 76 additions & 0 deletions larq_compute_engine/mlir/python/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,3 +225,79 @@ def modify_integer_quantized_model_io_type(

# Convert the model to a bytearray
return _convert_model_from_object_to_bytearray(model)


def strip_lcedequantize_ops(model):
"""Strip the LceDequantize ops to directly output bitpacked tf.int32 tensors."""
# Convert the model to an object
model = _convert_model_from_bytearray_to_object(model)

if len(model.subgraphs) > 1:
raise ValueError(
"Model must only have one subgraph. Instead, it has "
"{} subgraphs.".format(len(model.subgraphs))
)

## Find the LceDequantize operators
subgraph = model.subgraphs[0]
tensors = subgraph.tensors
operators = subgraph.operators
remove_tensors_idxs = set()

# Ensure model has at least one LceDequantize operator
lce_dequant_opcode_idx = None
for idx, opcode in enumerate(model.operatorCodes):
if opcode.customCode == b"LceDequantize":
lce_dequant_opcode_idx = idx
if lce_dequant_opcode_idx is not None:
break
if lce_dequant_opcode_idx is None:
raise ValueError("Model does not contain any LceDequantize operators.")

# Ensure model outputs are dequantized
lce_output_dequant_ops = []
for op in operators:
# Find output LceDequantize operator
if (
op.opcodeIndex == lce_dequant_opcode_idx
and op.outputs[0] in subgraph.outputs
):
pos, float_tensor, int_tensor = (
"output",
tensors[op.outputs[0]],
tensors[op.inputs[0]],
)
lce_output_dequant_ops.append(op)
# Otherwise, ignore
else:
continue
# If found, validate the input/output tensor type
if float_tensor.type != tflite_schema.TensorType.FLOAT32:
raise ValueError(
"Model {} type must be tf.float32. Expected type for tensor with "
"name '{}' is tf.float32, instead type is tf.{}".format(
pos,
float_tensor.name,
_convert_tflite_enum_type_to_tf_type(float_tensor.type).name,
)
)
if int_tensor.type != tflite_schema.TensorType.INT32:
raise ValueError(
"Expected type for tensor with "
"name '{}' is tf.int32, instead type is tf.{}".format(
int_tensor.name,
_convert_tflite_enum_type_to_tf_type(int_tensor.type).name,
)
)

# Remove the LceDequantize operators
for op in lce_output_dequant_ops:
subgraph.outputs[subgraph.outputs == op.outputs[0]] = op.inputs[0]
remove_tensors_idxs.add(op.outputs[0])
operators.remove(op)

# Remove tensors marked for deletion.
_remove_tensors_from_model(model, remove_tensors_idxs)

# Convert the model to a bytearray
return _convert_model_from_object_to_bytearray(model)
8 changes: 8 additions & 0 deletions larq_compute_engine/tests/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,14 @@ py_test(
],
)

py_test(
name = "strip_lcedequantize_test",
srcs = ["strip_lcedequantize_test.py"],
deps = [
"//larq_compute_engine/mlir:converter",
],
)

py_test(
name = "convert_model",
srcs = ["convert_model.py"],
Expand Down
76 changes: 76 additions & 0 deletions larq_compute_engine/tests/strip_lcedequantize_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
import sys

import larq as lq
import pytest
import tensorflow as tf

from larq_compute_engine.mlir.python.converter import convert_keras_model
from larq_compute_engine.mlir.python.util import strip_lcedequantize_ops


def toy_model_sign(**kwargs):
img = tf.keras.layers.Input(shape=(224, 224, 3))
x = lq.layers.QuantConv2D(
256,
kernel_size=3,
strides=1,
padding="same",
pad_values=1,
input_quantizer="ste_sign",
kernel_quantizer="ste_sign",
kernel_constraint="weight_clip",
)(img)
x = lq.quantizers.SteSign()(x)
return tf.keras.Model(inputs=img, outputs=x)


def quant(x):
return tf.quantization.fake_quant_with_min_max_vars(x, -3.0, 3.0)


def toy_model_int8_sign(**kwargs):
img = tf.keras.layers.Input(shape=(224, 224, 3))
x = quant(img)
x = lq.layers.QuantConv2D(
256,
kernel_size=3,
strides=1,
padding="same",
pad_values=1,
input_quantizer="ste_sign",
kernel_quantizer="ste_sign",
kernel_constraint="weight_clip",
)(img)
x = lq.quantizers.SteSign()(x)
x = quant(x)
return tf.keras.Model(inputs=img, outputs=x)


@pytest.mark.parametrize("model_cls", [toy_model_sign, toy_model_int8_sign])
@pytest.mark.parametrize("inference_input_type", [tf.int8, tf.float32])
@pytest.mark.parametrize("inference_output_type", [tf.int8, tf.float32])
def test_strip_lcedequantize_ops(
model_cls, inference_input_type, inference_output_type
):
model_lce = convert_keras_model(
model_cls(),
inference_input_type=inference_input_type,
inference_output_type=inference_output_type,
experimental_default_int8_range=None,
experimental_enable_bitpacked_activations=True,
)
model_lce = strip_lcedequantize_ops(model_lce)
interpreter = tf.lite.Interpreter(model_content=model_lce)
input_details = interpreter.get_input_details()
assert len(input_details) == 1
assert input_details[0]["dtype"] == inference_input_type.as_numpy_dtype
output_details = interpreter.get_output_details()
assert len(output_details) == 1
if inference_output_type == tf.float32:
assert output_details[0]["dtype"] == tf.int32.as_numpy_dtype
else:
assert output_details[0]["dtype"] == inference_output_type.as_numpy_dtype
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this mean if output inference type is tf.int8 the dequantize op is not removed?
@simonmaurer Is this intended behaviour?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, this is intended behavior.

as discussed here:

Sorry for the late reply.

In that case these checks are not necessary, or am I missing something here?

To be honest, I am not 100% sure about this. This code was originally adapted from tensorflow/tensorflow@ec94fab but it looks like there've been additional changes in the latest version which would be great to apply here as well in another PR.
I would need to take a closer look if these errors would ever be thrown here, but if you think they won't, fell free to remove the checks for them in strip_lcedequantize_ops.

Could you still add a unittest similar to this one that tests that the functionality works as expected?
I do understand that you manually verified that the code works, but it is important for us to have these checks also run on CI so we can ensure that it also works with older versions of TensorFlow and crucially that the code will keep working and we don't accidentally break it in a future PR.

and when looking at the graph of tf.float32/tf.int8 models there are only LceDequantize ops when the model has tf.float32 outputs.
in other words the LceDequantize ops are removed (resulting in the desired tf.int32 outputs) iff the model has tf.float32 outputs

Copy link
Contributor Author

@simonmaurer simonmaurer Jun 2, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

actually it shouldn't differ..as long as there are LceDequantize ops as outputs they should be removed to get tf.int32 outputs, no matter what tf.float32/tf.int8 ops are inside the model

since you've pointed this out: I've removed this check for tf.int8 models in the unittest function with the last commit, but actually it could (or should) still be in there. not sure though why the MLIR test did not go through for tf.int8 output models..

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

MLIR failed again for the assertion tests

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You might also want to apply similar changes as done in #635 to this PR, but I don't think this will fix CI either. Looks like the dequantize node is not correctly removed in your example.



if __name__ == "__main__":
sys.exit(pytest.main([__file__, "-s"]))