Skip to content

Commit 2d0f836

Browse files
committed
fix
1 parent a1b9580 commit 2d0f836

File tree

1 file changed

+7
-21
lines changed

1 file changed

+7
-21
lines changed

paddlenlp/trainer/auto_trainer.py

Lines changed: 7 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -547,30 +547,16 @@ def _save_checkpoint(self, model, metrics=None):
547547
else:
548548
optim_state_dict = self.optimizer.state_dict()
549549
optim_state_dict.pop("LR_Scheduler", None)
550-
550+
opt_state_keys = ["_moment1_0", "_moment2_0", "_beta1_pow_acc_0", "_beta2_pow_acc_0"]
551551
for p_name, p in model.state_dict().items():
552552
if paddle.distributed.get_rank() not in p.process_mesh.process_ids:
553553
var_name = p.name
554-
if (
555-
var_name + "_moment1_0" in optim_state_dict
556-
and not optim_state_dict[var_name + "_moment1_0"].is_dist()
557-
):
558-
optim_state_dict.pop(var_name + "_moment1_0")
559-
if (
560-
var_name + "_moment2_0" in optim_state_dict
561-
and not optim_state_dict[var_name + "_moment2_0"].is_dist()
562-
):
563-
optim_state_dict.pop(var_name + "_moment2_0")
564-
if (
565-
var_name + "_beta1_pow_acc_0" in optim_state_dict
566-
and not optim_state_dict[var_name + "_beta1_pow_acc_0"].is_dist()
567-
):
568-
optim_state_dict.pop(var_name + "_beta1_pow_acc_0")
569-
if (
570-
var_name + "_beta2_pow_acc_0" in optim_state_dict
571-
and not optim_state_dict[var_name + "_beta2_pow_acc_0"].is_dist()
572-
):
573-
optim_state_dict.pop(var_name + "_beta2_pow_acc_0")
554+
for key in opt_state_keys:
555+
if (
556+
var_name + key in optim_state_dict
557+
and not optim_state_dict[var_name + key].is_dist()
558+
):
559+
optim_state_dict.pop(var_name + key)
574560

575561
state_dict = {
576562
MODEL_NAME: model.state_dict(),

0 commit comments

Comments
 (0)