Skip to content

Commit dbfd2b2

Browse files
robertgshaw2-redhatRobert Shawnjhilltlrmchlsmth
authored andcommitted
[DP] Internal Load Balancing Per Node [one-pod-per-node] (vllm-project#21238)
Signed-off-by: Robert Shaw <[email protected]> Signed-off-by: Nick Hill <[email protected]> Signed-off-by: Tyler Michael Smith <[email protected]> Co-authored-by: Robert Shaw <[email protected]> Co-authored-by: Nick Hill <[email protected]> Co-authored-by: Tyler Michael Smith <[email protected]> Signed-off-by: Paul Pak <[email protected]>
1 parent bb21f76 commit dbfd2b2

File tree

12 files changed

+486
-45
lines changed

12 files changed

+486
-45
lines changed

.buildkite/test-pipeline.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,7 @@ steps:
166166
- tests/v1/test_async_llm_dp.py
167167
- tests/v1/test_external_lb_dp.py
168168
- tests/v1/test_internal_lb_dp.py
169+
- tests/v1/test_hybrid_lb_dp.py
169170
- tests/v1/engine/test_engine_core_client.py
170171
commands:
171172
# test with tp=2 and external_dp=2
@@ -178,6 +179,7 @@ steps:
178179
- TP_SIZE=2 DP_SIZE=2 pytest -v -s v1/test_async_llm_dp.py
179180
- TP_SIZE=2 DP_SIZE=2 pytest -v -s v1/test_external_lb_dp.py
180181
- TP_SIZE=1 DP_SIZE=4 pytest -v -s v1/test_internal_lb_dp.py
182+
- TP_SIZE=1 DP_SIZE=4 pytest -v -s v1/test_hybrid_lb_dp.py
181183
- pytest -v -s v1/engine/test_engine_core_client.py::test_kv_cache_events_dp
182184
- pytest -v -s distributed/test_utils.py
183185
- pytest -v -s compile/test_basic_correctness.py

tests/v1/engine/test_engine_core_client.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -565,8 +565,8 @@ def create_mock_executor(vllm_config):
565565

566566
from vllm.v1.engine.utils import EngineZmqAddresses
567567

568-
def mock_startup_handshake(self, handshake_socket, on_head_node,
569-
parallel_config):
568+
def mock_startup_handshake(self, handshake_socket, local_client,
569+
headless, parallel_config):
570570
return EngineZmqAddresses(inputs=["tcp://127.0.0.1:5555"],
571571
outputs=["tcp://127.0.0.1:5556"],
572572
coordinator_input=None,

tests/v1/test_hybrid_lb_dp.py

Lines changed: 352 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,352 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
import asyncio
4+
import os
5+
import threading
6+
import time
7+
from contextlib import AsyncExitStack
8+
9+
import openai # use the official client for correctness check
10+
import pytest
11+
import pytest_asyncio
12+
13+
from tests.utils import RemoteOpenAIServer
14+
from tests.v1.test_utils import check_request_balancing
15+
from vllm.platforms import Platform
16+
17+
MODEL_NAME = "ibm-research/PowerMoE-3b"
18+
19+
# Number of data parallel ranks for hybrid LB testing (4 total)
20+
DP_SIZE = int(os.getenv("DP_SIZE", "4"))
21+
# Default tensor parallel size to use
22+
TP_SIZE = int(os.getenv("TP_SIZE", "1"))
23+
24+
# Number of nodes (2 nodes, each with 2 DP ranks)
25+
NUM_NODES = 2
26+
DP_SIZE_LOCAL = DP_SIZE // NUM_NODES # 2 ranks per node
27+
28+
29+
class HybridLBServerManager:
30+
"""Manages hybrid data parallel vLLM server instances where each node
31+
runs a single logical API server that balances requests only to the
32+
DP engines running on that same node."""
33+
34+
def __init__(self,
35+
model_name: str,
36+
dp_size: int,
37+
api_server_count: int,
38+
base_server_args: list,
39+
dp_size_local: int = DP_SIZE_LOCAL,
40+
tp_size: int = TP_SIZE):
41+
self.model_name = model_name
42+
self.dp_size = dp_size
43+
self.dp_size_local = dp_size_local
44+
self.tp_size = tp_size
45+
self.api_server_count = api_server_count
46+
self.base_server_args = base_server_args
47+
self.servers: list[tuple[RemoteOpenAIServer, list[str]]] = []
48+
self.server_threads: list[threading.Thread] = []
49+
self.num_nodes = dp_size // dp_size_local
50+
51+
def __enter__(self) -> list[tuple[RemoteOpenAIServer, list[str]]]:
52+
"""Start all server instances for hybrid LB mode."""
53+
for node_id in range(self.num_nodes):
54+
# Create server args for this specific node
55+
server_args = self.base_server_args.copy()
56+
57+
# Calculate start rank for this node
58+
start_rank = node_id * self.dp_size_local
59+
60+
# Add hybrid LB specific arguments
61+
server_args.extend([
62+
"--data-parallel-size",
63+
str(self.dp_size),
64+
"--data-parallel-size-local",
65+
str(self.dp_size_local),
66+
"--data-parallel-start-rank",
67+
str(start_rank),
68+
"--data-parallel-hybrid-lb", # Enable hybrid LB mode
69+
"--tensor-parallel-size",
70+
str(self.tp_size),
71+
"--port",
72+
str(8000 + node_id), # Different port for each node
73+
"--api-server-count",
74+
str(self.api_server_count),
75+
"--data-parallel-address",
76+
"127.0.0.1",
77+
"--data-parallel-rpc-port",
78+
"13345",
79+
])
80+
81+
# Use a thread to start each server to allow parallel initialization
82+
def start_server(node: int, sargs: list[str]):
83+
try:
84+
# Calculate GPU devices for this node
85+
gpus_per_node = self.dp_size_local * self.tp_size
86+
gpu_start = node * gpus_per_node
87+
gpu_end = gpu_start + gpus_per_node
88+
89+
# Start the server
90+
server = RemoteOpenAIServer(
91+
self.model_name,
92+
sargs,
93+
auto_port=False,
94+
env_dict={
95+
"CUDA_VISIBLE_DEVICES":
96+
",".join(
97+
str(Platform.device_id_to_physical_device_id(
98+
i)) for i in range(gpu_start, gpu_end))
99+
})
100+
server.__enter__()
101+
print(f"Hybrid LB node {node} started successfully with "
102+
f"{self.dp_size_local} local DP ranks and "
103+
f"{self.api_server_count} API servers")
104+
self.servers.append((server, sargs))
105+
except Exception as e:
106+
print(f"Failed to start hybrid LB node {node}: {e}")
107+
raise
108+
109+
thread = threading.Thread(target=start_server,
110+
args=(node_id, server_args))
111+
thread.start()
112+
113+
self.server_threads.append(thread)
114+
115+
# Wait for all servers to start
116+
for thread in self.server_threads:
117+
thread.join()
118+
119+
# Give servers additional time to fully initialize and coordinate
120+
time.sleep(3)
121+
122+
if len(self.servers) != self.num_nodes:
123+
raise Exception("Servers failed to start")
124+
125+
return self.servers
126+
127+
def __exit__(self, exc_type, exc_val, exc_tb):
128+
"""Stop all server instances."""
129+
while self.servers:
130+
try:
131+
self.servers.pop()[0].__exit__(exc_type, exc_val, exc_tb)
132+
except Exception as e:
133+
print(f"Error stopping server: {e}")
134+
135+
136+
@pytest.fixture(scope="module")
137+
def default_server_args():
138+
return [
139+
# use half precision for speed and memory savings in CI environment
140+
"--dtype",
141+
"bfloat16",
142+
"--max-model-len",
143+
"2048",
144+
"--max-num-seqs",
145+
"128",
146+
"--enforce-eager",
147+
]
148+
149+
150+
@pytest.fixture(scope="module", params=[1]) # Only 1 API server for now
151+
def servers(request, default_server_args):
152+
api_server_count = request.param
153+
with HybridLBServerManager(MODEL_NAME, DP_SIZE, api_server_count,
154+
default_server_args, DP_SIZE_LOCAL,
155+
TP_SIZE) as server_list:
156+
yield server_list
157+
158+
159+
@pytest_asyncio.fixture
160+
async def clients(servers: list[tuple[RemoteOpenAIServer, list[str]]]):
161+
# Create a client for each node (each node has its own API endpoint)
162+
async with AsyncExitStack() as stack:
163+
yield [
164+
await stack.enter_async_context(server.get_async_client())
165+
for server, _ in servers
166+
]
167+
168+
169+
@pytest.mark.asyncio
170+
@pytest.mark.parametrize(
171+
"model_name",
172+
[MODEL_NAME],
173+
)
174+
async def test_hybrid_lb_completion(clients: list[openai.AsyncOpenAI],
175+
servers: list[tuple[RemoteOpenAIServer,
176+
list[str]]],
177+
model_name: str) -> None:
178+
179+
async def make_request(client: openai.AsyncOpenAI):
180+
completion = await client.completions.create(
181+
model=model_name,
182+
prompt="Hello, my name is",
183+
max_tokens=10,
184+
temperature=1.0)
185+
186+
assert completion.id is not None
187+
assert completion.choices is not None and len(completion.choices) == 1
188+
189+
choice = completion.choices[0]
190+
# The exact number of tokens can vary slightly with temperature=1.0,
191+
# so we check for a reasonable minimum length.
192+
assert len(choice.text) >= 1
193+
# Finish reason might not always be 'length' if the model finishes early
194+
# or due to other reasons, especially with high temperature.
195+
# So, we'll accept 'length' or 'stop'.
196+
assert choice.finish_reason in ("length", "stop")
197+
198+
# Token counts can also vary, so we check they are positive.
199+
assert completion.usage.completion_tokens > 0
200+
assert completion.usage.prompt_tokens > 0
201+
assert completion.usage.total_tokens > 0
202+
return completion
203+
204+
# Test single request to each node
205+
for i, client in enumerate(clients):
206+
result = await make_request(client)
207+
assert result is not None
208+
print(
209+
f"Hybrid LB node {i} handled single completion request successfully"
210+
)
211+
212+
await asyncio.sleep(0.5)
213+
214+
# Send requests to all nodes - each should balance within its local DP ranks
215+
num_requests_per_node = 25 # Total 50 requests across 2 nodes
216+
all_tasks = []
217+
218+
for i, client in enumerate(clients):
219+
tasks = [make_request(client) for _ in range(num_requests_per_node)]
220+
all_tasks.extend(tasks)
221+
222+
results = await asyncio.gather(*all_tasks)
223+
assert len(results) == num_requests_per_node * len(clients)
224+
assert all(completion is not None for completion in results)
225+
226+
await asyncio.sleep(0.5)
227+
228+
# Second burst of requests
229+
all_tasks = []
230+
for i, client in enumerate(clients):
231+
tasks = [make_request(client) for _ in range(num_requests_per_node)]
232+
all_tasks.extend(tasks)
233+
234+
results = await asyncio.gather(*all_tasks)
235+
assert len(results) == num_requests_per_node * len(clients)
236+
assert all(completion is not None for completion in results)
237+
238+
_, server_args = servers[0]
239+
api_server_count = (
240+
server_args.count('--api-server-count')
241+
and server_args[server_args.index('--api-server-count') + 1] or 1)
242+
print(
243+
f"Successfully completed hybrid LB test with {len(clients)} nodes "
244+
f"({DP_SIZE_LOCAL} DP ranks each, API server count: {api_server_count})"
245+
)
246+
247+
# Check request balancing within each node
248+
for i, (server, _) in enumerate(servers):
249+
print(f"Checking request balancing for node {i}")
250+
check_request_balancing(server, DP_SIZE_LOCAL)
251+
252+
253+
@pytest.mark.asyncio
254+
@pytest.mark.parametrize(
255+
"model_name",
256+
[MODEL_NAME],
257+
)
258+
async def test_hybrid_lb_completion_streaming(clients: list[
259+
openai.AsyncOpenAI], servers: list[tuple[RemoteOpenAIServer, list[str]]],
260+
model_name: str) -> None:
261+
prompt = "What is an LLM?"
262+
263+
async def make_streaming_request(client: openai.AsyncOpenAI):
264+
# Perform a non-streaming request to get the expected full output
265+
single_completion = await client.completions.create(
266+
model=model_name,
267+
prompt=prompt,
268+
max_tokens=5,
269+
temperature=0.0,
270+
)
271+
single_output = single_completion.choices[0].text
272+
273+
# Perform the streaming request
274+
stream = await client.completions.create(model=model_name,
275+
prompt=prompt,
276+
max_tokens=5,
277+
temperature=0.0,
278+
stream=True)
279+
chunks: list[str] = []
280+
finish_reason_count = 0
281+
last_chunk = None
282+
async for chunk in stream:
283+
chunks.append(chunk.choices[0].text)
284+
if chunk.choices[0].finish_reason is not None:
285+
finish_reason_count += 1
286+
last_chunk = chunk # Keep track of the last chunk
287+
288+
# finish reason should only return in the last block for OpenAI API
289+
assert finish_reason_count == 1, (
290+
"Finish reason should appear exactly once.")
291+
assert last_chunk is not None, (
292+
"Stream should have yielded at least one chunk.")
293+
assert last_chunk.choices[
294+
0].finish_reason == "length", "Finish reason should be 'length'."
295+
# Check that the combined text matches the non-streamed version.
296+
assert "".join(
297+
chunks
298+
) == single_output, "Streamed output should match non-streamed output."
299+
return True # Indicate success for this request
300+
301+
# Test single request to each node
302+
for i, client in enumerate(clients):
303+
result = await make_streaming_request(client)
304+
assert result is not None
305+
print(
306+
f"Hybrid LB node {i} handled single streaming request successfully"
307+
)
308+
309+
await asyncio.sleep(0.5)
310+
311+
# Send streaming requests to all nodes
312+
num_requests_per_node = 25 # Total 50 requests across 2 nodes
313+
all_tasks = []
314+
315+
for i, client in enumerate(clients):
316+
tasks = [
317+
make_streaming_request(client)
318+
for _ in range(num_requests_per_node)
319+
]
320+
all_tasks.extend(tasks)
321+
322+
results = await asyncio.gather(*all_tasks)
323+
assert len(results) == num_requests_per_node * len(clients)
324+
assert all(results), "Not all streaming requests completed successfully."
325+
326+
await asyncio.sleep(0.5)
327+
328+
# Second burst of streaming requests
329+
all_tasks = []
330+
for i, client in enumerate(clients):
331+
tasks = [
332+
make_streaming_request(client)
333+
for _ in range(num_requests_per_node)
334+
]
335+
all_tasks.extend(tasks)
336+
337+
results = await asyncio.gather(*all_tasks)
338+
assert len(results) == num_requests_per_node * len(clients)
339+
assert all(results), "Not all streaming requests completed successfully."
340+
341+
_, server_args = servers[0]
342+
api_server_count = (
343+
server_args.count('--api-server-count')
344+
and server_args[server_args.index('--api-server-count') + 1] or 1)
345+
print(f"Successfully completed hybrid LB streaming test with "
346+
f"{len(clients)} nodes ({DP_SIZE_LOCAL} DP ranks each, "
347+
f"API server count: {api_server_count})")
348+
349+
# Check request balancing within each node
350+
for i, (server, _) in enumerate(servers):
351+
print(f"Checking streaming request balancing for node {i}")
352+
check_request_balancing(server, DP_SIZE_LOCAL)

vllm/config.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1922,8 +1922,16 @@ class ParallelConfig:
19221922
"""Backend to use for data parallel, either "mp" or "ray"."""
19231923
data_parallel_external_lb: bool = False
19241924
"""Whether to use "external" DP LB mode. Applies only to online serving
1925-
and when data_parallel_size > 0. Set implicitly when
1926-
data_parallel_rank is provided explicitly to vllm serve."""
1925+
and when data_parallel_size > 0. This is useful for a "one-pod-per-rank"
1926+
wide-EP setup in Kuberentes. Set implicitly when --data-parallel-rank
1927+
is provided explicitly to vllm serve."""
1928+
data_parallel_hybrid_lb: bool = False
1929+
"""Whether to use "hybrid" DP LB mode. Applies only to online serving
1930+
and when data_parallel_size > 0. Enables running an AsyncLLM
1931+
and API server on a "per-node" basis where vLLM load balances
1932+
between local data parallel ranks, but an external LB balances
1933+
between vLLM nodes/replicas. Set explicitly in conjunction with
1934+
--data-parallel-start-rank."""
19271935
enable_expert_parallel: bool = False
19281936
"""Use expert parallelism instead of tensor parallelism for MoE layers."""
19291937
enable_eplb: bool = False

0 commit comments

Comments
 (0)