Skip to content

Commit b218d65

Browse files
taozhiweipre-commit-ci[bot]SkafteNickiBorda
authored andcommitted
let _get_default_process_group_backend_for_device support more hardware platforms (#21057)
* support more hardware platforms and no longer hard code cuda when call _get_default_process_group_backend_for_device * Apply suggestions from code review --------- Signed-off-by: taozhiwei <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Nicki Skafte Detlefsen <[email protected]> Co-authored-by: Jirka Borovec <[email protected]> (cherry picked from commit 119a640)
1 parent 1fcbc6c commit b218d65

File tree

2 files changed

+27
-1
lines changed

2 files changed

+27
-1
lines changed

src/lightning/fabric/utilities/distributed.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -319,7 +319,11 @@ def _destroy_dist_connection() -> None:
319319

320320

321321
def _get_default_process_group_backend_for_device(device: torch.device) -> str:
322-
return "nccl" if device.type == "cuda" else "gloo"
322+
"""Return corresponding distributed backend for a given device."""
323+
device_backend_map = torch.distributed.Backend.default_device_backend_map
324+
if device.type in device_backend_map:
325+
return device_backend_map[device.type]
326+
return "gloo"
323327

324328

325329
class _DatasetSamplerWrapper(Dataset):

tests/tests_fabric/utilities/test_distributed.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from lightning.fabric.utilities.distributed import (
1818
_destroy_dist_connection,
1919
_gather_all_tensors,
20+
_get_default_process_group_backend_for_device,
2021
_InfiniteBarrier,
2122
_init_dist_connection,
2223
_is_dtensor,
@@ -243,6 +244,27 @@ def test_init_dist_connection_registers_destruction_handler(_, atexit_mock):
243244
atexit_mock.register.assert_not_called()
244245

245246

247+
def test_get_default_process_group_backend_for_device():
248+
"""Test that each device type maps to its correct default process group backend."""
249+
# register a custom backend for test
250+
torch.utils.rename_privateuse1_backend("pcu")
251+
252+
def mock_backend(store, group_rank, group_size, timeout):
253+
pass
254+
255+
torch.distributed.Backend.register_backend(
256+
"pccl",
257+
lambda store, group_rank, group_size, timeout: mock_backend(store, group_rank, group_size, timeout),
258+
devices=["pcu"],
259+
)
260+
261+
# test that the default backend is correctly set for each device
262+
devices = [torch.device("cpu"), torch.device("cuda:0"), torch.device("pcu:0")]
263+
backends = ["gloo", "nccl", "pccl"]
264+
for device, backend in zip(devices, backends):
265+
assert _get_default_process_group_backend_for_device(device) == backend
266+
267+
246268
@RunIf(min_torch="2.4")
247269
def test_is_dtensor(monkeypatch):
248270
from torch.distributed._tensor import DTensor

0 commit comments

Comments
 (0)