Skip to content

Commit da78cae

Browse files
authored
[core][distributed] zmq fallback for broadcasting large objects (#6183)
[core][distributed] add zmq fallback for broadcasting large objects (#6183)
1 parent 2416b26 commit da78cae

File tree

6 files changed

+274
-80
lines changed

6 files changed

+274
-80
lines changed

requirements-common.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,3 +21,4 @@ lm-format-enforcer == 0.10.1
2121
outlines >= 0.0.43 # Requires torch >= 2.1.0
2222
typing_extensions
2323
filelock >= 3.10.4 # filelock starts to support `mode` argument from 3.10.4
24+
pyzmq

tests/distributed/test_same_node.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,11 @@
22

33
import torch
44

5-
from vllm.distributed.parallel_state import is_in_the_same_node
5+
from vllm.distributed.parallel_state import in_the_same_node_as
66

77
torch.distributed.init_process_group(backend="gloo")
8-
test_result = is_in_the_same_node(torch.distributed.group.WORLD)
8+
test_result = all(
9+
in_the_same_node_as(torch.distributed.group.WORLD, source_rank=0))
910

1011
expected = os.environ.get("VLLM_TEST_SAME_HOST", "1") == "1"
1112
assert test_result == expected, f"Expected {expected}, got {test_result}"

tests/distributed/test_shm_broadcast.py

Lines changed: 3 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,7 @@
66
import numpy as np
77
import torch.distributed as dist
88

9-
from vllm.distributed.device_communicators.shm_broadcast import (
10-
ShmRingBuffer, ShmRingBufferIO)
9+
from vllm.distributed.device_communicators.shm_broadcast import MessageQueue
1110
from vllm.utils import update_environment_variables
1211

1312

@@ -56,8 +55,8 @@ def wrapped_fn(env):
5655
@worker_fn_wrapper
5756
def worker_fn():
5857
writer_rank = 2
59-
broadcaster = ShmRingBufferIO.create_from_process_group(
60-
dist.group.WORLD, 1024 * 1024, 2, writer_rank)
58+
broadcaster = MessageQueue.create_from_process_group(
59+
dist.group.WORLD, 40 * 1024, 2, writer_rank)
6160
if dist.get_rank() == writer_rank:
6261
seed = random.randint(0, 1000)
6362
dist.broadcast_object_list([seed], writer_rank)
@@ -87,13 +86,3 @@ def worker_fn():
8786

8887
def test_shm_broadcast():
8988
distributed_run(worker_fn, 4)
90-
91-
92-
def test_singe_process():
93-
buffer = ShmRingBuffer(1, 1024, 4)
94-
reader = ShmRingBufferIO(buffer, reader_rank=0)
95-
writer = ShmRingBufferIO(buffer, reader_rank=-1)
96-
writer.enqueue([0])
97-
writer.enqueue([1])
98-
assert reader.dequeue() == [0]
99-
assert reader.dequeue() == [1]

vllm/distributed/device_communicators/custom_all_reduce.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from vllm import _custom_ops as ops
1010
from vllm.distributed.device_communicators.custom_all_reduce_utils import (
1111
gpu_p2p_access_check)
12-
from vllm.distributed.parallel_state import is_in_the_same_node
12+
from vllm.distributed.parallel_state import in_the_same_node_as
1313
from vllm.logger import init_logger
1414
from vllm.utils import cuda_device_count_stateless, is_full_nvlink
1515

@@ -64,7 +64,7 @@ def __init__(self,
6464
assert dist.get_backend(group) != dist.Backend.NCCL, (
6565
"CustomAllreduce should be attached to a non-NCCL group.")
6666

67-
if not is_in_the_same_node(group):
67+
if not all(in_the_same_node_as(group, source_rank=0)):
6868
# No need to initialize custom allreduce for multi-node case.
6969
logger.warning(
7070
"Custom allreduce is disabled because this process group"

0 commit comments

Comments
 (0)