3030from paddlenlp .transformers .model_utils import (
3131 PretrainedModel ,
3232 _load_state_dict_into_model ,
33+ faster_set_state_dict ,
3334 get_parameter_dtype ,
3435 load_state_dict ,
3536 unwrap_model ,
6566from paddlenlp .utils .nested import nested_copy , nested_copy_place
6667
6768if is_safetensors_available ():
68- from safetensors import safe_open
69+ # from safetensors import safe_open
6970 from safetensors .numpy import save_file as safe_save_file
7071
72+ from paddlenlp .utils .safetensors import fast_safe_open as safe_open
7173
7274FP32_MASTER = "fp32_master_0"
7375optimizer_scalar_name = [
@@ -196,7 +198,6 @@ def load_unified_checkpoint(args, model, optimizer, resume_from_checkpoint: str,
196198 Returns:
197199 None
198200 """
199-
200201 if paddle .distributed .get_world_size () <= 1 :
201202 load_single_card_checkpoint (args , model , resume_from_checkpoint )
202203 return
@@ -222,7 +223,6 @@ def load_unified_checkpoint_locally(args, model, resume_from_checkpoint: str, sa
222223 pretrained_model_name_or_path = resume_from_checkpoint ,
223224 index_filename = os .path .join (resume_from_checkpoint , index_filename ),
224225 )
225-
226226 loaded_keys = sharded_metadata ["all_checkpoint_keys" ]
227227
228228 model_state_dict = get_expected_state_dict (model )
@@ -266,7 +266,9 @@ def _remove_unused_keys(
266266 else :
267267 tp_actions = model .get_tensor_parallel_convert_actions (model .config , loaded_keys , ignore_error = True )
268268 # Here we use expected_keys to optimize weights loading for pipeline model. Only works for safetensors
269- state_dict = load_state_dict (shard_file , tp_actions if pre_tensor_parallel_split else None , expected_keys )
269+ state_dict = load_state_dict (
270+ shard_file , tp_actions if pre_tensor_parallel_split else None , expected_keys , device = "expected"
271+ )
270272
271273 if not pre_tensor_parallel_split :
272274 # Since we load all keys but we only need one of pipeline stages
@@ -279,11 +281,12 @@ def _remove_unused_keys(
279281 None , model .config , state_dict = state_dict , ignore_error = len (resolved_archive_file ) > 1
280282 )
281283
282- error_msgs += _load_state_dict_into_model (model , state_dict , "" )
284+ # error_msgs += _load_state_dict_into_model(model, state_dict, "")
285+ error_msgs += faster_set_state_dict (model , state_dict , strict_dtype = False )
283286
284287 # force memory release
285288 del state_dict
286- gc .collect ()
289+ # gc.collect()
287290
288291 if len (error_msgs ) > 0 :
289292 error_msg = "\n \t " .join (error_msgs )
@@ -337,6 +340,7 @@ def unified_checkpoint_into_shards(
337340 tp_actions = model_to_save .get_tensor_parallel_convert_actions (
338341 model_to_save .config , state_dict .keys (), is_split = False , ignore_error = True
339342 )
343+ logger .info ("Unified model tensor parallel weights in shards" )
340344 state_dict = merge_tensor_parallel_with_shard (state_dict , tp_actions , all_filter_keys )
341345
342346 # build index json file
@@ -490,6 +494,7 @@ def load_unified_optimizer_locally(args, model, optimizer, resume_from_checkpoin
490494 # This should always be a list but, just to be sure.
491495 if not isinstance (resolved_archive_file , list ):
492496 resolved_archive_file = [resolved_archive_file ]
497+
493498 if len (resolved_archive_file ) > 1 :
494499 resolved_archive_file = tqdm (resolved_archive_file , desc = "Loading optimizer shards" )
495500
@@ -537,10 +542,10 @@ def load_resolved_archive_file(resolved_archive_file, sharded_metadata, expected
537542 tp_actions = mapping_optimizer_tp_actions (tp_actions , expected_keys )
538543
539544 # Here we use expected_keys to optimize weights loading for pipeline model. Only works for safetensors
540- state_dict = load_state_dict (shard_file , tp_actions , expected_keys )
545+ state_dict = load_state_dict (shard_file , tp_actions , expected_keys , device = "expected" )
541546 else :
542547 # for pipeline model, we don't need to use tp_actions
543- state_dict = load_state_dict (shard_file , None , expected_keys )
548+ state_dict = load_state_dict (shard_file , None , expected_keys , device = "expected" )
544549
545550 returned_state_dict .update (state_dict )
546551 # force memory release
@@ -553,7 +558,6 @@ def load_resolved_archive_file(resolved_archive_file, sharded_metadata, expected
553558 state_dict_master_weight = load_resolved_archive_file (
554559 resolved_archive_file_mw , sharded_metadata_mw , expected_keys_mw , is_master_weights = True
555560 )
556-
557561 # rename optimizer param
558562 for key in list (state_dict_optim .keys ()):
559563 key_name = key .split ("/" )
@@ -562,13 +566,13 @@ def load_resolved_archive_file(resolved_archive_file, sharded_metadata, expected
562566 key_name = "_" .join ([static_name , FP32_MASTER , key_name [1 ]])
563567 else :
564568 key_name = "_" .join ([static_name , key_name [1 ]])
565- returned_optim_state_dict [key_name ] = state_dict_optim [ key ]
569+ returned_optim_state_dict [key_name ] = state_dict_optim . pop ( key )
566570 returned_optim_state_dict [key_name ].name = key_name
567571
568572 if has_master_weights :
569573 for key in list (state_dict_master_weight .keys ()):
570574 static_name = struct2static_name_mappings [key ]
571- returned_optim_state_dict ["master_weights" ][static_name ] = state_dict_master_weight [ key ]
575+ returned_optim_state_dict ["master_weights" ][static_name ] = state_dict_master_weight . pop ( key )
572576 returned_optim_state_dict ["master_weights" ][static_name ].name = "_" .join ([static_name , FP32_MASTER ])
573577
574578 returned_optim_state_dict = nested_copy_place (
@@ -640,6 +644,7 @@ def unified_optimizer_into_shards(
640644 tp_actions = model .get_tensor_parallel_convert_actions (
641645 model .config , model_keys , is_split = False , ignore_error = True
642646 )
647+ logger .info ("Unified optimizer tensor parallel in shards" )
643648 optim_state_dict = merge_tensor_parallel_for_optimizer (
644649 optim_state_dict ,
645650 tp_actions ,
@@ -648,6 +653,7 @@ def unified_optimizer_into_shards(
648653 paddle .device .cuda .empty_cache ()
649654
650655 if master_weights is not None :
656+ logger .info ("Unified master weight tensor parallel in shards" )
651657 master_weights = merge_tensor_parallel_for_optimizer (
652658 master_weights ,
653659 tp_actions ,
@@ -703,7 +709,6 @@ def unified_optimizer_into_shards(
703709def check_unified_checkpoint (args , model , resume_from_checkpoint , safe_serialization = False ):
704710 index_filename = select_model_weight_index (args , model , resume_from_checkpoint , safe_serialization , local = False )
705711 index_filename = os .path .join (resume_from_checkpoint , index_filename )
706-
707712 # Find index json file and distribute this file in global group.
708713 if distributed_isfile (index_filename ):
709714 distributed_file (index_filename )
@@ -1605,7 +1610,9 @@ def gather_sharded_object(index_file, total_size, is_optimizer=False):
16051610 tp_group = hcg .get_model_parallel_group ()
16061611 pp_group = hcg .get_pipe_parallel_group ()
16071612
1608- logger .info ("Unified checkpoint generating sharded_index json files." )
1613+ logger .info (
1614+ f"Unified checkpoint: generating sharded_index json files for { 'optimizer or master weight' if is_optimizer else 'model weight' } ."
1615+ )
16091616
16101617 if tp_group .nranks > 1 :
16111618 dist .all_gather_object (index_file_list , index_file , tp_group )
@@ -1714,8 +1721,6 @@ def filter_params(model_to_save, state_dict, is_optimizer=False):
17141721
17151722
17161723def merge_tensor_parallel_with_shard (state_dict , tp_actions , all_filter_keys ):
1717- logger .info ("Unified checkpoint merge tensor parallel in shards" )
1718-
17191724 hcg = fleet .get_hybrid_communicate_group ()
17201725 tp_group = hcg .get_model_parallel_group ()
17211726 tp_rank = tp_group .rank
@@ -1741,7 +1746,7 @@ def merge_tensor_parallel_with_shard(state_dict, tp_actions, all_filter_keys):
17411746 action = tp_actions .pop (key )
17421747 tensor = action (ret ) if is_dst else None
17431748 else :
1744- tensor = tensor ._copy_to (paddle .CPUPlace (), False ) if is_dst else None
1749+ tensor = tensor ._copy_to (paddle .CUDAPinnedPlace (), False ) if is_dst else None
17451750
17461751 if is_dst :
17471752 state_dict_to_save [key ] = tensor
@@ -1754,8 +1759,7 @@ def merge_tensor_parallel_with_shard(state_dict, tp_actions, all_filter_keys):
17541759
17551760
17561761def merge_tensor_parallel_for_optimizer (state_dict , tp_actions , all_filter_keys ):
1757- logger .info ("Unified optimizer tensor parallel in shards" )
1758-
1762+ # Core function for UC
17591763 hcg = fleet .get_hybrid_communicate_group ()
17601764 tp_group = hcg .get_model_parallel_group ()
17611765 tp_rank = tp_group .rank
@@ -1774,14 +1778,14 @@ def merge_tensor_parallel_for_optimizer(state_dict, tp_actions, all_filter_keys)
17741778 # for example: beta1, beta2
17751779 if tensor .numel ().item () == 1 :
17761780 tensor = (
1777- tensor ._copy_to (paddle .CPUPlace (), False ) if is_dst else None
1781+ tensor ._copy_to (paddle .CUDAPinnedPlace (), False ) if is_dst else None
17781782 ) # Need broadcast when loaded
17791783 else :
17801784 ret = distributed_gather (tensor , dst = j , group = tp_group , offload = False )
17811785 action = tp_actions [model_key ]
17821786 tensor = action (ret ) if is_dst else None
17831787 else :
1784- tensor = tensor ._copy_to (paddle .CPUPlace (), False ) if is_dst else None
1788+ tensor = tensor ._copy_to (paddle .CUDAPinnedPlace (), False ) if is_dst else None
17851789
17861790 if is_dst :
17871791 state_dict_to_save [filter_keys [i ]] = tensor
0 commit comments