-
Notifications
You must be signed in to change notification settings - Fork 35
Add support for saved model conversion #655
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Co-Authored-By: Tom Bannink <[email protected]>
# First attempt conversion as saved model | ||
try: | ||
with tempfile.TemporaryDirectory() as saved_model_dir: | ||
model.save(saved_model_dir, save_format="tf") | ||
|
||
return convert_saved_model( | ||
saved_model_dir, | ||
inference_input_type=inference_input_type, | ||
inference_output_type=inference_output_type, | ||
experimental_default_int8_range=experimental_default_int8_range, | ||
experimental_enable_bitpacked_activations=experimental_enable_bitpacked_activations, | ||
target=target, | ||
) | ||
except Exception: | ||
warnings.warn( | ||
"Saved-model conversion failed, falling back to graphdef-based conversion." | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is now enabled by default. Let me know if you prefer to keep this behind a flag instead.
// Normally we'd only set `inference_type` to QINT8 when there are fake_quant | ||
// nodes in the graph. However this did not work reliably, and even for float | ||
// models it is fine to set the inference type to QINT8, so we do that by | ||
// default. | ||
quant_specs.inference_type = tensorflow::DT_QINT8; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is not ideal, but I don't think this causes any issues.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great, thanks :)
Just a couple of minor comments.
experimental_default_int8_range: Optional[Tuple[float, float]] = None, | ||
experimental_enable_bitpacked_activations: bool = False, | ||
) -> bytes: | ||
"""Converts a SavedModel to TFLite flatbuffer. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm never sure how we should refer to things as being TensorFlow versus being Larq/LCE.
On the one hand, it's accurate because it is a TFLite flatbuffer, just with extra custom ops. On the other hand, I feel like it's weird for a user to use the larq_compute_engine
package and have method docstrings that refer to TensorFlow but not Larq or LCE.
I'm probably overthinking this though - how do you feel about it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is pretty much copied from the convert_keras_model
function, but I'm happy to rephrase it. Not sure how we refer to it in the guides though.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah me neither, I think that's the real issue. Since we are going to be going through and updating our docs for the 0.6 release anyhow, we can maybe have a think about it then and not worry about it now.
What do these changes do?
This PR enables support for directly converting saved models in TensorFlow.
How Has This Been Tested?
We've been using this code internally for a while now and added some test cases to the end2end tests
Related issue number
Closes #407
Co-Authored-By: Tom Bannink [email protected]