Skip to content

Commit e471b1e

Browse files
authored
fix xpu gather for unified ckpt (#8710)
1 parent 70564ba commit e471b1e

File tree

1 file changed

+10
-3
lines changed

1 file changed

+10
-3
lines changed

paddlenlp/trainer/plugins/unified_checkpoint.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@
4141
get_checkpoint_shard_files,
4242
is_safetensors_available,
4343
)
44-
from paddlenlp.utils.distributed import distributed_gather
44+
from paddlenlp.utils.distributed import distributed_allgather, distributed_gather
4545
from paddlenlp.utils.env import (
4646
LORA_WEIGHTS_NAME,
4747
PADDLE_MASTER_WEIGHTS_INDEX_NAME,
@@ -64,6 +64,7 @@
6464
)
6565
from paddlenlp.utils.log import logger
6666
from paddlenlp.utils.nested import nested_copy, nested_copy_place
67+
from paddlenlp.utils.tools import get_env_device
6768

6869
if is_safetensors_available():
6970
# from safetensors import safe_open
@@ -1753,7 +1754,10 @@ def merge_tensor_parallel_with_shard(state_dict, tp_actions, all_filter_keys):
17531754
key = filter_keys[i]
17541755
tensor = state_dict[key]
17551756
if key in tp_actions:
1756-
ret = distributed_gather(tensor, dst=j, group=tp_group, offload=False)
1757+
if get_env_device() == "xpu":
1758+
ret = distributed_allgather(tensor, group=tp_group, offload=False)
1759+
else:
1760+
ret = distributed_gather(tensor, dst=j, group=tp_group, offload=False)
17571761
action = tp_actions.pop(key)
17581762
tensor = action(ret) if is_dst else None
17591763
else:
@@ -1790,7 +1794,10 @@ def merge_tensor_parallel_for_optimizer(state_dict, tp_actions, all_filter_keys)
17901794
if tensor.numel().item() == 1:
17911795
tensor = tensor._copy_to(DEST_PLACE, False) if is_dst else None # Need broadcast when loaded
17921796
else:
1793-
ret = distributed_gather(tensor, dst=j, group=tp_group, offload=False)
1797+
if get_env_device() == "xpu":
1798+
ret = distributed_allgather(tensor, group=tp_group, offload=False)
1799+
else:
1800+
ret = distributed_gather(tensor, dst=j, group=tp_group, offload=False)
17941801
action = tp_actions[model_key]
17951802
tensor = action(ret) if is_dst else None
17961803
else:

0 commit comments

Comments
 (0)