4141
4242class CheckpointConverter :
4343 def __init__ (
44- self , hybrid_parallel_ckpt_path , state_dict , parameter_to_structured_name , trainging_args = None , patch_dict = None
44+ self ,
45+ hybrid_parallel_ckpt_path ,
46+ state_dict ,
47+ parameter_to_structured_name ,
48+ trainging_args = None ,
49+ patch_dict = None ,
50+ local_view_pattern : list | bool = None ,
4551 ):
4652 self .use_dist = True if paddle .distributed .get_world_size () > 1 else False
4753 self .path = hybrid_parallel_ckpt_path
@@ -85,6 +91,17 @@ def __init__(
8591 self .auto_parallel_state_dict [self .patch_dict [k ]] = self .auto_parallel_state_dict [k ]
8692 for k in del_keys :
8793 self .auto_parallel_state_dict .pop (k )
94+ # solve the problem of inconsistent parameter names in moe automatic parallel mode.
95+ if hasattr (trainging_args , "moe_group" ) and trainging_args .moe_group :
96+ if local_view_pattern is False :
97+ self .local_view_pattern_list = None
98+ else :
99+ if isinstance (local_view_pattern , list ):
100+ self .local_view_pattern_list = local_view_pattern
101+ else :
102+ self .local_view_pattern_list = ["experts" ]
103+ else :
104+ self .local_view_pattern_list = None
88105
89106 flags = [
90107 ["tp degree" , self .tp_degree ],
@@ -497,6 +514,46 @@ def gen_metadata_and_prepare_source_state_dict(self):
497514 else :
498515 return self .gen_metadata_for_tp_sharded_tensor ()
499516
517+ def rename_local_view_state_dict (self , state_dict , file_name ):
518+ """
519+ Rename the key for local views to the key for global views, and return the renamed `state_dict`.
520+ """
521+ if self .local_view_pattern_list is None :
522+ return state_dict
523+ # case 1: moe_group is mp_group
524+ if self .tp_degree > 1 and self .sharding_degree <= 1 :
525+ (tp_rank , pp_rank , sharding_rank ) = self .get_distribution_rank_from_file_name (file_name )
526+ expert_name_old2new = {}
527+ for pattern in self .local_view_pattern_list :
528+ expert_pattern = rf"({ pattern } \.)(\d+)"
529+ # extract all experts IDs
530+ expert_ids = set ()
531+ for state_name in state_dict .keys ():
532+ res = re .search (expert_pattern , state_name )
533+ if res :
534+ expert_ids .add (int (res .group (2 )))
535+ expert_num = len (expert_ids )
536+ # construct old name to new name mapping
537+ for state_name in state_dict .keys ():
538+ res = re .search (expert_pattern , state_name )
539+ if res :
540+ new_expert_id = int (res .group (2 )) % expert_num + tp_rank * expert_num
541+ expert_name_old2new [state_name ] = re .sub (
542+ expert_pattern , f"{ res .group (1 )} { new_expert_id } " , state_name
543+ )
544+ # rename state_dict
545+ renamed_state_dict = {
546+ expert_name_old2new [state_name ]
547+ if state_name in expert_name_old2new
548+ else state_name : state_dict [state_name ]
549+ for state_name in state_dict .keys ()
550+ }
551+
552+ return renamed_state_dict
553+ # TODO: add support for sharding
554+ else :
555+ return state_dict
556+
500557 def load_state_dict_and_rename (self ):
501558 """
502559 Parse the distributed information from the names of the checkpoint files and evenly parse out the distributed information for each weight/optimizer state
@@ -741,11 +798,10 @@ def load_state_dict_and_rename(self):
741798 model_state_file_name = self .get_model_state_file_from (file_name )
742799 assert model_state_file_name is not None
743800 model_state_keys = global_file_to_state_dict_keys_mapping [model_state_file_name ]
744- renamed_state_dict = self .rename_using_optimizer_state_order (model_state_keys , state_dict )
745- self .get_sharded_tensor_infos (file , renamed_state_dict , cur_rank_sharded_tensor_infos )
746- self .cur_rank_loaded_state_dict [file_name ] = renamed_state_dict
747- else :
748- self .get_sharded_tensor_infos (file_name , state_dict , cur_rank_sharded_tensor_infos )
801+ state_dict = self .rename_using_optimizer_state_order (model_state_keys , state_dict )
802+ renamed_state_dict = self .rename_local_view_state_dict (state_dict , file_name )
803+ self .get_sharded_tensor_infos (file_name , renamed_state_dict , cur_rank_sharded_tensor_infos )
804+ self .cur_rank_loaded_state_dict [file_name ] = renamed_state_dict
749805 else :
750806 for file , state_dict in self .cur_rank_loaded_state_dict .items ():
751807 # The rule for renaming is to change the master_weights name in the optimizer state to the model weight name,
@@ -897,6 +953,9 @@ def rename(old_name, parameter_to_structured_name):
897953 return None
898954
899955 for key , value in state_dict .items ():
956+ # NOTE: Skip the parameters that are not initialized,which are not in the current rank.
957+ if value is None or (isinstance (value , paddle .Tensor ) and not value ._is_initialized ()):
958+ continue
900959 if key in parameter_to_structured_name .values ():
901960 new_name = key
902961 else :
@@ -909,7 +968,9 @@ def rename(old_name, parameter_to_structured_name):
909968 def rename_using_optimizer_state_order (self , model_state_keys , optimizer_state_dict ):
910969 name_mapping = {}
911970 suffix_bucket = {}
912- assert len (optimizer_state_dict ) % len (model_state_keys ) == 0
971+ # TODO: After adapting to sharding, remove the code below.
972+ if self .is_sharding_stage3 or (self .sharding_degree > 1 and self .sharding_stage1_v == 2 ):
973+ assert len (optimizer_state_dict ) % len (model_state_keys ) == 0
913974 for suffix in OPTIMIZER_STATE_NAME_SUFFIX :
914975 suffix_bucket [suffix ] = []
915976 for opt_name , opt_value in optimizer_state_dict .items ():
@@ -927,10 +988,27 @@ def rename_using_optimizer_state_order(self, model_state_keys, optimizer_state_d
927988 for suffix , old_names in suffix_bucket .items ():
928989 if len (old_names ) == 0 :
929990 continue
930- assert len (old_names ) == len (model_state_keys )
931- for i in range (len (old_names )):
932- name_mapping [old_names [i ]] = model_state_keys [i ] + suffix
933-
991+ # TODO: After adapting to sharding, remove the code below.
992+ if self .is_sharding_stage3 or (self .sharding_degree > 1 and self .sharding_stage1_v == 2 ):
993+ assert len (old_names ) == len (model_state_keys )
994+
995+ # NOTE: Handle the case where the number of master_weight elements is not equal to the number of model_state_keys.
996+ if suffix != ".master_weight" :
997+ for i in range (len (old_names )):
998+ name_mapping [old_names [i ]] = model_state_keys [i ] + suffix
999+ else :
1000+ for i in range (len (old_names )):
1001+ param = old_names [i ][:- 14 ]
1002+ index = - 1
1003+ for idx , opt_name in enumerate (suffix_bucket [".moment1" ]):
1004+ if param == opt_name [:- 24 ]:
1005+ index = idx
1006+ break
1007+ if index >= 0 :
1008+ name_mapping [old_names [i ]] = model_state_keys [index ] + suffix
1009+ else :
1010+ raise RuntimeError (f"Can't find { param } in optimizer state dict." )
1011+ # rename state dict
9341012 renamed_state_dict = {}
9351013 for k , v in optimizer_state_dict .items ():
9361014 renamed_state_dict [name_mapping [k ]] = v
0 commit comments