Skip to content

Commit 8669149

Browse files
committed
[Trainer] Support MoE
1 parent efd29c0 commit 8669149

File tree

3 files changed

+148
-27
lines changed

3 files changed

+148
-27
lines changed

paddlenlp/trainer/trainer.py

Lines changed: 88 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848
from paddle.distributed.fleet.meta_parallel.sharding.group_sharded_optimizer_stage2 import (
4949
GroupShardedOptimizerStage2,
5050
)
51+
from paddle.utils import map_structure
5152

5253
try:
5354
from paddle.distributed.fleet.utils.hybrid_parallel_util import (
@@ -143,6 +144,7 @@
143144
from .utils import reshard as reshard_util
144145
from .utils.helper import ( # nested_truncate,
145146
broadcast_dp_optimizer,
147+
broadcast_moe_optimizer,
146148
distributed_concat,
147149
distributed_file,
148150
distributed_isfile,
@@ -940,12 +942,17 @@ def _inner_training_loop(
940942
forbidden_no_sync = True
941943

942944
availiable_no_sync = dp_enabled and not forbidden_no_sync
945+
has_no_sync = hasattr(model, "no_sync")
943946

944947
is_no_sync = (
945-
((step_control + 1) % args.gradient_accumulation_steps != 0)
946-
and availiable_no_sync
947-
and args._no_sync_in_gradient_accumulation
948-
) or (args.recompute and availiable_no_sync)
948+
(
949+
((step_control + 1) % args.gradient_accumulation_steps != 0)
950+
and availiable_no_sync
951+
and args._no_sync_in_gradient_accumulation
952+
)
953+
or (args.recompute and availiable_no_sync)
954+
or args.use_moe
955+
)
949956
# sharding
950957
# stage1. the same as ddp
951958
# stage2. manualy collect gradient on dp group
@@ -956,14 +963,25 @@ def _inner_training_loop(
956963
if dp_master_grad:
957964
is_no_sync = True
958965

959-
if is_no_sync:
966+
if is_no_sync and has_no_sync:
960967
# Avoid unnecessary DDP synchronization since there will be no backward pass on this example.
961968
with model.no_sync():
962969
tr_loss_step = self.training_step(model, inputs)
963970
else:
964971
tr_loss_step = self.training_step(model, inputs)
965972

966-
tr_loss += tr_loss_step
973+
def fused_allreduce_gradients_no_sync(param_list, hcg):
974+
param_list = list(param_list)
975+
nonmoe_list = [p for p in param_list if not getattr(p, "no_sync", False)]
976+
moe_list = [p for p in param_list if getattr(p, "no_sync", False)]
977+
if moe_list and not self.args.use_moe:
978+
logger.warning("found `no_sync` param when `use_moe=False`")
979+
fused_allreduce_gradients(nonmoe_list, hcg)
980+
981+
if tr_loss_step is not None:
982+
if tr_loss is None:
983+
tr_loss = map_structure(lambda x: paddle.zeros_like(x), tr_loss_step)
984+
map_structure(lambda x, y: x.add_(y), tr_loss, tr_loss_step)
967985

968986
if (step_control + 1) % args.gradient_accumulation_steps == 0 or (
969987
# last step in epoch but step is always smaller than gradient_accumulation_steps
@@ -983,12 +1001,13 @@ def _inner_training_loop(
9831001

9841002
# Case 1: Use recompute and dp / sharding stage1,
9851003
# manualy collect gradient for dp.
986-
if args.recompute and availiable_no_sync:
987-
fused_allreduce_gradients(list(model.parameters()), None)
1004+
# Case 1.1: pure dp + moe should manually collect gradient here.
1005+
if (args.recompute or args.use_moe) and availiable_no_sync:
1006+
fused_allreduce_gradients_no_sync(list(model.parameters()), None)
9881007

9891008
# Case 2: hack dp with master_grad
9901009
if dp_master_grad and not (args.recompute and availiable_no_sync):
991-
fused_allreduce_gradients(list(model.parameters()), None)
1010+
fused_allreduce_gradients_no_sync(list(model.parameters()), None)
9921011

9931012
# Pipeline parallel mode, handle gradient reduce here to overlap
9941013
pipeline_parallel_config = (
@@ -1007,7 +1026,9 @@ def _inner_training_loop(
10071026
self.optimizer._inner_opt.reduce_gradients(list(parameters_list), self.optimizer._hcg)
10081027

10091028
if self.optimizer._dp_enable or getattr(self.optimizer, "_sep_enable", False):
1010-
fused_allreduce_gradients(list(parameters_list), self.optimizer._hcg)
1029+
fused_allreduce_gradients_no_sync(list(parameters_list), self.optimizer._hcg)
1030+
else:
1031+
assert not self.args.use_moe, "moe should not use `enable_dp_comm_overlap`"
10111032

10121033
self.timers and self.timers("all-reduce").stop()
10131034
self.timers and self.timers("optimizer-step").start()
@@ -1132,7 +1153,7 @@ def _inner_training_loop(
11321153
"on multiple nodes, you should activate `--save_on_each_node`."
11331154
)
11341155

1135-
self._total_loss_scalar += tr_loss.item()
1156+
self._total_loss_scalar += tr_loss.pop("loss").item() if isinstance(tr_loss, dict) else tr_loss.item()
11361157
train_loss = self._total_loss_scalar / self.state.global_step
11371158

11381159
metrics = speed_metrics("train", start_time, num_samples=num_train_samples, num_steps=self.state.max_steps)
@@ -1250,12 +1271,22 @@ def _maybe_log_save_evaluate(self, tr_loss, model, epoch, ignore_keys_for_eval,
12501271
logs: Dict[str, float] = {}
12511272

12521273
# all_gather + mean() to get average loss over all processes
1253-
tr_loss_scalar = self._get_item_from_loss(self._nested_gather(tr_loss).mean())
1274+
tr_loss_scalar = map_structure(lambda x: self._get_item_from_loss(self._nested_gather(x).mean()), tr_loss)
12541275

12551276
# reset tr_loss to zero
1256-
tr_loss.subtract_(tr_loss)
1257-
1258-
logs["loss"] = round(tr_loss_scalar / (self.state.global_step - self._globalstep_last_logged), 8)
1277+
map_structure(lambda x: x.zero_(), tr_loss)
1278+
1279+
if isinstance(tr_loss_scalar, dict):
1280+
for k, v in tr_loss_scalar.items():
1281+
logs[k] = round(v / (self.state.global_step - self._globalstep_last_logged), 8)
1282+
elif isinstance(tr_loss_scalar, (list, tuple)):
1283+
for i, v in enumerate(tr_loss_scalar):
1284+
logs[f"loss_{i}"] = round(v / (self.state.global_step - self._globalstep_last_logged), 8)
1285+
else:
1286+
logs["loss"] = round(
1287+
tr_loss_scalar / (self.state.global_step - self._globalstep_last_logged),
1288+
8,
1289+
)
12591290
logs["learning_rate"] = float("{0:.3e}".format(self._get_learning_rate()))
12601291
logs["global_step"] = int(self.state.global_step)
12611292

@@ -1290,7 +1321,9 @@ def _maybe_log_save_evaluate(self, tr_loss, model, epoch, ignore_keys_for_eval,
12901321
)
12911322
)
12921323

1293-
self._total_loss_scalar += tr_loss_scalar
1324+
self._total_loss_scalar += (
1325+
tr_loss_scalar.pop("loss") if isinstance(tr_loss_scalar, dict) else tr_loss_scalar
1326+
)
12941327
self._globalstep_last_logged = self.state.global_step
12951328
self._globalstep_last_start_time = time.time()
12961329

@@ -2047,14 +2080,19 @@ def training_step(self, model: nn.Layer, inputs: Dict[str, Union[paddle.Tensor,
20472080
loss = self.compute_loss(model, inputs)
20482081

20492082
if self.args.gradient_accumulation_steps > 1 and not self._enable_delay_scale_loss():
2050-
loss = loss / self.args.gradient_accumulation_steps
2083+
loss = map_structure(lambda x: x / self.args.gradient_accumulation_steps, loss)
2084+
2085+
if isinstance(loss, dict):
2086+
total_loss = loss["loss"]
2087+
else:
2088+
total_loss = loss
20512089

20522090
if self.do_grad_scaling:
2053-
self.scaler.scale(loss).backward()
2091+
self.scaler.scale(total_loss).backward()
20542092
else:
2055-
loss.backward()
2093+
total_loss.backward()
20562094

2057-
return loss.detach()
2095+
return map_structure(lambda v: v.detach(), loss)
20582096

20592097
def training_pipeline_step(self, model: nn.Layer, inputs: Dict[str, Union[paddle.Tensor, Any]]) -> paddle.Tensor:
20602098
"""
@@ -2113,6 +2151,18 @@ def training_pipeline_step(self, model: nn.Layer, inputs: Dict[str, Union[paddle
21132151

21142152
return loss.detach()
21152153

2154+
def _save_moe_weights(
2155+
self,
2156+
output_dir,
2157+
merge_tensor_parallel: Optional[bool] = False,
2158+
):
2159+
self._save(output_dir=output_dir, merge_tensor_parallel=merge_tensor_parallel)
2160+
if not self.args.ignore_save_lr_and_optim:
2161+
self._save_ckpt_func(
2162+
self.optimizer.state_dict(),
2163+
os.path.join(output_dir, _add_variant(OPTIMIZER_NAME, self.args.optimizer_name_suffix)),
2164+
)
2165+
21162166
def save_model(self, output_dir: Optional[str] = None, merge_tensor_parallel: Optional[bool] = False):
21172167
"""
21182168
Will save the model, so you can reload it using `from_pretrained()`.
@@ -2126,7 +2176,12 @@ def save_model(self, output_dir: Optional[str] = None, merge_tensor_parallel: Op
21262176
if ShardingOption.FULL_SHARD in self.args.sharding:
21272177
self.model_wrapped.get_all_parameters(convert2cpu=True)
21282178

2129-
if self.args.should_save_model_state:
2179+
if not self.is_in_train and self.args.use_moe:
2180+
should_save_model_state = self.args.should_save_moe_model_state
2181+
else:
2182+
should_save_model_state = self.args.should_save_model_state
2183+
2184+
if should_save_model_state:
21302185
unified_checkpoint_config_backup = self.args.unified_checkpoint_config
21312186
# backup and remove unified_checkpoint_config for not trine stage
21322187
if not self.is_in_train:
@@ -2245,6 +2300,10 @@ def _save_checkpoint(self, model, metrics=None):
22452300
os.makedirs(output_dir, exist_ok=True)
22462301
paddle.save(rng_states, os.path.join(output_dir, "rng_state.pth"))
22472302

2303+
if self.args.use_moe and self.args.data_parallel_rank > 0:
2304+
logger.info("Saving moe weights for data_parallel_rank > 0")
2305+
self._save_moe_weights(output_dir)
2306+
22482307
# Maybe delete some older checkpoints.
22492308
# For hybrid parallel training, the checkpoint files maybe on different node.
22502309
need_to_rotate_checkpoints = False
@@ -2452,7 +2511,7 @@ def _load_optimizer_and_scheduler(self, checkpoint):
24522511
logger.info("Loading checkpoint, the next checkpoint will be saved as unified checkpoint")
24532512

24542513
if not use_unified_checkpoint:
2455-
if self.args.data_parallel_rank == 0:
2514+
if self.args.use_moe or self.args.data_parallel_rank == 0:
24562515
optimizer_name = _add_variant(OPTIMIZER_NAME, self.args.optimizer_name_suffix)
24572516
path = os.path.join(checkpoint, optimizer_name)
24582517
if os.path.isfile(path):
@@ -2476,7 +2535,11 @@ def _load_optimizer_and_scheduler(self, checkpoint):
24762535
# broadcast optimizer state in dp group
24772536
if self.args.local_rank != -1:
24782537
dist.barrier()
2479-
opt_state_dict = broadcast_dp_optimizer(opt_state_dict)
2538+
if not self.args.use_moe:
2539+
opt_state_dict = broadcast_dp_optimizer(opt_state_dict)
2540+
else:
2541+
state_dict = self.model.state_dict()
2542+
opt_state_dict = broadcast_moe_optimizer(state_dict, opt_state_dict)
24802543

24812544
if opt_state_dict is not None:
24822545
# Load in optimizer and scheduler states
@@ -2939,6 +3002,8 @@ def prediction_step(
29393002
if has_labels:
29403003
with self.autocast_smart_context_manager():
29413004
loss, outputs = self.compute_loss(model, inputs, return_outputs=True)
3005+
if isinstance(loss, dict):
3006+
loss = loss.pop("loss")
29423007
loss = loss.mean().detach()
29433008

29443009
if isinstance(outputs, dict):

paddlenlp/trainer/training_args.py

Lines changed: 27 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1650,11 +1650,11 @@ def optimizer_name_suffix(self):
16501650
if self.sharding_parallel_degree > 1:
16511651
name.append(self._format_name("shard", self.sharding_parallel_rank, self.sharding_parallel_degree))
16521652
if self.use_moe:
1653-
name.append(f"moe{self.data_parallel_rank:0>2d}")
1653+
name.append(self._format_name("moe", self.data_parallel_rank, self.data_parallel_degree))
16541654
return "_".join(name)
16551655
else:
16561656
if self.use_moe:
1657-
return f"moe{self.data_parallel_rank:0>2d}"
1657+
return self._format_name("moe", self.data_parallel_rank, self.data_parallel_degree)
16581658
return None
16591659

16601660
@property
@@ -1666,12 +1666,12 @@ def weight_name_suffix(self):
16661666
if self.pipeline_parallel_degree > 1:
16671667
name.append(self._format_name("pp", self.pipeline_parallel_rank, self.pipeline_parallel_degree))
16681668
if self.use_moe:
1669-
name.append(f"moe{self.data_parallel_rank:0>2d}")
1669+
name.append(self._format_name("moe", self.data_parallel_rank, self.data_parallel_degree))
16701670
return "_".join(name)
16711671

16721672
else:
16731673
if self.use_moe:
1674-
return f"moe{self.data_parallel_rank:0>2d}"
1674+
return self._format_name("moe", self.data_parallel_rank, self.data_parallel_degree)
16751675
return None
16761676

16771677
def sharded_name_suffix(self, shard_id=None, pp_id=None):
@@ -1787,6 +1787,29 @@ def should_save_model_state(self):
17871787
else:
17881788
return self.process_index == 0
17891789

1790+
@property
1791+
def should_save_moe_model_state(self):
1792+
"""
1793+
Whether or not the current process should write to disk, e.g., to save moe models and checkpoints.
1794+
1795+
For model state:
1796+
work for data parallel, tensor parallel, sharding
1797+
For optimizer state:
1798+
work for data parallel, tensor parallel
1799+
not work for sharding
1800+
"""
1801+
if self.save_on_each_node:
1802+
return self.local_process_index == 0
1803+
else:
1804+
if self.should_save_sharding_stage1_model:
1805+
return True
1806+
elif self.enable_auto_parallel:
1807+
return True
1808+
elif self.use_hybrid_parallel:
1809+
return self.sharding_parallel_rank == 0
1810+
else:
1811+
return self.process_index == 0
1812+
17901813
@property
17911814
def _no_sync_in_gradient_accumulation(self):
17921815
"""

paddlenlp/trainer/utils/helper.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -226,3 +226,36 @@ def broadcast_dp_optimizer(state_dict):
226226
state_dict = nested_broadcast_tensor(state_dict, src=src_rank, group=dp_group)
227227

228228
return state_dict
229+
230+
231+
def broadcast_moe_optimizer(state_dict, opt_state_dict):
232+
no_sync_vname = []
233+
for k, v in state_dict.items():
234+
if getattr(v, "no_sync", False):
235+
no_sync_vname.append(v.name)
236+
new_opt_state_dict = broadcast_dp_optimizer(opt_state_dict)
237+
# 1. when updating opt_state_dict, we should disable broading the parameters with the same name when `no_sync=True`.
238+
# 2. if the keys of opt_state_dict and new_opt_state_dict are exactly the same, there is no need to update.
239+
# 3. if they are different, the update should be based on the `no_sync_vname`.
240+
if len(opt_state_dict.keys()) != len(new_opt_state_dict.keys()):
241+
for op_k, op_v in new_opt_state_dict.items():
242+
if op_k == "master_weights":
243+
for k, v in new_opt_state_dict["master_weights"].items():
244+
no_sync = False
245+
for no_sync_v in no_sync_vname:
246+
if k.startswith(no_sync_v):
247+
no_sync = True
248+
break
249+
if not no_sync:
250+
opt_state_dict["master_weights"][k] = v
251+
elif op_k == "LR_Scheduler":
252+
pass
253+
else:
254+
no_sync = False
255+
for no_sync_v in no_sync_vname:
256+
if op_k.startswith(no_sync_v):
257+
no_sync = True
258+
break
259+
if not no_sync:
260+
opt_state_dict[op_k] = op_v
261+
return opt_state_dict

0 commit comments

Comments
 (0)