Skip to content

Commit 440c093

Browse files
authored
[fix] Broadcast optimizer state using broadcast_dp without shard-reshard. (#8522)
1 parent e71540b commit 440c093

File tree

2 files changed

+10
-5
lines changed

2 files changed

+10
-5
lines changed

paddlenlp/trainer/trainer.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -582,7 +582,7 @@ def _load_from_checkpoint(self, resume_from_checkpoint=None):
582582
weights_index_file,
583583
]
584584
):
585-
raise ValueError(f"Can't find a valid checkpoint at {resume_from_checkpoint}")
585+
raise ValueError(f"Can't find a valid checkpoint at {resume_from_checkpoint} -- {weights_file}")
586586

587587
logger.info(f"Loading model from {resume_from_checkpoint} .")
588588

@@ -2237,7 +2237,7 @@ def _save_checkpoint(self, model, metrics=None):
22372237
safe_serialization=True,
22382238
)
22392239
else:
2240-
if self.dp_group.rank > 0:
2240+
if self.args.data_parallel_rank > 0 and self.args.use_expert_parallel:
22412241
self._save_ckpt_func(
22422242
self._filter_moe_no_sync_optimizer_params(), os.path.join(output_dir, OPTIMIZER_NAME)
22432243
)
@@ -2525,7 +2525,9 @@ def _load_optimizer_and_scheduler(self, checkpoint):
25252525
if self.args.local_rank != -1:
25262526
dist.barrier()
25272527
if self.args.use_expert_parallel:
2528-
opt_state_dict = broadcast_moe_optimizer(opt_state_dict)
2528+
opt_state_dict = broadcast_moe_optimizer(
2529+
opt_state_dict, broadcast_dp=not self.args.should_load_sharding_stage1_model
2530+
)
25292531
else:
25302532
if not self.args.should_load_sharding_stage1_model:
25312533
opt_state_dict = broadcast_dp_optimizer(opt_state_dict)

paddlenlp/trainer/utils/helper.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -228,7 +228,7 @@ def broadcast_dp_optimizer(state_dict):
228228
return state_dict
229229

230230

231-
def broadcast_moe_optimizer(state_dict):
231+
def broadcast_moe_optimizer(state_dict, broadcast_dp=True):
232232

233233
try:
234234
hcg = fleet.get_hybrid_communicate_group()
@@ -270,7 +270,10 @@ def _broadcast_moe_optimizer_state(state_dict):
270270
base_state_dict.update(buf[2])
271271
return base_state_dict
272272

273-
base_state_dict = _broadcast_moe_optimizer_state(state_dict)
273+
if broadcast_dp:
274+
base_state_dict = broadcast_dp_optimizer(state_dict)
275+
else:
276+
base_state_dict = _broadcast_moe_optimizer_state(state_dict)
274277
if data_parallel_rank > 0:
275278
master_weight = state_dict.pop("master_weights", {})
276279
base_state_dict.update(state_dict)

0 commit comments

Comments
 (0)