@@ -2297,16 +2297,7 @@ def save_model(self, output_dir: Optional[str] = None, merge_tensor_parallel: Op
22972297 self .model_wrapped .get_all_parameters (convert2cpu = True )
22982298
22992299 if self .args .should_save_model_state :
2300- unified_checkpoint_config_backup = self .args .unified_checkpoint_config
2301- # backup and remove unified_checkpoint_config for not trine stage
2302- if not self .is_in_train :
2303- self .args .unified_checkpoint_config = []
2304-
23052300 self ._save (output_dir = output_dir , merge_tensor_parallel = merge_tensor_parallel )
2306-
2307- # recover unified_checkpoint_config for not trine stage
2308- if not self .is_in_train :
2309- self .args .unified_checkpoint_config = unified_checkpoint_config_backup
23102301 else :
23112302 if self .args .unified_checkpoint and "async_save" in self .args .unified_checkpoint_config :
23122303 os .makedirs (output_dir , exist_ok = True )
@@ -2584,10 +2575,9 @@ def _save(self, output_dir: Optional[str] = None, state_dict=None, merge_tensor_
25842575 # Save a trained model and configuration using `save_pretrained()`.
25852576 # They can then be reloaded using `from_pretrained()`
25862577
2587- local_rank = int (os .getenv ("PADDLE_RANK_IN_NODE" , 0 ))
25882578 if (
25892579 strtobool (os .getenv ("FLAG_LLM_PDC" , "False" ))
2590- and local_rank == 0
2580+ and paddle . distributed . get_rank () == 0
25912581 and self .args .unified_checkpoint
25922582 and "async_save" in self .args .unified_checkpoint_config
25932583 ):
@@ -2598,9 +2588,10 @@ def _save(self, output_dir: Optional[str] = None, state_dict=None, merge_tensor_
25982588 "ignore_save_lr_and_optim" : self .args .ignore_save_lr_and_optim ,
25992589 "skip_save_model_weight" : "skip_save_model_weight" in self .args .unified_checkpoint_config ,
26002590 }
2601- if not os .path .exists (os .path .join (self .args .logging_dir , "async_save_info.json" )):
2602- with open (os .path .join (self .args .logging_dir , "async_save_info.json" ), "w" ) as f :
2603- json .dump (save_info , f )
2591+ if os .path .exists (os .path .join (self .args .logging_dir , "async_save_info.json" )): # afs cannot overwrite
2592+ os .remove (os .path .join (self .args .logging_dir , "async_save_info.json" ))
2593+ with open (os .path .join (self .args .logging_dir , "async_save_info.json" ), "w" ) as f :
2594+ json .dump (save_info , f )
26042595
26052596 if self .args .should_save :
26062597 if self .tokenizer is not None :
@@ -2609,7 +2600,17 @@ def _save(self, output_dir: Optional[str] = None, state_dict=None, merge_tensor_
26092600 paddle .save (self .args , os .path .join (output_dir , TRAINING_ARGS_NAME ))
26102601
26112602 if self .args .unified_checkpoint :
2603+ unified_checkpoint_config_backup = self .args .unified_checkpoint_config
2604+ # backup and remove unified_checkpoint_config for not trine stage
2605+ if not self .is_in_train :
2606+ self .args .unified_checkpoint_config = []
2607+
26122608 self .unified_checkpoint_handler .save_unified_checkpoint (self .model , self .optimizer , output_dir )
2609+
2610+ # recover unified_checkpoint_config for not trine stage
2611+ if not self .is_in_train :
2612+ self .args .unified_checkpoint_config = unified_checkpoint_config_backup
2613+
26132614 return
26142615
26152616 merge_tensor_parallel = merge_tensor_parallel and self .args .use_hybrid_parallel
0 commit comments