Skip to content

Commit 31e47a5

Browse files
authored
[Unified Checkpoint] update async_save_info in develop (#9173)
1 parent f80a051 commit 31e47a5

File tree

1 file changed

+15
-14
lines changed

1 file changed

+15
-14
lines changed

paddlenlp/trainer/trainer.py

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2297,16 +2297,7 @@ def save_model(self, output_dir: Optional[str] = None, merge_tensor_parallel: Op
22972297
self.model_wrapped.get_all_parameters(convert2cpu=True)
22982298

22992299
if self.args.should_save_model_state:
2300-
unified_checkpoint_config_backup = self.args.unified_checkpoint_config
2301-
# backup and remove unified_checkpoint_config for not trine stage
2302-
if not self.is_in_train:
2303-
self.args.unified_checkpoint_config = []
2304-
23052300
self._save(output_dir=output_dir, merge_tensor_parallel=merge_tensor_parallel)
2306-
2307-
# recover unified_checkpoint_config for not trine stage
2308-
if not self.is_in_train:
2309-
self.args.unified_checkpoint_config = unified_checkpoint_config_backup
23102301
else:
23112302
if self.args.unified_checkpoint and "async_save" in self.args.unified_checkpoint_config:
23122303
os.makedirs(output_dir, exist_ok=True)
@@ -2584,10 +2575,9 @@ def _save(self, output_dir: Optional[str] = None, state_dict=None, merge_tensor_
25842575
# Save a trained model and configuration using `save_pretrained()`.
25852576
# They can then be reloaded using `from_pretrained()`
25862577

2587-
local_rank = int(os.getenv("PADDLE_RANK_IN_NODE", 0))
25882578
if (
25892579
strtobool(os.getenv("FLAG_LLM_PDC", "False"))
2590-
and local_rank == 0
2580+
and paddle.distributed.get_rank() == 0
25912581
and self.args.unified_checkpoint
25922582
and "async_save" in self.args.unified_checkpoint_config
25932583
):
@@ -2598,9 +2588,10 @@ def _save(self, output_dir: Optional[str] = None, state_dict=None, merge_tensor_
25982588
"ignore_save_lr_and_optim": self.args.ignore_save_lr_and_optim,
25992589
"skip_save_model_weight": "skip_save_model_weight" in self.args.unified_checkpoint_config,
26002590
}
2601-
if not os.path.exists(os.path.join(self.args.logging_dir, "async_save_info.json")):
2602-
with open(os.path.join(self.args.logging_dir, "async_save_info.json"), "w") as f:
2603-
json.dump(save_info, f)
2591+
if os.path.exists(os.path.join(self.args.logging_dir, "async_save_info.json")): # afs cannot overwrite
2592+
os.remove(os.path.join(self.args.logging_dir, "async_save_info.json"))
2593+
with open(os.path.join(self.args.logging_dir, "async_save_info.json"), "w") as f:
2594+
json.dump(save_info, f)
26042595

26052596
if self.args.should_save:
26062597
if self.tokenizer is not None:
@@ -2609,7 +2600,17 @@ def _save(self, output_dir: Optional[str] = None, state_dict=None, merge_tensor_
26092600
paddle.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))
26102601

26112602
if self.args.unified_checkpoint:
2603+
unified_checkpoint_config_backup = self.args.unified_checkpoint_config
2604+
# backup and remove unified_checkpoint_config for not trine stage
2605+
if not self.is_in_train:
2606+
self.args.unified_checkpoint_config = []
2607+
26122608
self.unified_checkpoint_handler.save_unified_checkpoint(self.model, self.optimizer, output_dir)
2609+
2610+
# recover unified_checkpoint_config for not trine stage
2611+
if not self.is_in_train:
2612+
self.args.unified_checkpoint_config = unified_checkpoint_config_backup
2613+
26132614
return
26142615

26152616
merge_tensor_parallel = merge_tensor_parallel and self.args.use_hybrid_parallel

0 commit comments

Comments
 (0)