|
41 | 41 | get_checkpoint_shard_files, |
42 | 42 | is_safetensors_available, |
43 | 43 | ) |
44 | | -from paddlenlp.utils.distributed import distributed_gather |
| 44 | +from paddlenlp.utils.distributed import distributed_allgather, distributed_gather |
45 | 45 | from paddlenlp.utils.env import ( |
46 | 46 | LORA_WEIGHTS_NAME, |
47 | 47 | PADDLE_MASTER_WEIGHTS_INDEX_NAME, |
|
64 | 64 | ) |
65 | 65 | from paddlenlp.utils.log import logger |
66 | 66 | from paddlenlp.utils.nested import nested_copy, nested_copy_place |
| 67 | +from paddlenlp.utils.tools import get_env_device |
67 | 68 |
|
68 | 69 | if is_safetensors_available(): |
69 | 70 | # from safetensors import safe_open |
@@ -1747,7 +1748,10 @@ def merge_tensor_parallel_with_shard(state_dict, tp_actions, all_filter_keys): |
1747 | 1748 | key = filter_keys[i] |
1748 | 1749 | tensor = state_dict[key] |
1749 | 1750 | if key in tp_actions: |
1750 | | - ret = distributed_gather(tensor, dst=j, group=tp_group, offload=False) |
| 1751 | + if get_env_device() == "xpu": |
| 1752 | + ret = distributed_allgather(tensor, group=tp_group, offload=False) |
| 1753 | + else: |
| 1754 | + ret = distributed_gather(tensor, dst=j, group=tp_group, offload=False) |
1751 | 1755 | action = tp_actions.pop(key) |
1752 | 1756 | tensor = action(ret) if is_dst else None |
1753 | 1757 | else: |
@@ -1784,7 +1788,10 @@ def merge_tensor_parallel_for_optimizer(state_dict, tp_actions, all_filter_keys) |
1784 | 1788 | if tensor.numel().item() == 1: |
1785 | 1789 | tensor = tensor._copy_to(DEST_PLACE, False) if is_dst else None # Need broadcast when loaded |
1786 | 1790 | else: |
1787 | | - ret = distributed_gather(tensor, dst=j, group=tp_group, offload=False) |
| 1791 | + if get_env_device() == "xpu": |
| 1792 | + ret = distributed_allgather(tensor, group=tp_group, offload=False) |
| 1793 | + else: |
| 1794 | + ret = distributed_gather(tensor, dst=j, group=tp_group, offload=False) |
1788 | 1795 | action = tp_actions[model_key] |
1789 | 1796 | tensor = action(ret) if is_dst else None |
1790 | 1797 | else: |
|
0 commit comments