Skip to content

Commit bce36ee

Browse files
Load sharded pt to flax (#18419)
* initial commit * add small test * add cross pt tf flag to test * fix quality * style * update test with new repo * fix failing test * update * fix wrong param ordering * style * update based on review * update related to recent new caching mechanism * quality * Update based on review Co-authored-by: sgugger <[email protected]> * quality and style * Update src/transformers/modeling_flax_utils.py Co-authored-by: sgugger <[email protected]> Co-authored-by: Sylvain Gugger <[email protected]>
1 parent c8b6ae8 commit bce36ee

File tree

3 files changed

+94
-8
lines changed

3 files changed

+94
-8
lines changed

src/transformers/modeling_flax_pytorch_utils.py

Lines changed: 67 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,9 @@
3838
#####################
3939

4040

41-
def load_pytorch_checkpoint_in_flax_state_dict(flax_model, pytorch_checkpoint_path, allow_missing_keys=False):
41+
def load_pytorch_checkpoint_in_flax_state_dict(
42+
flax_model, pytorch_checkpoint_path, is_sharded, allow_missing_keys=False
43+
):
4244
"""Load pytorch checkpoints in a flax model"""
4345
try:
4446
import torch # noqa: F401
@@ -50,14 +52,17 @@ def load_pytorch_checkpoint_in_flax_state_dict(flax_model, pytorch_checkpoint_pa
5052
)
5153
raise
5254

53-
pt_path = os.path.abspath(pytorch_checkpoint_path)
54-
logger.info(f"Loading PyTorch weights from {pt_path}")
55+
if not is_sharded:
56+
pt_path = os.path.abspath(pytorch_checkpoint_path)
57+
logger.info(f"Loading PyTorch weights from {pt_path}")
5558

56-
pt_state_dict = torch.load(pt_path, map_location="cpu")
57-
logger.info(f"PyTorch checkpoint contains {sum(t.numel() for t in pt_state_dict.values()):,} parameters.")
58-
59-
flax_state_dict = convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model)
59+
pt_state_dict = torch.load(pt_path, map_location="cpu")
60+
logger.info(f"PyTorch checkpoint contains {sum(t.numel() for t in pt_state_dict.values()):,} parameters.")
6061

62+
flax_state_dict = convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model)
63+
else:
64+
# model is sharded and pytorch_checkpoint_path already contains the list of .pt shard files
65+
flax_state_dict = convert_pytorch_sharded_state_dict_to_flax(pytorch_checkpoint_path, flax_model)
6166
return flax_state_dict
6267

6368

@@ -156,6 +161,61 @@ def convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model):
156161
return unflatten_dict(flax_state_dict)
157162

158163

164+
############################
165+
# Sharded Pytorch => Flax #
166+
############################
167+
168+
169+
def convert_pytorch_sharded_state_dict_to_flax(shard_filenames, flax_model):
170+
import torch
171+
172+
# Load the index
173+
flax_state_dict = {}
174+
for shard_file in shard_filenames:
175+
# load using msgpack utils
176+
pt_state_dict = torch.load(shard_file)
177+
pt_state_dict = {k: v.numpy() for k, v in pt_state_dict.items()}
178+
179+
model_prefix = flax_model.base_model_prefix
180+
random_flax_state_dict = flatten_dict(flax_model.params)
181+
182+
load_model_with_head_into_base_model = (model_prefix not in flax_model.params) and (
183+
model_prefix in set([k.split(".")[0] for k in pt_state_dict.keys()])
184+
)
185+
load_base_model_into_model_with_head = (model_prefix in flax_model.params) and (
186+
model_prefix not in set([k.split(".")[0] for k in pt_state_dict.keys()])
187+
)
188+
# Need to change some parameters name to match Flax names
189+
for pt_key, pt_tensor in pt_state_dict.items():
190+
191+
pt_tuple_key = tuple(pt_key.split("."))
192+
193+
# remove base model prefix if necessary
194+
has_base_model_prefix = pt_tuple_key[0] == model_prefix
195+
if load_model_with_head_into_base_model and has_base_model_prefix:
196+
pt_tuple_key = pt_tuple_key[1:]
197+
198+
# Correctly rename weight parameters
199+
flax_key, flax_tensor = rename_key_and_reshape_tensor(
200+
pt_tuple_key, pt_tensor, random_flax_state_dict, model_prefix
201+
)
202+
# add model prefix if necessary
203+
require_base_model_prefix = (model_prefix,) + flax_key in random_flax_state_dict
204+
if load_base_model_into_model_with_head and require_base_model_prefix:
205+
flax_key = (model_prefix,) + flax_key
206+
207+
if flax_key in random_flax_state_dict:
208+
if flax_tensor.shape != random_flax_state_dict[flax_key].shape:
209+
raise ValueError(
210+
f"PyTorch checkpoint seems to be incorrect. Weight {pt_key} was expected to be of shape "
211+
f"{random_flax_state_dict[flax_key].shape}, but is {flax_tensor.shape}."
212+
)
213+
214+
# also add unexpected weight so that warning is thrown
215+
flax_state_dict[flax_key] = jnp.asarray(flax_tensor)
216+
return unflatten_dict(flax_state_dict)
217+
218+
159219
#####################
160220
# Flax => PyTorch #
161221
#####################

src/transformers/modeling_flax_utils.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
from .utils import (
4141
FLAX_WEIGHTS_INDEX_NAME,
4242
FLAX_WEIGHTS_NAME,
43+
WEIGHTS_INDEX_NAME,
4344
WEIGHTS_NAME,
4445
PushToHubMixin,
4546
add_code_sample_docstrings,
@@ -650,6 +651,10 @@ def from_pretrained(
650651
if from_pt and os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)):
651652
# Load from a PyTorch checkpoint
652653
archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)
654+
elif from_pt and os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_INDEX_NAME)):
655+
# Load from a sharded pytorch checkpoint
656+
archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_INDEX_NAME)
657+
is_sharded = True
653658
elif os.path.isfile(os.path.join(pretrained_model_name_or_path, FLAX_WEIGHTS_NAME)):
654659
# Load from a Flax checkpoint
655660
archive_file = os.path.join(pretrained_model_name_or_path, FLAX_WEIGHTS_NAME)
@@ -700,6 +705,13 @@ def from_pretrained(
700705
)
701706
if resolved_archive_file is not None:
702707
is_sharded = True
708+
# Maybe the checkpoint is pytorch sharded, we try to grab the pytorch index name in this case.
709+
elif resolved_archive_file is None and from_pt:
710+
resolved_archive_file = cached_file(
711+
pretrained_model_name_or_path, WEIGHTS_INDEX_NAME, **cached_file_kwargs
712+
)
713+
if resolved_archive_file is not None:
714+
is_sharded = True
703715
if resolved_archive_file is None:
704716
# Otherwise, maybe there is a TF or Flax model file. We try those to give a helpful error
705717
# message.
@@ -714,6 +726,12 @@ def from_pretrained(
714726
f" {FLAX_WEIGHTS_NAME} but there is a file for PyTorch weights. Use `from_pt=True` to"
715727
" load this model from those weights."
716728
)
729+
elif has_file(pretrained_model_name_or_path, WEIGHTS_INDEX_NAME, **has_file_kwargs):
730+
raise EnvironmentError(
731+
f"{pretrained_model_name_or_path} does not appear to have a file named"
732+
f" {FLAX_WEIGHTS_INDEX_NAME} but there is a sharded file for PyTorch weights. Use"
733+
" `from_pt=True` to load this model from those weights."
734+
)
717735
else:
718736
raise EnvironmentError(
719737
f"{pretrained_model_name_or_path} does not appear to have a file named"
@@ -761,7 +779,7 @@ def from_pretrained(
761779
model = cls(config, *model_args, _do_init=_do_init, **model_kwargs)
762780

763781
if from_pt:
764-
state = load_pytorch_checkpoint_in_flax_state_dict(model, resolved_archive_file)
782+
state = load_pytorch_checkpoint_in_flax_state_dict(model, resolved_archive_file, is_sharded)
765783
else:
766784

767785
if is_sharded:

tests/test_modeling_flax_common.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1099,6 +1099,14 @@ def test_checkpoint_sharding_local(self):
10991099
for p1, p2 in zip(flatten_dict(model.params).values(), flatten_dict(new_model.params).values()):
11001100
self.assertTrue(np.allclose(np.array(p1), np.array(p2)))
11011101

1102+
@is_pt_flax_cross_test
1103+
def test_from_sharded_pt(self):
1104+
model = FlaxBertModel.from_pretrained("hf-internal-testing/tiny-random-bert-sharded", from_pt=True)
1105+
ref_model = FlaxBertModel.from_pretrained("hf-internal-testing/tiny-random-bert-fx-only")
1106+
for key, ref_val in flatten_dict(ref_model.params).items():
1107+
val = flatten_dict(model.params)[key]
1108+
assert np.allclose(np.array(val), np.array(ref_val))
1109+
11021110
def test_gradient_checkpointing(self):
11031111
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
11041112

0 commit comments

Comments
 (0)