Skip to content

Add save in safetensors format #784

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 33 commits into from
Jun 30, 2025
Merged

Add save in safetensors format #784

merged 33 commits into from
Jun 30, 2025

Conversation

igor-iusupov
Copy link
Contributor

No description provided.

Copy link

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hello @igor-iusupov, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

Summary of Changes

Hello team, gemini-code-assist here to provide a summary of this pull request. This PR introduces the capability to save PyTorch models in the safetensors format within the batchflow library. It achieves this by adding a new boolean parameter, use_safetensors, to the existing save method of the base PyTorch model class. When this parameter is set to True, the method now uses the safetensors.torch.save_file function to save the model's state dictionary, offering an alternative to the default PyTorch pickle format or other supported formats like ONNX and OpenVINO.

Highlights

  • Safetensors Support: Adds the ability to save PyTorch models using the safetensors format.
  • Model Saving: Modifies the base model's save method to include an option for safetensors output.

Changelog

  • batchflow/models/torch/base.py
    • Added the use_safetensors boolean parameter to the save method signature (around line 1671).
    • Updated the docstring for the save method to document the new use_safetensors parameter (around line 1691).
    • Implemented the logic within the save method to handle saving the model's state dictionary using safetensors.torch.save_file when use_safetensors is True (around lines 1761-1764).
    • Removed a blank line (around line 17).
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in issue comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist is currently in preview and may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments to provide feedback.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces the capability to save PyTorch models in the safetensors format, which is a valuable addition for model interoperability and security. The core implementation of saving the state_dict using safetensors.torch.save_file is correct.

My review focuses on improving clarity for users regarding how this new option interacts with existing save functionalities, particularly concerning metadata preservation, the usage of the path argument, and the handling of mutually exclusive format flags. Addressing these points will enhance the robustness and usability of this feature.

Summary of Findings

  • Documentation Clarity for safetensors Save Option: The method docstring should be enhanced to clearly explain: 1) How the path argument is utilized when use_safetensors=True. 2) The mutual exclusivity and precedence if multiple use_... format flags are enabled. 3) Whether associated metadata (like model configuration, training iteration) is saved with the safetensors format, as it currently appears to save only the state_dict.
  • Metadata Preservation Consistency with safetensors: The current safetensors implementation only saves the model's state_dict, which differs from the ONNX and OpenVINO saving options that also store other model attributes (e.g., config, iteration count). This inconsistency could lead to loss of information if not intended. The safetensors.torch.save_file function supports a metadata argument, which could be used to store these attributes for consistency. If this omission is by design, it needs to be prominently documented.
  • Handling of Mutually Exclusive Save Format Flags: The save method's behavior when multiple format flags (use_onnx, use_openvino, use_safetensors) are set to True should be more robust or clearly documented. Currently, only the first format encountered in the if/elif conditional chain is used, which might not be obvious to users. Consider raising an error or explicitly documenting the priority.

Merge Readiness

This pull request introduces a useful feature for saving models in the safetensors format. However, to ensure clarity for users and consistency with existing functionalities, I recommend addressing the points raised, particularly regarding metadata preservation and the handling/documentation of mutually exclusive save options. The identified high severity issue concerning metadata should be carefully considered. Once these aspects are clarified and potentially adjusted, the PR will be in a much stronger position for merging. As I am not authorized to approve pull requests, further review and approval by authorized maintainers will be necessary.

Comment on lines 1761 to 1764
elif use_safetensors:
from safetensors.torch import save_file
state_dict = self.model.state_dict()
save_file(state_dict, path)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This new elif block correctly implements saving to safetensors. There are a couple of points to consider for enhancing robustness and consistency with other save formats:

  1. Handling Multiple Format Flags: If a user inadvertently sets multiple format flags to True (e.g., use_onnx=True and use_safetensors=True), only the first one encountered in the if/elif chain will be executed. This could be non-obvious. Would it be more user-friendly to:

    • Raise a ValueError if more than one of use_onnx, use_openvino, use_safetensors is True?
    • Or, at least, clearly document the order of precedence in the method's docstring?
  2. Metadata Preservation Consistency: The ONNX and OpenVINO saving paths preserve a dictionary of model attributes (derived from self.PRESERVE_ONNX and self.PRESERVE_OPENVINO) alongside the model weights. The current safetensors implementation, however, only saves the model.state_dict().
    Is this omission of metadata for the safetensors format an intentional design choice? If users rely on reloading models with their full context (e.g., full_config, iteration count, which are part of self.PRESERVE), this difference in behavior could be critical.
    The safetensors.torch.save_file function supports a metadata: Dict[str, str] argument. To maintain consistency, it would be beneficial to also save relevant preserved attributes when using safetensors. This would involve:

    • Defining a set of attributes to preserve for safetensors (e.g., self.PRESERVE_SAFETENSORS or using a subset of self.PRESERVE).
    • Serializing these attributes into a string-to-string dictionary compatible with safetensors metadata.
      For example:
    # Inside the `elif use_safetensors:` block:
    from safetensors.torch import save_file
    import json # For serializing complex objects
    
    state_dict = self.model.state_dict()
    
    # Example: Prepare metadata (attributes and serialization logic need careful consideration)
    # Note: safetensors metadata values must be strings.
    ignore_attributes = ignore_attributes or set()
    preserved_attrs_for_sf = set(self.PRESERVE) - ignore_attributes # Or a dedicated PRESERVE_SAFETENSORS
    metadata_dict = {}
    for item_key in preserved_attrs_for_sf:
        if hasattr(self, item_key):
            attr_val = getattr(self, item_key)
            if isinstance(attr_val, (str, int, float, bool)):
                metadata_dict[item_key] = str(attr_val)
            elif isinstance(attr_val, (Config, dict, list)):
                try:
                    # Attempt to JSON serialize complex types
                    metadata_dict[item_key] = json.dumps(attr_val.config if isinstance(attr_val, Config) else attr_val)
                except TypeError:
                    # Fallback for non-serializable types
                    metadata_dict[item_key] = f"Unserializable: <{type(attr_val).__name__}>"
            # else: might log a warning for attributes not easily convertible to string metadata
    
    save_file(state_dict, path, metadata=metadata_dict)

    If omitting metadata is a deliberate choice (e.g., for a minimal, weights-only save), this distinction should be very clearly documented.

Comment on lines 1691 to 1692
use_safetensors: bool
Whether to store model as safetensors file.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The docstring for use_safetensors clearly states its purpose. To further improve user understanding of the save method with this new option, could we expand the documentation (either here or in the main method docstring) to cover:

  1. path Argument Behavior: Clarify that when use_safetensors=True, the main path argument is directly used as the filename for the .safetensors file. This is a key difference from the default behavior where path might point to a PyTorch pickle containing metadata, or from path_onnx/path_openvino which are separate arguments.

  2. Mutual Exclusivity of Formats: Explain how use_onnx, use_openvino, and use_safetensors interact if multiple are set to True. The current if/elif/else structure implies mutual exclusivity with a specific order of precedence (ONNX > OpenVINO > SafeTensors > default PyTorch save). Making this explicit would prevent user confusion.

  3. Metadata with safetensors: It would be helpful to document whether metadata (attributes from self.PRESERVE like full_config, iteration, etc.) is saved when use_safetensors=True. The current implementation appears to save only the state_dict, which differs from the ONNX and OpenVINO options that store a dictionary including such metadata. Highlighting this difference is important.

@igor-iusupov igor-iusupov marked this pull request as ready for review June 10, 2025 09:04
@igor-iusupov igor-iusupov requested a review from a team as a code owner June 10, 2025 09:04
@AlexeyKozhevin
Copy link
Member

Try to run save and load in different scenarios. Thus, you will find some erros, e.g. missing preserved_dict for safetensors case

@AlexeyKozhevin AlexeyKozhevin requested a review from Copilot June 10, 2025 10:03
Copilot

This comment was marked as outdated.

@AlexeyKozhevin
Copy link
Member

Fix linting issues

@roman-kh roman-kh changed the base branch from master to r0.9.0 June 11, 2025 05:50
@roman-kh roman-kh force-pushed the r0.9.0 branch 5 times, most recently from 1805059 to 028d383 Compare June 11, 2025 07:20
@AlexeyKozhevin AlexeyKozhevin requested a review from Copilot June 11, 2025 16:19
Copilot

This comment was marked as outdated.

model = convert(file).eval()
self.model = model

self.model_to_device()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Check that call, model.load('tmp.safetensors', fmt='safetensors', pickle_metadata=False, device='cpu') create model on cuda

@AlexeyKozhevin AlexeyKozhevin requested a review from Copilot June 27, 2025 13:56
Copy link

@Copilot Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR adds support for saving models in the safetensors format while also refactoring the save/load methods for various model formats. Key changes include adding the safetensors dependency in pyproject.toml, tagging a slow test in research_test.py, and updating the save and load methods in the torch base model to accommodate new format options.

Reviewed Changes

Copilot reviewed 3 out of 3 changed files in this pull request and generated no comments.

File Description
pyproject.toml Adds safetensors dependency
batchflow/tests/research_test.py Adds @pytest.mark.slow decorator for a test case
batchflow/models/torch/base.py Refactors model saving/loading to support "pt", "onnx", "openvino", and "safetensors"
Comments suppressed due to low confidence (1)

batchflow/models/torch/base.py:1859

  • The variable 'model_load_kwargs' is used without being defined or passed. Please define it or adjust the parameters to avoid a runtime error.
                model = OVModel(model_path=file, **model_load_kwargs)

@AlexeyKozhevin AlexeyKozhevin merged commit 6ded8cf into r0.9.0 Jun 30, 2025
1 check passed
@AlexeyKozhevin AlexeyKozhevin deleted the save-safetensors branch June 30, 2025 16:35
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants