|
41 | 41 | get_checkpoint_shard_files, |
42 | 42 | is_safetensors_available, |
43 | 43 | ) |
44 | | -from paddlenlp.utils.distributed import distributed_gather |
| 44 | +from paddlenlp.utils.distributed import distributed_allgather, distributed_gather |
45 | 45 | from paddlenlp.utils.env import ( |
46 | 46 | LORA_WEIGHTS_NAME, |
47 | 47 | PADDLE_MASTER_WEIGHTS_INDEX_NAME, |
|
64 | 64 | ) |
65 | 65 | from paddlenlp.utils.log import logger |
66 | 66 | from paddlenlp.utils.nested import nested_copy, nested_copy_place |
| 67 | +from paddlenlp.utils.tools import get_env_device |
67 | 68 |
|
68 | 69 | if is_safetensors_available(): |
69 | 70 | # from safetensors import safe_open |
@@ -1753,7 +1754,10 @@ def merge_tensor_parallel_with_shard(state_dict, tp_actions, all_filter_keys): |
1753 | 1754 | key = filter_keys[i] |
1754 | 1755 | tensor = state_dict[key] |
1755 | 1756 | if key in tp_actions: |
1756 | | - ret = distributed_gather(tensor, dst=j, group=tp_group, offload=False) |
| 1757 | + if get_env_device() == "xpu": |
| 1758 | + ret = distributed_allgather(tensor, group=tp_group, offload=False) |
| 1759 | + else: |
| 1760 | + ret = distributed_gather(tensor, dst=j, group=tp_group, offload=False) |
1757 | 1761 | action = tp_actions.pop(key) |
1758 | 1762 | tensor = action(ret) if is_dst else None |
1759 | 1763 | else: |
@@ -1790,7 +1794,10 @@ def merge_tensor_parallel_for_optimizer(state_dict, tp_actions, all_filter_keys) |
1790 | 1794 | if tensor.numel().item() == 1: |
1791 | 1795 | tensor = tensor._copy_to(DEST_PLACE, False) if is_dst else None # Need broadcast when loaded |
1792 | 1796 | else: |
1793 | | - ret = distributed_gather(tensor, dst=j, group=tp_group, offload=False) |
| 1797 | + if get_env_device() == "xpu": |
| 1798 | + ret = distributed_allgather(tensor, group=tp_group, offload=False) |
| 1799 | + else: |
| 1800 | + ret = distributed_gather(tensor, dst=j, group=tp_group, offload=False) |
1794 | 1801 | action = tp_actions[model_key] |
1795 | 1802 | tensor = action(ret) if is_dst else None |
1796 | 1803 | else: |
|
0 commit comments