Skip to content

Commit 2097916

Browse files
committed
fix load best
1 parent 75c7636 commit 2097916

File tree

2 files changed

+28
-0
lines changed

2 files changed

+28
-0
lines changed

paddlenlp/trainer/trainer.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,7 @@
8989
from ..transformers.model_utils import (
9090
PretrainedModel,
9191
_add_variant,
92+
_load_state_dict_into_model,
9293
load_sharded_checkpoint,
9394
unwrap_model,
9495
)
@@ -149,6 +150,7 @@
149150
from .utils import reshard as reshard_util
150151
from .utils.async_save import AsyncSaver
151152
from .utils.helper import ( # nested_truncate,
153+
broadcast_dataset_rank0_model,
152154
broadcast_dp_optimizer,
153155
broadcast_moe_optimizer,
154156
distributed_concat,
@@ -1161,6 +1163,10 @@ def fused_allreduce_gradients_no_sync(paramlist, hcg):
11611163
self.state.best_model_checkpoint,
11621164
safe_serialization=True,
11631165
)
1166+
if self.args.sharding_parallel_degree > 1 or self.args.data_parallel_degree > 1:
1167+
state_dict = broadcast_dataset_rank0_model(self.model.state_dict())
1168+
if self.args.dataset_rank > 0:
1169+
_load_state_dict_into_model(self.model, state_dict, "")
11641170
else:
11651171
weight_name = PADDLE_WEIGHTS_NAME
11661172
best_model_path = os.path.join(
@@ -1203,6 +1209,10 @@ def _load_best_model_from_peft_checkpoint(self):
12031209
self.state.best_model_checkpoint,
12041210
safe_serialization=True,
12051211
)
1212+
if self.args.sharding_parallel_degree > 1 or self.args.data_parallel_degree > 1:
1213+
state_dict = broadcast_dataset_rank0_model(self.model.get_trainable_state_dict())
1214+
if self.args.dataset_rank > 0:
1215+
_load_state_dict_into_model(self.model, state_dict, "")
12061216
return
12071217

12081218
convert_tp = False

paddlenlp/trainer/utils/helper.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -309,3 +309,21 @@ def _broadcast_moe_optimizer_state(state_dict):
309309
state_dict = base_state_dict
310310
del base_state_dict
311311
return state_dict
312+
313+
314+
def broadcast_dataset_rank0_model(state_dict):
315+
if paddle.distributed.get_world_size() <= 1:
316+
return state_dict
317+
318+
logger.info("Start broadcast model in sharding group or data parallel group.")
319+
hcg = fleet.get_hybrid_communicate_group()
320+
sharding_group = hcg.get_sharding_parallel_group()
321+
dp_group = hcg.get_data_parallel_group()
322+
323+
if sharding_group.nranks > 1:
324+
for k in state_dict.keys():
325+
dist.broadcast(state_dict[k], src=hcg.get_sharding_parallel_group_src_rank(), group=sharding_group)
326+
if dp_group.nranks > 1:
327+
for k in state_dict.keys():
328+
dist.broadcast(state_dict[k], src=hcg.get_data_parallel_group_src_rank(), group=dp_group)
329+
return state_dict

0 commit comments

Comments
 (0)