Skip to content

Commit 72db7c2

Browse files
committed
fix xpu gather for unified ckpt (PaddlePaddle#8710)
1 parent d9ea4c8 commit 72db7c2

File tree

1 file changed

+10
-3
lines changed

1 file changed

+10
-3
lines changed

paddlenlp/trainer/plugins/unified_checkpoint.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@
4141
get_checkpoint_shard_files,
4242
is_safetensors_available,
4343
)
44-
from paddlenlp.utils.distributed import distributed_gather
44+
from paddlenlp.utils.distributed import distributed_allgather, distributed_gather
4545
from paddlenlp.utils.env import (
4646
LORA_WEIGHTS_NAME,
4747
PADDLE_MASTER_WEIGHTS_INDEX_NAME,
@@ -64,6 +64,7 @@
6464
)
6565
from paddlenlp.utils.log import logger
6666
from paddlenlp.utils.nested import nested_copy, nested_copy_place
67+
from paddlenlp.utils.tools import get_env_device
6768

6869
if is_safetensors_available():
6970
# from safetensors import safe_open
@@ -1747,7 +1748,10 @@ def merge_tensor_parallel_with_shard(state_dict, tp_actions, all_filter_keys):
17471748
key = filter_keys[i]
17481749
tensor = state_dict[key]
17491750
if key in tp_actions:
1750-
ret = distributed_gather(tensor, dst=j, group=tp_group, offload=False)
1751+
if get_env_device() == "xpu":
1752+
ret = distributed_allgather(tensor, group=tp_group, offload=False)
1753+
else:
1754+
ret = distributed_gather(tensor, dst=j, group=tp_group, offload=False)
17511755
action = tp_actions.pop(key)
17521756
tensor = action(ret) if is_dst else None
17531757
else:
@@ -1784,7 +1788,10 @@ def merge_tensor_parallel_for_optimizer(state_dict, tp_actions, all_filter_keys)
17841788
if tensor.numel().item() == 1:
17851789
tensor = tensor._copy_to(DEST_PLACE, False) if is_dst else None # Need broadcast when loaded
17861790
else:
1787-
ret = distributed_gather(tensor, dst=j, group=tp_group, offload=False)
1791+
if get_env_device() == "xpu":
1792+
ret = distributed_allgather(tensor, group=tp_group, offload=False)
1793+
else:
1794+
ret = distributed_gather(tensor, dst=j, group=tp_group, offload=False)
17881795
action = tp_actions[model_key]
17891796
tensor = action(ret) if is_dst else None
17901797
else:

0 commit comments

Comments
 (0)