4848from paddle .distributed .fleet .meta_parallel .sharding .group_sharded_optimizer_stage2 import (
4949 GroupShardedOptimizerStage2 ,
5050)
51+ from paddle .utils import map_structure
5152
5253try :
5354 from paddle .distributed .fleet .utils .hybrid_parallel_util import (
143144from .utils import reshard as reshard_util
144145from .utils .helper import ( # nested_truncate,
145146 broadcast_dp_optimizer ,
147+ broadcast_moe_optimizer ,
146148 distributed_concat ,
147149 distributed_file ,
148150 distributed_isfile ,
@@ -940,12 +942,17 @@ def _inner_training_loop(
940942 forbidden_no_sync = True
941943
942944 availiable_no_sync = dp_enabled and not forbidden_no_sync
945+ has_no_sync = hasattr (model , "no_sync" )
943946
944947 is_no_sync = (
945- ((step_control + 1 ) % args .gradient_accumulation_steps != 0 )
946- and availiable_no_sync
947- and args ._no_sync_in_gradient_accumulation
948- ) or (args .recompute and availiable_no_sync )
948+ (
949+ ((step_control + 1 ) % args .gradient_accumulation_steps != 0 )
950+ and availiable_no_sync
951+ and args ._no_sync_in_gradient_accumulation
952+ )
953+ or (args .recompute and availiable_no_sync )
954+ or args .use_moe
955+ )
949956 # sharding
950957 # stage1. the same as ddp
951958 # stage2. manualy collect gradient on dp group
@@ -956,14 +963,25 @@ def _inner_training_loop(
956963 if dp_master_grad :
957964 is_no_sync = True
958965
959- if is_no_sync :
966+ if is_no_sync and has_no_sync :
960967 # Avoid unnecessary DDP synchronization since there will be no backward pass on this example.
961968 with model .no_sync ():
962969 tr_loss_step = self .training_step (model , inputs )
963970 else :
964971 tr_loss_step = self .training_step (model , inputs )
965972
966- tr_loss += tr_loss_step
973+ def fused_allreduce_gradients_no_sync (param_list , hcg ):
974+ param_list = list (param_list )
975+ nonmoe_list = [p for p in param_list if not getattr (p , "no_sync" , False )]
976+ moe_list = [p for p in param_list if getattr (p , "no_sync" , False )]
977+ if moe_list and not self .args .use_moe :
978+ logger .warning ("found `no_sync` param when `use_moe=False`" )
979+ fused_allreduce_gradients (nonmoe_list , hcg )
980+
981+ if tr_loss_step is not None :
982+ if tr_loss is None :
983+ tr_loss = map_structure (lambda x : paddle .zeros_like (x ), tr_loss_step )
984+ map_structure (lambda x , y : x .add_ (y ), tr_loss , tr_loss_step )
967985
968986 if (step_control + 1 ) % args .gradient_accumulation_steps == 0 or (
969987 # last step in epoch but step is always smaller than gradient_accumulation_steps
@@ -983,12 +1001,13 @@ def _inner_training_loop(
9831001
9841002 # Case 1: Use recompute and dp / sharding stage1,
9851003 # manualy collect gradient for dp.
986- if args .recompute and availiable_no_sync :
987- fused_allreduce_gradients (list (model .parameters ()), None )
1004+ # Case 1.1: pure dp + moe should manually collect gradient here.
1005+ if (args .recompute or args .use_moe ) and availiable_no_sync :
1006+ fused_allreduce_gradients_no_sync (list (model .parameters ()), None )
9881007
9891008 # Case 2: hack dp with master_grad
9901009 if dp_master_grad and not (args .recompute and availiable_no_sync ):
991- fused_allreduce_gradients (list (model .parameters ()), None )
1010+ fused_allreduce_gradients_no_sync (list (model .parameters ()), None )
9921011
9931012 # Pipeline parallel mode, handle gradient reduce here to overlap
9941013 pipeline_parallel_config = (
@@ -1007,7 +1026,9 @@ def _inner_training_loop(
10071026 self .optimizer ._inner_opt .reduce_gradients (list (parameters_list ), self .optimizer ._hcg )
10081027
10091028 if self .optimizer ._dp_enable or getattr (self .optimizer , "_sep_enable" , False ):
1010- fused_allreduce_gradients (list (parameters_list ), self .optimizer ._hcg )
1029+ fused_allreduce_gradients_no_sync (list (parameters_list ), self .optimizer ._hcg )
1030+ else :
1031+ assert not self .args .use_moe , "moe should not use `enable_dp_comm_overlap`"
10111032
10121033 self .timers and self .timers ("all-reduce" ).stop ()
10131034 self .timers and self .timers ("optimizer-step" ).start ()
@@ -1132,7 +1153,7 @@ def _inner_training_loop(
11321153 "on multiple nodes, you should activate `--save_on_each_node`."
11331154 )
11341155
1135- self ._total_loss_scalar += tr_loss .item ()
1156+ self ._total_loss_scalar += tr_loss .pop ( "loss" ). item () if isinstance ( tr_loss , dict ) else tr_loss . item ()
11361157 train_loss = self ._total_loss_scalar / self .state .global_step
11371158
11381159 metrics = speed_metrics ("train" , start_time , num_samples = num_train_samples , num_steps = self .state .max_steps )
@@ -1250,12 +1271,22 @@ def _maybe_log_save_evaluate(self, tr_loss, model, epoch, ignore_keys_for_eval,
12501271 logs : Dict [str , float ] = {}
12511272
12521273 # all_gather + mean() to get average loss over all processes
1253- tr_loss_scalar = self ._get_item_from_loss (self ._nested_gather (tr_loss ).mean ())
1274+ tr_loss_scalar = map_structure ( lambda x : self ._get_item_from_loss (self ._nested_gather (x ).mean ()), tr_loss )
12541275
12551276 # reset tr_loss to zero
1256- tr_loss .subtract_ (tr_loss )
1257-
1258- logs ["loss" ] = round (tr_loss_scalar / (self .state .global_step - self ._globalstep_last_logged ), 8 )
1277+ map_structure (lambda x : x .zero_ (), tr_loss )
1278+
1279+ if isinstance (tr_loss_scalar , dict ):
1280+ for k , v in tr_loss_scalar .items ():
1281+ logs [k ] = round (v / (self .state .global_step - self ._globalstep_last_logged ), 8 )
1282+ elif isinstance (tr_loss_scalar , (list , tuple )):
1283+ for i , v in enumerate (tr_loss_scalar ):
1284+ logs [f"loss_{ i } " ] = round (v / (self .state .global_step - self ._globalstep_last_logged ), 8 )
1285+ else :
1286+ logs ["loss" ] = round (
1287+ tr_loss_scalar / (self .state .global_step - self ._globalstep_last_logged ),
1288+ 8 ,
1289+ )
12591290 logs ["learning_rate" ] = float ("{0:.3e}" .format (self ._get_learning_rate ()))
12601291 logs ["global_step" ] = int (self .state .global_step )
12611292
@@ -1290,7 +1321,9 @@ def _maybe_log_save_evaluate(self, tr_loss, model, epoch, ignore_keys_for_eval,
12901321 )
12911322 )
12921323
1293- self ._total_loss_scalar += tr_loss_scalar
1324+ self ._total_loss_scalar += (
1325+ tr_loss_scalar .pop ("loss" ) if isinstance (tr_loss_scalar , dict ) else tr_loss_scalar
1326+ )
12941327 self ._globalstep_last_logged = self .state .global_step
12951328 self ._globalstep_last_start_time = time .time ()
12961329
@@ -2047,14 +2080,19 @@ def training_step(self, model: nn.Layer, inputs: Dict[str, Union[paddle.Tensor,
20472080 loss = self .compute_loss (model , inputs )
20482081
20492082 if self .args .gradient_accumulation_steps > 1 and not self ._enable_delay_scale_loss ():
2050- loss = loss / self .args .gradient_accumulation_steps
2083+ loss = map_structure (lambda x : x / self .args .gradient_accumulation_steps , loss )
2084+
2085+ if isinstance (loss , dict ):
2086+ total_loss = loss ["loss" ]
2087+ else :
2088+ total_loss = loss
20512089
20522090 if self .do_grad_scaling :
2053- self .scaler .scale (loss ).backward ()
2091+ self .scaler .scale (total_loss ).backward ()
20542092 else :
2055- loss .backward ()
2093+ total_loss .backward ()
20562094
2057- return loss .detach ()
2095+ return map_structure ( lambda v : v .detach (), loss )
20582096
20592097 def training_pipeline_step (self , model : nn .Layer , inputs : Dict [str , Union [paddle .Tensor , Any ]]) -> paddle .Tensor :
20602098 """
@@ -2113,6 +2151,18 @@ def training_pipeline_step(self, model: nn.Layer, inputs: Dict[str, Union[paddle
21132151
21142152 return loss .detach ()
21152153
2154+ def _save_moe_weights (
2155+ self ,
2156+ output_dir ,
2157+ merge_tensor_parallel : Optional [bool ] = False ,
2158+ ):
2159+ self ._save (output_dir = output_dir , merge_tensor_parallel = merge_tensor_parallel )
2160+ if not self .args .ignore_save_lr_and_optim :
2161+ self ._save_ckpt_func (
2162+ self .optimizer .state_dict (),
2163+ os .path .join (output_dir , _add_variant (OPTIMIZER_NAME , self .args .optimizer_name_suffix )),
2164+ )
2165+
21162166 def save_model (self , output_dir : Optional [str ] = None , merge_tensor_parallel : Optional [bool ] = False ):
21172167 """
21182168 Will save the model, so you can reload it using `from_pretrained()`.
@@ -2126,7 +2176,12 @@ def save_model(self, output_dir: Optional[str] = None, merge_tensor_parallel: Op
21262176 if ShardingOption .FULL_SHARD in self .args .sharding :
21272177 self .model_wrapped .get_all_parameters (convert2cpu = True )
21282178
2129- if self .args .should_save_model_state :
2179+ if not self .is_in_train and self .args .use_moe :
2180+ should_save_model_state = self .args .should_save_moe_model_state
2181+ else :
2182+ should_save_model_state = self .args .should_save_model_state
2183+
2184+ if should_save_model_state :
21302185 unified_checkpoint_config_backup = self .args .unified_checkpoint_config
21312186 # backup and remove unified_checkpoint_config for not trine stage
21322187 if not self .is_in_train :
@@ -2245,6 +2300,10 @@ def _save_checkpoint(self, model, metrics=None):
22452300 os .makedirs (output_dir , exist_ok = True )
22462301 paddle .save (rng_states , os .path .join (output_dir , "rng_state.pth" ))
22472302
2303+ if self .args .use_moe and self .args .data_parallel_rank > 0 :
2304+ logger .info ("Saving moe weights for data_parallel_rank > 0" )
2305+ self ._save_moe_weights (output_dir )
2306+
22482307 # Maybe delete some older checkpoints.
22492308 # For hybrid parallel training, the checkpoint files maybe on different node.
22502309 need_to_rotate_checkpoints = False
@@ -2452,7 +2511,7 @@ def _load_optimizer_and_scheduler(self, checkpoint):
24522511 logger .info ("Loading checkpoint, the next checkpoint will be saved as unified checkpoint" )
24532512
24542513 if not use_unified_checkpoint :
2455- if self .args .data_parallel_rank == 0 :
2514+ if self .args .use_moe or self . args . data_parallel_rank == 0 :
24562515 optimizer_name = _add_variant (OPTIMIZER_NAME , self .args .optimizer_name_suffix )
24572516 path = os .path .join (checkpoint , optimizer_name )
24582517 if os .path .isfile (path ):
@@ -2476,7 +2535,11 @@ def _load_optimizer_and_scheduler(self, checkpoint):
24762535 # broadcast optimizer state in dp group
24772536 if self .args .local_rank != - 1 :
24782537 dist .barrier ()
2479- opt_state_dict = broadcast_dp_optimizer (opt_state_dict )
2538+ if not self .args .use_moe :
2539+ opt_state_dict = broadcast_dp_optimizer (opt_state_dict )
2540+ else :
2541+ state_dict = self .model .state_dict ()
2542+ opt_state_dict = broadcast_moe_optimizer (state_dict , opt_state_dict )
24802543
24812544 if opt_state_dict is not None :
24822545 # Load in optimizer and scheduler states
@@ -2939,6 +3002,8 @@ def prediction_step(
29393002 if has_labels :
29403003 with self .autocast_smart_context_manager ():
29413004 loss , outputs = self .compute_loss (model , inputs , return_outputs = True )
3005+ if isinstance (loss , dict ):
3006+ loss = loss .pop ("loss" )
29423007 loss = loss .mean ().detach ()
29433008
29443009 if isinstance (outputs , dict ):
0 commit comments