-
Notifications
You must be signed in to change notification settings - Fork 45
Add save in safetensors format #784
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hello @igor-iusupov, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!
Summary of Changes
Hello team, gemini-code-assist here to provide a summary of this pull request. This PR introduces the capability to save PyTorch models in the safetensors format within the batchflow
library. It achieves this by adding a new boolean parameter, use_safetensors
, to the existing save
method of the base PyTorch model class. When this parameter is set to True
, the method now uses the safetensors.torch.save_file
function to save the model's state dictionary, offering an alternative to the default PyTorch pickle format or other supported formats like ONNX and OpenVINO.
Highlights
- Safetensors Support: Adds the ability to save PyTorch models using the safetensors format.
- Model Saving: Modifies the base model's
save
method to include an option for safetensors output.
Changelog
- batchflow/models/torch/base.py
- Added the
use_safetensors
boolean parameter to thesave
method signature (around line 1671). - Updated the docstring for the
save
method to document the newuse_safetensors
parameter (around line 1691). - Implemented the logic within the
save
method to handle saving the model's state dictionary usingsafetensors.torch.save_file
whenuse_safetensors
is True (around lines 1761-1764). - Removed a blank line (around line 17).
- Added the
Using Gemini Code Assist
The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.
Invoking Gemini
You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command>
or @gemini-code-assist <command>
. Below is a summary of the supported commands.
Feature | Command | Description |
---|---|---|
Code Review | /gemini review |
Performs a code review for the current pull request in its current state. |
Pull Request Summary | /gemini summary |
Provides a summary of the current pull request in its current state. |
Comment | @gemini-code-assist | Responds in comments when explicitly tagged, both in issue comments and review comments. |
Help | /gemini help |
Displays a list of available commands. |
Customization
To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/
folder in the base of the repository. Detailed instructions can be found here.
Limitations & Feedback
Gemini Code Assist is currently in preview and may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments to provide feedback.
You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.
Footnotes
-
Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution. ↩
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces the capability to save PyTorch models in the safetensors
format, which is a valuable addition for model interoperability and security. The core implementation of saving the state_dict
using safetensors.torch.save_file
is correct.
My review focuses on improving clarity for users regarding how this new option interacts with existing save functionalities, particularly concerning metadata preservation, the usage of the path
argument, and the handling of mutually exclusive format flags. Addressing these points will enhance the robustness and usability of this feature.
Summary of Findings
- Documentation Clarity for
safetensors
Save Option: The method docstring should be enhanced to clearly explain: 1) How thepath
argument is utilized whenuse_safetensors=True
. 2) The mutual exclusivity and precedence if multipleuse_...
format flags are enabled. 3) Whether associated metadata (like model configuration, training iteration) is saved with thesafetensors
format, as it currently appears to save only thestate_dict
. - Metadata Preservation Consistency with
safetensors
: The currentsafetensors
implementation only saves the model'sstate_dict
, which differs from the ONNX and OpenVINO saving options that also store other model attributes (e.g., config, iteration count). This inconsistency could lead to loss of information if not intended. Thesafetensors.torch.save_file
function supports ametadata
argument, which could be used to store these attributes for consistency. If this omission is by design, it needs to be prominently documented. - Handling of Mutually Exclusive Save Format Flags: The
save
method's behavior when multiple format flags (use_onnx
,use_openvino
,use_safetensors
) are set toTrue
should be more robust or clearly documented. Currently, only the first format encountered in theif/elif
conditional chain is used, which might not be obvious to users. Consider raising an error or explicitly documenting the priority.
Merge Readiness
This pull request introduces a useful feature for saving models in the safetensors
format. However, to ensure clarity for users and consistency with existing functionalities, I recommend addressing the points raised, particularly regarding metadata preservation and the handling/documentation of mutually exclusive save options. The identified high
severity issue concerning metadata should be carefully considered. Once these aspects are clarified and potentially adjusted, the PR will be in a much stronger position for merging. As I am not authorized to approve pull requests, further review and approval by authorized maintainers will be necessary.
batchflow/models/torch/base.py
Outdated
elif use_safetensors: | ||
from safetensors.torch import save_file | ||
state_dict = self.model.state_dict() | ||
save_file(state_dict, path) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This new elif
block correctly implements saving to safetensors
. There are a couple of points to consider for enhancing robustness and consistency with other save formats:
-
Handling Multiple Format Flags: If a user inadvertently sets multiple format flags to
True
(e.g.,use_onnx=True
anduse_safetensors=True
), only the first one encountered in theif/elif
chain will be executed. This could be non-obvious. Would it be more user-friendly to:- Raise a
ValueError
if more than one ofuse_onnx
,use_openvino
,use_safetensors
isTrue
? - Or, at least, clearly document the order of precedence in the method's docstring?
- Raise a
-
Metadata Preservation Consistency: The ONNX and OpenVINO saving paths preserve a dictionary of model attributes (derived from
self.PRESERVE_ONNX
andself.PRESERVE_OPENVINO
) alongside the model weights. The currentsafetensors
implementation, however, only saves themodel.state_dict()
.
Is this omission of metadata for thesafetensors
format an intentional design choice? If users rely on reloading models with their full context (e.g.,full_config
,iteration
count, which are part ofself.PRESERVE
), this difference in behavior could be critical.
Thesafetensors.torch.save_file
function supports ametadata: Dict[str, str]
argument. To maintain consistency, it would be beneficial to also save relevant preserved attributes when usingsafetensors
. This would involve:- Defining a set of attributes to preserve for
safetensors
(e.g.,self.PRESERVE_SAFETENSORS
or using a subset ofself.PRESERVE
). - Serializing these attributes into a string-to-string dictionary compatible with
safetensors
metadata.
For example:
# Inside the `elif use_safetensors:` block: from safetensors.torch import save_file import json # For serializing complex objects state_dict = self.model.state_dict() # Example: Prepare metadata (attributes and serialization logic need careful consideration) # Note: safetensors metadata values must be strings. ignore_attributes = ignore_attributes or set() preserved_attrs_for_sf = set(self.PRESERVE) - ignore_attributes # Or a dedicated PRESERVE_SAFETENSORS metadata_dict = {} for item_key in preserved_attrs_for_sf: if hasattr(self, item_key): attr_val = getattr(self, item_key) if isinstance(attr_val, (str, int, float, bool)): metadata_dict[item_key] = str(attr_val) elif isinstance(attr_val, (Config, dict, list)): try: # Attempt to JSON serialize complex types metadata_dict[item_key] = json.dumps(attr_val.config if isinstance(attr_val, Config) else attr_val) except TypeError: # Fallback for non-serializable types metadata_dict[item_key] = f"Unserializable: <{type(attr_val).__name__}>" # else: might log a warning for attributes not easily convertible to string metadata save_file(state_dict, path, metadata=metadata_dict)
If omitting metadata is a deliberate choice (e.g., for a minimal, weights-only save), this distinction should be very clearly documented.
- Defining a set of attributes to preserve for
batchflow/models/torch/base.py
Outdated
use_safetensors: bool | ||
Whether to store model as safetensors file. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The docstring for use_safetensors
clearly states its purpose. To further improve user understanding of the save
method with this new option, could we expand the documentation (either here or in the main method docstring) to cover:
-
path
Argument Behavior: Clarify that whenuse_safetensors=True
, the mainpath
argument is directly used as the filename for the.safetensors
file. This is a key difference from the default behavior wherepath
might point to a PyTorch pickle containing metadata, or frompath_onnx
/path_openvino
which are separate arguments. -
Mutual Exclusivity of Formats: Explain how
use_onnx
,use_openvino
, anduse_safetensors
interact if multiple are set toTrue
. The currentif/elif/else
structure implies mutual exclusivity with a specific order of precedence (ONNX > OpenVINO > SafeTensors > default PyTorch save). Making this explicit would prevent user confusion. -
Metadata with
safetensors
: It would be helpful to document whether metadata (attributes fromself.PRESERVE
likefull_config
,iteration
, etc.) is saved whenuse_safetensors=True
. The current implementation appears to save only thestate_dict
, which differs from the ONNX and OpenVINO options that store a dictionary including such metadata. Highlighting this difference is important.
Try to run |
Fix linting issues |
1805059
to
028d383
Compare
batchflow/models/torch/base.py
Outdated
model = convert(file).eval() | ||
self.model = model | ||
|
||
self.model_to_device() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Check that call, model.load('tmp.safetensors', fmt='safetensors', pickle_metadata=False, device='cpu')
create model on cuda
Co-authored-by: Copilot <[email protected]>
Co-authored-by: Alexey Kozhevin <[email protected]>
1296356
to
169f58a
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR adds support for saving models in the safetensors format while also refactoring the save/load methods for various model formats. Key changes include adding the safetensors dependency in pyproject.toml, tagging a slow test in research_test.py, and updating the save and load methods in the torch base model to accommodate new format options.
Reviewed Changes
Copilot reviewed 3 out of 3 changed files in this pull request and generated no comments.
File | Description |
---|---|
pyproject.toml | Adds safetensors dependency |
batchflow/tests/research_test.py | Adds @pytest.mark.slow decorator for a test case |
batchflow/models/torch/base.py | Refactors model saving/loading to support "pt", "onnx", "openvino", and "safetensors" |
Comments suppressed due to low confidence (1)
batchflow/models/torch/base.py:1859
- The variable 'model_load_kwargs' is used without being defined or passed. Please define it or adjust the parameters to avoid a runtime error.
model = OVModel(model_path=file, **model_load_kwargs)
No description provided.