Skip to content

Loading sharded model in tf from pytorch checkpoints #17537

@ArthurZucker

Description

@ArthurZucker

System Info

- `transformers` version: 4.20.0.dev0
- Platform: macOS-12.4-arm64-arm-64bit
- Python version: 3.9.12
- Huggingface_hub version: 0.6.0
- PyTorch version (GPU?): 1.13.0.dev20220521 (False)
- Tensorflow version (GPU?): 2.9.0 (True)
- Flax version (CPU?/GPU?/TPU?): 0.4.2 (cpu)
- Jax version: 0.3.6
- JaxLib version: 0.3.5

Who can help?

@LysandreJik I am not sure who to ping on that 😅

Loading a big model from the hub in tensorflow is impossible if the model is sharded.

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

>>> tf_model = TFOPTModel.from_pretrained("facebook/opt-13b",from_pt = True)
Traceback (most recent call last):
  File "/home/arthur_huggingface_co/transformers/src/transformers/modeling_tf_utils.py", line 1789, in from_pretrained
    resolved_archive_file = cached_path(
  File "/home/arthur_huggingface_co/transformers/src/transformers/utils/hub.py", line 282, in cached_path
    output_path = get_from_cache(
  File "/home/arthur_huggingface_co/transformers/src/transformers/utils/hub.py", line 486, in get_from_cache
    _raise_for_status(r)
  File "/home/arthur_huggingface_co/transformers/src/transformers/utils/hub.py", line 409, in _raise_for_status
    raise EntryNotFoundError(f"404 Client Error: Entry Not Found for url: {request.url}")
transformers.utils.hub.EntryNotFoundError: 404 Client Error: Entry Not Found for url: https://huggingface.co/facebook/opt-13b/resolve/main/pytorch_model.bin

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/home/arthur_huggingface_co/transformers/src/transformers/modeling_tf_utils.py", line 1833, in from_pretrained
    raise EnvironmentError(
OSError: facebook/opt-13b does not appear to have a file named pytorch_model.bin.

The following script has to be used in order to convert the weights:

path = "facebook/opt-13b"
pt_model = OPTModel.from_pretrained(path)
pt_model.save_pretrained(path,max_shard_size = "1000GB")
tf_model = TFOPTModel.from_pretrained(path,from_pt = True)
tf_model.save_pretrained(path,save_config=False)

Expected behavior

Automatically do this in background?

Metadata

Metadata

Assignees

Labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions