Skip to content

Conversation

simonmaurer
Copy link
Contributor

@simonmaurer simonmaurer commented Mar 4, 2021

  • strip_lcedequantize_ops: strips the output LceDequantize operators of a model such that the output is a bitpacked tf.int32 tensor

What do these changes do?

Usually the lce_converter dequantizes the bitpacked output back to tf.float32/tf.int8 resulting in an identity tensor. This is intended for training. However for inference or postprocessing one could use the bitpacked tf.int32 tensors directly.
By using strip_lce_dequantize_ops one can strip the output LceDequantize operators of a model to get access to the bitpacked tf.int32 tensors.
Use cases: larq.layers.QuantConv2D followed by a sign operation (ie. larq.math.sign or larq.quantizers.SteSign())

Update

Works for tf.float32 as well as for tf.int8 quantized models.
Strips Dequantize ops too if any, ie. due to dequantization when using inference_output_type=tf.int8

How Has This Been Tested?

import tensorflow as tf
import larq as lq
from larq_compute_engine import convert_keras_model
from larq_compute_engine.mlir.python.util import strip_lcedequantize_ops

def toy_model_sign():
    img = tf.keras.layers.Input(shape=(224, 224, 3))
    x = lq.layers.QuantConv2D(256, kernel_size = 3, strides = 1, padding = "same", pad_values = 1, input_quantizer = "ste_sign", kernel_quantizer = "ste_sign",
                           kernel_constraint = "weight_clip")(img)
    x = lq.quantizers.SteSign()(x)
    return tf.keras.Model(inputs = img, outputs = x)

def quant(x):
    x = tf.quantization.fake_quant_with_min_max_vars(x, -3, 3)

def toy_model_int8_sign():
    img = tf.keras.layers.Input(shape=(224, 224, 3))
    x = quant(img)
    x = lq.layers.QuantConv2D(256, kernel_size = 3, strides = 1, padding = "same", pad_values = 1, input_quantizer = "ste_sign", kernel_quantizer = "ste_sign",
                           kernel_constraint = "weight_clip")(x)
    x = lq.quantizers.SteSign()(x)
    x = quant(x)
    return tf.keras.Model(inputs = img, outputs = x)

model = toy_model_sign()    # or toy_model_int8_sign()
tflite_model = convert_keras_model(model, inference_input_type = tf.float32, inference_output_type = tf.float32, experimental_enable_bitpacked_activations = True)
tflite_model = strip_lcedequantize_ops(tflite_model)

Related issue number

#599

- strips the output LceDequantize operators of a model such that the output is a bitpacked tf.int32 tensor
- usually the lce_converter dequantizes the bitpacked output back to tf.float32 resulting in an identity tensor
- use cases: larq.layers.QuantConv2D followed by a sign operation (ie. larq.math.sign or larq.quantizers.SteSign())
- import using `from larq_compute_engine.mlir.python.util import strip_lcedequantize_ops`
@lgeiger
Copy link
Member

lgeiger commented Mar 4, 2021

@simonmaurer Thanks for the PR 🎉

Could you also add a small unittest for this functionality?
Probably something similar to this test and a test that checks that the errors are raise correctly should do.

@lgeiger lgeiger added the feature New feature or request label Mar 4, 2021
@simonmaurer
Copy link
Contributor Author

simonmaurer commented Mar 6, 2021

so the conversion works as seen on the following pictures:

Toy model (bitpacked) Toy model (bitpacked tf.int32)

I realized that the error check conditions are never met if the model has TF quantized inputs/outputs (tf.int8) because these are different ops anyway (TF does fake quantization using Quantize and never needs a dequantize, so there will never be a LceDequantize op in output as seen below):

Toy model (bitpacked with tf.int8 fake quantization) Toy model (bitpacked tf.int32 with tf.int8 fake quantization)

In that case these checks are not necessary, or am I missing something here?

@simonmaurer
Copy link
Contributor Author

simonmaurer commented Mar 6, 2021

sidenote on lce_converter: I finally understand what these lines together with _find_int8_quantized_inputs_outputs do:
they only remove input Quantize and output DeQuantize ops if in the model there were manually placed tf.quantization.fake_quant_with_min_max_args ops AND only if inference_input_type = tf.int8/inference_output_type = tf.int8

the thing is if there are no tf.quantization.fake_quant_with_min_max_args, how will these lines ever throw an error?

@lgeiger
Copy link
Member

lgeiger commented Mar 11, 2021

Sorry for the late reply.

In that case these checks are not necessary, or am I missing something here?

To be honest, I am not 100% sure about this. This code was originally adapted from tensorflow/tensorflow@ec94fab but it looks like there've been additional changes in the latest version which would be great to apply here as well in another PR.
I would need to take a closer look if these errors would ever be thrown here, but if you think they won't, fell free to remove the checks for them in strip_lcedequantize_ops.

Could you still add a unittest similar to this one that tests that the functionality works as expected?
I do understand that you manually verified that the code works, but it is important for us to have these checks also run on CI so we can ensure that it also works with older versions of TensorFlow and crucially that the code will keep working and we don't accidentally break it in a future PR.

@simonmaurer
Copy link
Contributor Author

Sorry for the late reply.

In that case these checks are not necessary, or am I missing something here?

To be honest, I am not 100% sure about this. This code was originally adapted from tensorflow/tensorflow@ec94fab but it looks like there've been additional changes in the latest version which would be great to apply here as well in another PR.
I would need to take a closer look if these errors would ever be thrown here, but if you think they won't, fell free to remove the checks for them in strip_lcedequantize_ops.

Could you still add a unittest similar to this one that tests that the functionality works as expected?
I do understand that you manually verified that the code works, but it is important for us to have these checks also run on CI so we can ensure that it also works with older versions of TensorFlow and crucially that the code will keep working and we don't accidentally break it in a future PR.

@lgeiger sorry for the delay.
I've added a small unit test similar to end2end_test to verify the strip_lcedequantize_ops

@lgeiger
Copy link
Member

lgeiger commented Apr 12, 2021

Thanks for adding the test. Could you check the linting errors on CI? It looks like pyflakes complains about some imports in the test file:

larq_compute_engine/tests/strip_lcedequantize_test.py:3:1 'larq as lq' imported but unused
larq_compute_engine/tests/strip_lcedequantize_test.py:4:1 'numpy as np' imported but unused
larq_compute_engine/tests/strip_lcedequantize_test.py:10:1 'larq_compute_engine.tflite.python.interpreter.Interpreter' imported but unused
larq_compute_engine/tests/strip_lcedequantize_test.py:15:9 undefined name 'larq'
larq_compute_engine/tests/strip_lcedequantize_test.py:16:9 undefined name 'larq'
larq_compute_engine/tests/strip_lcedequantize_test.py:25:9 undefined name 'larq'
larq_compute_engine/tests/strip_lcedequantize_test.py:26:9 undefined name 'larq'
larq_compute_engine/tests/strip_lcedequantize_test.py:39:69 undefined name 'toy_model'

@simonmaurer
Copy link
Contributor Author

simonmaurer commented Apr 13, 2021

@lgeiger fixed all import errors, reformatted code using black code style.
I've decided to put the unit test for the utility function strip_lcedequantize_ops in a separate strip_lcedequantize_test.py module to keep end2end_test.py untouched.
of course it can be adapted, let me know if it works 👍

Copy link
Member

@lgeiger lgeiger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@simonmaurer Sorry for the delay in reviewing this. I just rebased your PR onto master and added the unittests to our CI infrastructure in da432fd

Let's get this merged once CI passes 🚀

simonmaurer and others added 9 commits May 20, 2021 15:23
- strips the output LceDequantize operators of a model such that the output is a bitpacked tf.int32 tensor
- usually the lce_converter dequantizes the bitpacked output back to tf.float32 resulting in an identity tensor
- use cases: larq.layers.QuantConv2D followed by a sign operation (ie. larq.math.sign or larq.quantizers.SteSign())
- import using `from larq_compute_engine.mlir.python.util import strip_lcedequantize_ops`
@lgeiger lgeiger force-pushed the bitpacked_int32 branch from da432fd to 559ca80 Compare May 20, 2021 14:23
Copy link
Member

@lgeiger lgeiger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for sticking with this PR, I just have an additional question below about the intended behaviour the unittest is verifying.

if inference_output_type == tf.float32:
assert output_details[0]["dtype"] == tf.int32.as_numpy_dtype
else:
assert output_details[0]["dtype"] == inference_output_type.as_numpy_dtype
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this mean if output inference type is tf.int8 the dequantize op is not removed?
@simonmaurer Is this intended behaviour?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, this is intended behavior.

as discussed here:

Sorry for the late reply.

In that case these checks are not necessary, or am I missing something here?

To be honest, I am not 100% sure about this. This code was originally adapted from tensorflow/tensorflow@ec94fab but it looks like there've been additional changes in the latest version which would be great to apply here as well in another PR.
I would need to take a closer look if these errors would ever be thrown here, but if you think they won't, fell free to remove the checks for them in strip_lcedequantize_ops.

Could you still add a unittest similar to this one that tests that the functionality works as expected?
I do understand that you manually verified that the code works, but it is important for us to have these checks also run on CI so we can ensure that it also works with older versions of TensorFlow and crucially that the code will keep working and we don't accidentally break it in a future PR.

and when looking at the graph of tf.float32/tf.int8 models there are only LceDequantize ops when the model has tf.float32 outputs.
in other words the LceDequantize ops are removed (resulting in the desired tf.int32 outputs) iff the model has tf.float32 outputs

Copy link
Contributor Author

@simonmaurer simonmaurer Jun 2, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

actually it shouldn't differ..as long as there are LceDequantize ops as outputs they should be removed to get tf.int32 outputs, no matter what tf.float32/tf.int8 ops are inside the model

since you've pointed this out: I've removed this check for tf.int8 models in the unittest function with the last commit, but actually it could (or should) still be in there. not sure though why the MLIR test did not go through for tf.int8 output models..

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

MLIR failed again for the assertion tests

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You might also want to apply similar changes as done in #635 to this PR, but I don't think this will fix CI either. Looks like the dequantize node is not correctly removed in your example.

deactivate setting default int8 ranges for `tf.float32` models as the strip_lcedequantize_ops function will not remove `LceDequantize` ops
Copy link
Contributor Author

@simonmaurer simonmaurer left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@lgeiger after inspecting the unittest code again, I deactivated the default int8 ranges for tf.float32 models. as you said the function strip_lcedequantize_ops would never remove the LceDequantize ops since any model that was generated via the parametrized pytest functions resulted in tf.int8 outputs

@simonmaurer simonmaurer requested a review from lgeiger June 6, 2021 23:01
Copy link
Member

@lgeiger lgeiger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Still looks like CI is failing. Are you able to reproduce the failure locally?

@simonmaurer
Copy link
Contributor Author

Still looks like CI is failing. Are you able to reproduce the failure locally?

sorry if this didn't make it into the v0.6 release..didn't know your time schedule. it's been taking a while since I always thought I had to recompile everything.
gonna use your hint:
bazelisk test larq_compute_engine/tests:strip_lcedequantize_test --test_output=all

the code works for tf.float32 models, it is related to the unittests only. working on it...

@lgeiger
Copy link
Member

lgeiger commented Jun 11, 2021

sorry if this didn't make it into the v0.6 release..didn't know your time schedule.

No worries, we can always push a patch release later.

it's been taking a while since I always thought I had to recompile everything.

This should now be a bit better since we are now relying on a stable TF version so bazel should be able to properly cache the build artifacts.

simonmaurer and others added 2 commits June 11, 2021 15:10
Testing strip_lcedequantize_ops for tf.float32 output:
- fix double allocation of Interpreter, using tf.lite.Interpreter instead
- fix typo when converting model to TFLite model
@simonmaurer
Copy link
Contributor Author

@lgeiger testing tf.float32 model outputs (w/ and w/o quantize ops inside model). after some more testing I do think the Unit tests are failing for tf.int8 model outputs because there is a Dequantize op after model conversion to TFLite (before strip_lcedequantize_ops).

removed import of Larq interpreter due to Lint tests failing
@simonmaurer
Copy link
Contributor Author

You might also want to apply similar changes as done in #635 to this PR, but I don't think this will fix CI either. Looks like the dequantize node is not correctly removed in your example.
@lgeiger I assume this is exactly what you meant here and not the Larq dequantize op ;)

simonmaurer and others added 5 commits July 29, 2021 00:40
- only validate output after LceDequantize ops have been stripped, input type tests already validated in end2end_test.py
fix: setting inference_input_type statically to tf.float32 as we're only validating the output
- updated signature defs for TF2.5 compatibility
- support int8-quantized models when stripping LceDequantize op for int8 output
- support int8-quantized models when using dequantized tf.float32 output, strips Dequantize operator first then LceDequantize
@simonmaurer simonmaurer requested a review from lgeiger July 30, 2021 08:49
@simonmaurer
Copy link
Contributor Author

simonmaurer commented Jul 30, 2021

@lgeiger @AdamHillier finally the Unit tests work for both tf.float32 and tf.int8-quantized models. additionally I had to adapt the code to the following cases.

  1. int8-quantized models (inference_output_type=tf.float32) ie when the output is dequantized back to tf.float32 (meant that there is the additional Dequantize op of TFLite: LceConv2D -> LceDequantize -> Dequantize)
  2. int8-quantized models (inference_output_type=tf.int8) ie when the output is tf.int8 (LceConv2D -> LceDequantize)

Overall we have:

Toy model (inference_output_type=tf.float32) Toy model (inference_output_type=tf.int8)
Toy model with int8 fake quantization (inference_output_type=tf.float32) Toy model with int8 fake quantization (inference_output_type=tf.int8)

@simonmaurer simonmaurer requested a review from AdamHillier July 30, 2021 09:13
@simonmaurer
Copy link
Contributor Author

@lgeiger @Tombana request a review given the latest commit as the PR is now ready including the updates on the unit tests for tf.float32/tf.int8 models

Copy link
Member

@lgeiger lgeiger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great! Sorry for the long wait, let's get this merged now 🎉

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
feature New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants