3838#####################
3939
4040
41- def load_pytorch_checkpoint_in_flax_state_dict (flax_model , pytorch_checkpoint_path , allow_missing_keys = False ):
41+ def load_pytorch_checkpoint_in_flax_state_dict (
42+ flax_model , pytorch_checkpoint_path , is_sharded , allow_missing_keys = False
43+ ):
4244 """Load pytorch checkpoints in a flax model"""
4345 try :
4446 import torch # noqa: F401
@@ -50,14 +52,17 @@ def load_pytorch_checkpoint_in_flax_state_dict(flax_model, pytorch_checkpoint_pa
5052 )
5153 raise
5254
53- pt_path = os .path .abspath (pytorch_checkpoint_path )
54- logger .info (f"Loading PyTorch weights from { pt_path } " )
55+ if not is_sharded :
56+ pt_path = os .path .abspath (pytorch_checkpoint_path )
57+ logger .info (f"Loading PyTorch weights from { pt_path } " )
5558
56- pt_state_dict = torch .load (pt_path , map_location = "cpu" )
57- logger .info (f"PyTorch checkpoint contains { sum (t .numel () for t in pt_state_dict .values ()):,} parameters." )
58-
59- flax_state_dict = convert_pytorch_state_dict_to_flax (pt_state_dict , flax_model )
59+ pt_state_dict = torch .load (pt_path , map_location = "cpu" )
60+ logger .info (f"PyTorch checkpoint contains { sum (t .numel () for t in pt_state_dict .values ()):,} parameters." )
6061
62+ flax_state_dict = convert_pytorch_state_dict_to_flax (pt_state_dict , flax_model )
63+ else :
64+ # model is sharded and pytorch_checkpoint_path already contains the list of .pt shard files
65+ flax_state_dict = convert_pytorch_sharded_state_dict_to_flax (pytorch_checkpoint_path , flax_model )
6166 return flax_state_dict
6267
6368
@@ -156,6 +161,61 @@ def convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model):
156161 return unflatten_dict (flax_state_dict )
157162
158163
164+ ############################
165+ # Sharded Pytorch => Flax #
166+ ############################
167+
168+
169+ def convert_pytorch_sharded_state_dict_to_flax (shard_filenames , flax_model ):
170+ import torch
171+
172+ # Load the index
173+ flax_state_dict = {}
174+ for shard_file in shard_filenames :
175+ # load using msgpack utils
176+ pt_state_dict = torch .load (shard_file )
177+ pt_state_dict = {k : v .numpy () for k , v in pt_state_dict .items ()}
178+
179+ model_prefix = flax_model .base_model_prefix
180+ random_flax_state_dict = flatten_dict (flax_model .params )
181+
182+ load_model_with_head_into_base_model = (model_prefix not in flax_model .params ) and (
183+ model_prefix in set ([k .split ("." )[0 ] for k in pt_state_dict .keys ()])
184+ )
185+ load_base_model_into_model_with_head = (model_prefix in flax_model .params ) and (
186+ model_prefix not in set ([k .split ("." )[0 ] for k in pt_state_dict .keys ()])
187+ )
188+ # Need to change some parameters name to match Flax names
189+ for pt_key , pt_tensor in pt_state_dict .items ():
190+
191+ pt_tuple_key = tuple (pt_key .split ("." ))
192+
193+ # remove base model prefix if necessary
194+ has_base_model_prefix = pt_tuple_key [0 ] == model_prefix
195+ if load_model_with_head_into_base_model and has_base_model_prefix :
196+ pt_tuple_key = pt_tuple_key [1 :]
197+
198+ # Correctly rename weight parameters
199+ flax_key , flax_tensor = rename_key_and_reshape_tensor (
200+ pt_tuple_key , pt_tensor , random_flax_state_dict , model_prefix
201+ )
202+ # add model prefix if necessary
203+ require_base_model_prefix = (model_prefix ,) + flax_key in random_flax_state_dict
204+ if load_base_model_into_model_with_head and require_base_model_prefix :
205+ flax_key = (model_prefix ,) + flax_key
206+
207+ if flax_key in random_flax_state_dict :
208+ if flax_tensor .shape != random_flax_state_dict [flax_key ].shape :
209+ raise ValueError (
210+ f"PyTorch checkpoint seems to be incorrect. Weight { pt_key } was expected to be of shape "
211+ f"{ random_flax_state_dict [flax_key ].shape } , but is { flax_tensor .shape } ."
212+ )
213+
214+ # also add unexpected weight so that warning is thrown
215+ flax_state_dict [flax_key ] = jnp .asarray (flax_tensor )
216+ return unflatten_dict (flax_state_dict )
217+
218+
159219#####################
160220# Flax => PyTorch #
161221#####################
0 commit comments