143143from .utils import reshard as reshard_util
144144from .utils .helper import ( # nested_truncate,
145145 broadcast_dp_optimizer ,
146+ broadcast_moe_optimizer ,
146147 distributed_concat ,
147148 distributed_file ,
148149 distributed_isfile ,
@@ -565,7 +566,7 @@ def _load_from_checkpoint(self, resume_from_checkpoint=None):
565566 )
566567 self .model .set_state_dict (state_dict )
567568 else :
568- if resume_from_checkpoint is not None and self .args .dataset_rank == 0 :
569+ if resume_from_checkpoint is not None and ( self .args .dataset_rank == 0 or self . args . use_expert_parallel ) :
569570
570571 weights_file = os .path .join (
571572 resume_from_checkpoint , _add_variant (weight_name , self .args .weight_name_suffix )
@@ -930,22 +931,17 @@ def _inner_training_loop(
930931 self .control = self .callback_handler .on_step_begin (args , self .state , self .control )
931932 self .timers and self .timers ("forward-backward" ).start ()
932933
933- dp_enabled = (
934- self .args .data_parallel_degree > 1 if self .args .use_hybrid_parallel else args .local_rank != - 1
935- )
936- forbidden_no_sync = False
937934 # stage2 and stage3 should not no_sync, because the is no DDP wrapper and no_sync API
938935 # hybrid_parallel (tp or pp or sharding stage 1) should not no_sync
939- if self .args .use_hybrid_parallel :
940- forbidden_no_sync = True
941-
942- availiable_no_sync = dp_enabled and not forbidden_no_sync
943-
936+ availiable_no_sync = hasattr (model , "no_sync" )
944937 is_no_sync = (
945- ((step_control + 1 ) % args .gradient_accumulation_steps != 0 )
946- and availiable_no_sync
947- and args ._no_sync_in_gradient_accumulation
948- ) or (args .recompute and availiable_no_sync )
938+ (
939+ ((step_control + 1 ) % args .gradient_accumulation_steps != 0 )
940+ and args ._no_sync_in_gradient_accumulation
941+ )
942+ or args .recompute
943+ or args .use_expert_parallel
944+ ) and availiable_no_sync
949945 # sharding
950946 # stage1. the same as ddp
951947 # stage2. manualy collect gradient on dp group
@@ -965,6 +961,14 @@ def _inner_training_loop(
965961
966962 tr_loss += tr_loss_step
967963
964+ def fused_allreduce_gradients_no_sync (paramlist , hcg ):
965+ paramlist = list (paramlist )
966+ nonmoe_list = [p for p in paramlist if not getattr (p , "no_sync" , False )]
967+ moelist = [p for p in paramlist if getattr (p , "no_sync" , False )]
968+ if moelist and not self .args .use_expert_parallel :
969+ logger .warning ("found `no sync` param when `use_expert_parallel=False`" )
970+ fused_allreduce_gradients (nonmoe_list , hcg )
971+
968972 if (step_control + 1 ) % args .gradient_accumulation_steps == 0 or (
969973 # last step in epoch but step is always smaller than gradient_accumulation_steps
970974 steps_in_epoch <= args .gradient_accumulation_steps
@@ -983,12 +987,12 @@ def _inner_training_loop(
983987
984988 # Case 1: Use recompute and dp / sharding stage1,
985989 # manualy collect gradient for dp.
986- if args .recompute and availiable_no_sync :
987- fused_allreduce_gradients (list (model .parameters ()), None )
990+ if ( args .recompute or args . use_expert_parallel ) and availiable_no_sync :
991+ fused_allreduce_gradients_no_sync (list (model .parameters ()), None )
988992
989993 # Case 2: hack dp with master_grad
990- if dp_master_grad and not ( args . recompute and availiable_no_sync ) :
991- fused_allreduce_gradients (list (model .parameters ()), None )
994+ elif dp_master_grad :
995+ fused_allreduce_gradients_no_sync (list (model .parameters ()), None )
992996
993997 # Pipeline parallel mode, handle gradient reduce here to overlap
994998 pipeline_parallel_config = (
@@ -1007,8 +1011,7 @@ def _inner_training_loop(
10071011 self .optimizer ._inner_opt .reduce_gradients (list (parameters_list ), self .optimizer ._hcg )
10081012
10091013 if self .optimizer ._dp_enable or getattr (self .optimizer , "_sep_enable" , False ):
1010- fused_allreduce_gradients (list (parameters_list ), self .optimizer ._hcg )
1011-
1014+ fused_allreduce_gradients_no_sync (list (parameters_list ), self .optimizer ._hcg )
10121015 self .timers and self .timers ("all-reduce" ).stop ()
10131016 self .timers and self .timers ("optimizer-step" ).start ()
10141017
@@ -1028,6 +1031,8 @@ def _inner_training_loop(
10281031 )
10291032 optimizer_was_run = True
10301033 if self .do_grad_scaling :
1034+ if args .pipeline_parallel_degree > 1 :
1035+ assert not self .args .use_expert_parallel , "pipeline moe not work under fp16"
10311036 scale_before = paddle .assign (self .scaler ._scale )
10321037 self .scaler .step (self .optimizer )
10331038 self .scaler .update ()
@@ -2042,7 +2047,6 @@ def training_step(self, model: nn.Layer, inputs: Dict[str, Union[paddle.Tensor,
20422047
20432048 model .train ()
20442049 inputs = self ._prepare_inputs (inputs )
2045-
20462050 with self .autocast_smart_context_manager ():
20472051 loss = self .compute_loss (model , inputs )
20482052
@@ -2053,7 +2057,6 @@ def training_step(self, model: nn.Layer, inputs: Dict[str, Union[paddle.Tensor,
20532057 self .scaler .scale (loss ).backward ()
20542058 else :
20552059 loss .backward ()
2056-
20572060 return loss .detach ()
20582061
20592062 def training_pipeline_step (self , model : nn .Layer , inputs : Dict [str , Union [paddle .Tensor , Any ]]) -> paddle .Tensor :
@@ -2143,6 +2146,26 @@ def save_model(self, output_dir: Optional[str] = None, merge_tensor_parallel: Op
21432146 # For ckpt integrity
21442147 paddle .save (self .state .global_step , os .path .join (output_dir , ".model_done" ))
21452148
2149+ def _filter_moe_no_sync_optimizer_params (self ):
2150+ """
2151+ filter optimizer params which should not sync
2152+ """
2153+ state_dict = self .model .state_dict ()
2154+ optimzier_state_dict = self .optimizer .state_dict ()
2155+ filter_optimzier_state_dict = OrderedDict ()
2156+ param_names_in_master_weights = list (optimzier_state_dict ["master_weights" ].keys ()) if self .args .bf16 else []
2157+ filter_optimzier_state_dict ["master_weights" ] = OrderedDict ()
2158+ for k , v in state_dict .items ():
2159+ if getattr (v , "no_sync" , False ):
2160+ if v .name in param_names_in_master_weights :
2161+ filter_optimzier_state_dict ["master_weights" ][v .name ] = optimzier_state_dict ["master_weights" ][
2162+ v .name
2163+ ]
2164+ for op_k , op_v in optimzier_state_dict .items ():
2165+ if op_k .startswith (v .name ):
2166+ filter_optimzier_state_dict [op_k ] = op_v
2167+ return filter_optimzier_state_dict
2168+
21462169 def _save_checkpoint (self , model , metrics = None ):
21472170 # assert unwrap_model(model) is self.model, "internal model should be a reference to self.model"
21482171 self .runtime_timer .start ("checkpoint saving time" )
@@ -2165,7 +2188,7 @@ def _save_checkpoint(self, model, metrics=None):
21652188 optimizer_name = _add_variant (OPTIMIZER_NAME , self .args .optimizer_name_suffix )
21662189
21672190 if self .args .use_hybrid_parallel :
2168- if self .dp_group .rank <= 0 :
2191+ if self .dp_group .rank <= 0 or self . args . use_expert_parallel :
21692192 os .makedirs (output_dir , exist_ok = True )
21702193 logger .info ("Saving optimizer files." )
21712194 if self .args .unified_checkpoint :
@@ -2177,12 +2200,18 @@ def _save_checkpoint(self, model, metrics=None):
21772200 safe_serialization = True ,
21782201 )
21792202 else :
2180- self ._save_ckpt_func (
2181- self .optimizer .state_dict (),
2182- os .path .join (output_dir , optimizer_name ),
2183- )
2203+ if self .dp_group .rank > 0 : # this should only work for MoE saving
2204+ self ._save_ckpt_func (
2205+ self ._filter_moe_no_sync_optimizer_params (),
2206+ os .path .join (output_dir , optimizer_name ),
2207+ )
2208+ else :
2209+ self ._save_ckpt_func (
2210+ self .optimizer .state_dict (),
2211+ os .path .join (output_dir , optimizer_name ),
2212+ )
21842213
2185- if self .args .should_save :
2214+ if self .args .should_save or self . args . use_expert_parallel :
21862215 if not self .args .use_hybrid_parallel :
21872216 logger .info ("Saving optimizer files." )
21882217 if self .args .unified_checkpoint :
@@ -2194,7 +2223,12 @@ def _save_checkpoint(self, model, metrics=None):
21942223 safe_serialization = True ,
21952224 )
21962225 else :
2197- self ._save_ckpt_func (self .optimizer .state_dict (), os .path .join (output_dir , OPTIMIZER_NAME ))
2226+ if self .dp_group .rank > 0 :
2227+ self ._save_ckpt_func (
2228+ self ._filter_moe_no_sync_optimizer_params (), os .path .join (output_dir , OPTIMIZER_NAME )
2229+ )
2230+ else :
2231+ self ._save_ckpt_func (self .optimizer .state_dict (), os .path .join (output_dir , OPTIMIZER_NAME ))
21982232
21992233 # FIXME: maybe only save one copy
22002234 paddle .save (self .lr_scheduler .state_dict (), os .path .join (output_dir , SCHEDULER_NAME ))
@@ -2452,7 +2486,7 @@ def _load_optimizer_and_scheduler(self, checkpoint):
24522486 logger .info ("Loading checkpoint, the next checkpoint will be saved as unified checkpoint" )
24532487
24542488 if not use_unified_checkpoint :
2455- if self .args .data_parallel_rank == 0 :
2489+ if self .args .data_parallel_rank == 0 or self . args . use_expert_parallel :
24562490 optimizer_name = _add_variant (OPTIMIZER_NAME , self .args .optimizer_name_suffix )
24572491 path = os .path .join (checkpoint , optimizer_name )
24582492 if os .path .isfile (path ):
@@ -2476,7 +2510,11 @@ def _load_optimizer_and_scheduler(self, checkpoint):
24762510 # broadcast optimizer state in dp group
24772511 if self .args .local_rank != - 1 :
24782512 dist .barrier ()
2479- opt_state_dict = broadcast_dp_optimizer (opt_state_dict )
2513+ if self .args .use_expert_parallel :
2514+ opt_state_dict = broadcast_moe_optimizer (opt_state_dict )
2515+ else :
2516+ if not self .args .should_load_sharding_stage1_model :
2517+ opt_state_dict = broadcast_dp_optimizer (opt_state_dict )
24802518
24812519 if opt_state_dict is not None :
24822520 # Load in optimizer and scheduler states
0 commit comments