143143from .utils import reshard as reshard_util
144144from .utils .helper import ( # nested_truncate,
145145 broadcast_dp_optimizer ,
146+ broadcast_moe_optimizer ,
146147 distributed_concat ,
147148 distributed_file ,
148149 distributed_isfile ,
@@ -945,7 +946,8 @@ def _inner_training_loop(
945946 ((step_control + 1 ) % args .gradient_accumulation_steps != 0 )
946947 and availiable_no_sync
947948 and args ._no_sync_in_gradient_accumulation
948- ) or (args .recompute and availiable_no_sync )
949+ ) or (args .recompute and availiable_no_sync
950+ ) or (args .use_moe and availiable_no_sync )
949951 # sharding
950952 # stage1. the same as ddp
951953 # stage2. manualy collect gradient on dp group
@@ -965,6 +967,14 @@ def _inner_training_loop(
965967
966968 tr_loss += tr_loss_step
967969
970+ def fused_allreduce_gradients_no_sync (paramlist , hcg ):
971+ paramlist = list (paramlist )
972+ nonmoe_list = [p for p in paramlist if not getattr (p , "no_sync" , False )]
973+ moelist = [p for p in paramlist if getattr (p , "no_sync" , False )]
974+ if moelist and not self .args .use_moe :
975+ logger .warning ("found `no sync` param when `use_moe=False`" )
976+ fused_allreduce_gradients (nonmoe_list , hcg )
977+
968978 if (step_control + 1 ) % args .gradient_accumulation_steps == 0 or (
969979 # last step in epoch but step is always smaller than gradient_accumulation_steps
970980 steps_in_epoch <= args .gradient_accumulation_steps
@@ -983,12 +993,12 @@ def _inner_training_loop(
983993
984994 # Case 1: Use recompute and dp / sharding stage1,
985995 # manualy collect gradient for dp.
986- if args .recompute and availiable_no_sync :
987- fused_allreduce_gradients (list (model .parameters ()), None )
996+ if ( args .recompute or args . use_moe ) and availiable_no_sync :
997+ fused_allreduce_gradients_no_sync (list (model .parameters ()), None )
988998
989999 # Case 2: hack dp with master_grad
990- if dp_master_grad and not ( args . recompute and availiable_no_sync ) :
991- fused_allreduce_gradients (list (model .parameters ()), None )
1000+ elif dp_master_grad :
1001+ fused_allreduce_gradients_no_sync (list (model .parameters ()), None )
9921002
9931003 # Pipeline parallel mode, handle gradient reduce here to overlap
9941004 pipeline_parallel_config = (
@@ -1007,8 +1017,7 @@ def _inner_training_loop(
10071017 self .optimizer ._inner_opt .reduce_gradients (list (parameters_list ), self .optimizer ._hcg )
10081018
10091019 if self .optimizer ._dp_enable or getattr (self .optimizer , "_sep_enable" , False ):
1010- fused_allreduce_gradients (list (parameters_list ), self .optimizer ._hcg )
1011-
1020+ fused_allreduce_gradients_no_sync (list (parameters_list ), self .optimizer ._hcg )
10121021 self .timers and self .timers ("all-reduce" ).stop ()
10131022 self .timers and self .timers ("optimizer-step" ).start ()
10141023
@@ -1028,7 +1037,9 @@ def _inner_training_loop(
10281037 )
10291038 optimizer_was_run = True
10301039 if self .do_grad_scaling :
1031- scale_before = paddle .assign (self .scaler ._scale )
1040+ if args .pipeline_parallel_degree > 1 :
1041+ assert not self .args .use_moe , "pipline moe not work under fp16"
1042+ scale_before = self .scaler ._scale .numpy ()
10321043 self .scaler .step (self .optimizer )
10331044 self .scaler .update ()
10341045 scale_after = self .scaler ._scale
@@ -2042,7 +2053,7 @@ def training_step(self, model: nn.Layer, inputs: Dict[str, Union[paddle.Tensor,
20422053
20432054 model .train ()
20442055 inputs = self ._prepare_inputs (inputs )
2045-
2056+ self . timers and self . timers ( f"forward-acc- { self . _cur_acc_step } " ). start ()
20462057 with self .autocast_smart_context_manager ():
20472058 loss = self .compute_loss (model , inputs )
20482059
@@ -2053,7 +2064,7 @@ def training_step(self, model: nn.Layer, inputs: Dict[str, Union[paddle.Tensor,
20532064 self .scaler .scale (loss ).backward ()
20542065 else :
20552066 loss .backward ()
2056-
2067+ self . timers and self . timers ( f"backward-acc- { self . _cur_acc_step } " ). stop ()
20572068 return loss .detach ()
20582069
20592070 def training_pipeline_step (self , model : nn .Layer , inputs : Dict [str , Union [paddle .Tensor , Any ]]) -> paddle .Tensor :
@@ -2143,6 +2154,19 @@ def save_model(self, output_dir: Optional[str] = None, merge_tensor_parallel: Op
21432154 # For ckpt integrity
21442155 paddle .save (self .state .global_step , os .path .join (output_dir , ".model_done" ))
21452156
2157+ def _save_moe_weights (
2158+ self ,
2159+ output_dir : Optional [str ] = None ,
2160+ merge_tensor_parallel : Optional [bool ] = False ,):
2161+ # save moe optimizer and model state # TODO 默认为冗余存储
2162+
2163+ self ._save (output_dir = output_dir , merge_tensor_parallel = merge_tensor_parallel )
2164+ optimizer_name = _add_variant (OPTIMIZER_NAME , self .args .optimizer_name_suffix )
2165+ saved_signal_path = os .path .join (output_dir , f"saved_signal_{ dist .get_rank ()} " )
2166+ paddle .save (self .optimizer .state_dict (), os .path .join (output_dir , optimizer_name ))
2167+ with open (saved_signal_path , mode = "w+" ) as f :
2168+ f .write ("1" )
2169+
21462170 def _save_checkpoint (self , model , metrics = None ):
21472171 # assert unwrap_model(model) is self.model, "internal model should be a reference to self.model"
21482172 self .runtime_timer .start ("checkpoint saving time" )
@@ -2245,6 +2269,8 @@ def _save_checkpoint(self, model, metrics=None):
22452269 os .makedirs (output_dir , exist_ok = True )
22462270 paddle .save (rng_states , os .path .join (output_dir , "rng_state.pth" ))
22472271
2272+ if self .args .use_moe and self .args .data_parallel_rank > 0 :
2273+ self ._save_moe_weights (output_dir )
22482274 # Maybe delete some older checkpoints.
22492275 # For hybrid parallel training, the checkpoint files maybe on different node.
22502276 need_to_rotate_checkpoints = False
@@ -2476,7 +2502,10 @@ def _load_optimizer_and_scheduler(self, checkpoint):
24762502 # broadcast optimizer state in dp group
24772503 if self .args .local_rank != - 1 :
24782504 dist .barrier ()
2479- opt_state_dict = broadcast_dp_optimizer (opt_state_dict )
2505+ if not self .args .use_moe :
2506+ opt_state_dict = broadcast_dp_optimizer (opt_state_dict )
2507+ # else:
2508+ # opt_state_dict = broadcast_moe_optimizer(opt_state_dict)
24802509
24812510 if opt_state_dict is not None :
24822511 # Load in optimizer and scheduler states
0 commit comments