@@ -547,30 +547,16 @@ def _save_checkpoint(self, model, metrics=None):
547547 else :
548548 optim_state_dict = self .optimizer .state_dict ()
549549 optim_state_dict .pop ("LR_Scheduler" , None )
550-
550+ opt_state_keys = [ "_moment1_0" , "_moment2_0" , "_beta1_pow_acc_0" , "_beta2_pow_acc_0" ]
551551 for p_name , p in model .state_dict ().items ():
552552 if paddle .distributed .get_rank () not in p .process_mesh .process_ids :
553553 var_name = p .name
554- if (
555- var_name + "_moment1_0" in optim_state_dict
556- and not optim_state_dict [var_name + "_moment1_0" ].is_dist ()
557- ):
558- optim_state_dict .pop (var_name + "_moment1_0" )
559- if (
560- var_name + "_moment2_0" in optim_state_dict
561- and not optim_state_dict [var_name + "_moment2_0" ].is_dist ()
562- ):
563- optim_state_dict .pop (var_name + "_moment2_0" )
564- if (
565- var_name + "_beta1_pow_acc_0" in optim_state_dict
566- and not optim_state_dict [var_name + "_beta1_pow_acc_0" ].is_dist ()
567- ):
568- optim_state_dict .pop (var_name + "_beta1_pow_acc_0" )
569- if (
570- var_name + "_beta2_pow_acc_0" in optim_state_dict
571- and not optim_state_dict [var_name + "_beta2_pow_acc_0" ].is_dist ()
572- ):
573- optim_state_dict .pop (var_name + "_beta2_pow_acc_0" )
554+ for key in opt_state_keys :
555+ if (
556+ var_name + key in optim_state_dict
557+ and not optim_state_dict [var_name + key ].is_dist ()
558+ ):
559+ optim_state_dict .pop (var_name + key )
574560
575561 state_dict = {
576562 MODEL_NAME : model .state_dict (),
0 commit comments