Skip to content

[DP] Internal Load Balancing Per Node [one-pod-per-node] #21238

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 97 commits into from
Jul 24, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
97 commits
Select commit Hold shift + click to select a range
14f13ed
added debug logging
Jul 19, 2025
b90d331
updated
Jul 20, 2025
aefeeed
updated
Jul 20, 2025
59a9583
updated
Jul 20, 2025
48cf09b
updated
Jul 20, 2025
2fd0587
updated
Jul 20, 2025
14cf3c4
updated
Jul 20, 2025
4f5d3ea
updated
Jul 20, 2025
14db660
updated
Jul 20, 2025
2aa4975
updated
Jul 20, 2025
b142571
cleanup
Jul 20, 2025
e1843b7
updated
Jul 20, 2025
d2d54e9
updated
Jul 20, 2025
4438796
fix lb issues
Jul 20, 2025
2a68433
updated
Jul 20, 2025
1ced153
updatedd
Jul 20, 2025
b9c0f65
nits
Jul 20, 2025
dbc51d6
nits
Jul 20, 2025
471fa4a
updated
Jul 20, 2025
6569fac
stash
Jul 20, 2025
1e5303a
stash
Jul 20, 2025
a69edca
convert to use only one prometheus stat logger per async llm
Jul 20, 2025
de91a3c
convert to use only one prometheus stat logger per async llm
Jul 20, 2025
e08e1e9
cleanup prometheus logging
Jul 20, 2025
d39cf93
updated
Jul 20, 2025
9a2e26d
updated
Jul 20, 2025
3956d8c
updated
Jul 20, 2025
cad9670
updated
Jul 20, 2025
fd0650f
updated
Jul 20, 2025
896b0a2
updated
Jul 20, 2025
54e405b
updated
Jul 20, 2025
02ecfa8
updated
Jul 20, 2025
1358836
updated
Jul 20, 2025
4eae5cb
updated
Jul 20, 2025
5e6114d
Merge pull request #19 from robertgshaw2-redhat/fix-prometheus-logging
robertgshaw2-redhat Jul 20, 2025
c08fb6d
updated
Jul 20, 2025
d9291f9
cleanup
Jul 20, 2025
876c864
updated
Jul 20, 2025
f477b50
updated
Jul 20, 2025
5ea4fa2
updated
Jul 20, 2025
e9e180d
cleanup
Jul 20, 2025
3f4ae35
updated
Jul 20, 2025
840d381
updated
Jul 20, 2025
1b488f8
Merge branch 'main' into one-pod-per-node-lb
Jul 20, 2025
e540aa4
revert logger changes
Jul 20, 2025
72d2c87
nit comments
Jul 20, 2025
6206a06
nit comments
Jul 20, 2025
c229904
refactor ux
Jul 20, 2025
d9ea345
refactor ux
Jul 20, 2025
d4ab18f
updated
Jul 20, 2025
2cf8ff6
updated
Jul 20, 2025
99583c2
updated
Jul 20, 2025
ad34f4a
updated
Jul 20, 2025
fc79d23
updated
Jul 20, 2025
6491d59
cleanup
Jul 20, 2025
7127d83
updated
Jul 20, 2025
cae7cb0
updated
Jul 20, 2025
093b938
updated
Jul 20, 2025
0018dd0
debug
Jul 20, 2025
a46bc0a
updated
Jul 20, 2025
85cd2da
seems to be working again, but LB is wrong
Jul 20, 2025
9f7d321
stash
Jul 20, 2025
9160888
updated
Jul 20, 2025
be03d84
stash
Jul 20, 2025
32a35f5
stash
Jul 20, 2025
fe68027
updated
Jul 20, 2025
2d32c28
stash
Jul 20, 2025
d327a6b
cleanup
Jul 20, 2025
ec86e79
updated
Jul 20, 2025
3c206b1
updated
Jul 20, 2025
a588928
updated
Jul 20, 2025
e81c277
updated
Jul 20, 2025
1dcd900
updated
Jul 21, 2025
5f0663b
cleanup
Jul 21, 2025
6feb456
update ux
Jul 21, 2025
f53166a
update ux
Jul 21, 2025
e80c015
updated
Jul 21, 2025
1b481d3
updated
Jul 21, 2025
40397e3
finished validating
Jul 21, 2025
58e4227
Update vllm/engine/arg_utils.py
robertgshaw2-redhat Jul 21, 2025
7a793ad
fix data_parallel_hybrid_lb arg default value
njhill Jul 21, 2025
60ae223
Merge remote-tracking branch 'origin/main' into one-pod-per-node-lb
njhill Jul 21, 2025
36ed9f3
fix coordinator for hybrid LB mode
njhill Jul 22, 2025
82f9292
infer hybrid lb mode on secondary modes
njhill Jul 22, 2025
f27a85d
add cross-node dp arg validation
njhill Jul 22, 2025
cecf38a
fix assert
njhill Jul 22, 2025
75bd8ea
fix handshake
njhill Jul 22, 2025
aca3ce6
fix cross-node headless arg validation
njhill Jul 22, 2025
f63cc19
fix handshake mock test
njhill Jul 23, 2025
1bd5f2f
Merge remote-tracking branch 'refs/remotes/origin/main' into one-pod-…
njhill Jul 23, 2025
8601a22
fix bad merge
njhill Jul 23, 2025
d95aedd
[Tests] Add tests for headless internal DP LB
njhill Jul 23, 2025
6328c80
CI tests for hybrid DPLB mode
njhill Jul 23, 2025
1c300fc
fix internal_dp_lb tests
njhill Jul 23, 2025
fb0cf7e
rename test
njhill Jul 23, 2025
5fb6809
Merge remote-tracking branch 'origin/main' into one-pod-per-node-lb
njhill Jul 23, 2025
35f3782
relax hybrid dp asserts
tlrmchlsmth Jul 23, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .buildkite/test-pipeline.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,7 @@ steps:
- tests/v1/test_async_llm_dp.py
- tests/v1/test_external_lb_dp.py
- tests/v1/test_internal_lb_dp.py
- tests/v1/test_hybrid_lb_dp.py
- tests/v1/engine/test_engine_core_client.py
commands:
# test with tp=2 and external_dp=2
Expand All @@ -178,6 +179,7 @@ steps:
- TP_SIZE=2 DP_SIZE=2 pytest -v -s v1/test_async_llm_dp.py
- TP_SIZE=2 DP_SIZE=2 pytest -v -s v1/test_external_lb_dp.py
- TP_SIZE=1 DP_SIZE=4 pytest -v -s v1/test_internal_lb_dp.py
- TP_SIZE=1 DP_SIZE=4 pytest -v -s v1/test_hybrid_lb_dp.py
- pytest -v -s v1/engine/test_engine_core_client.py::test_kv_cache_events_dp
- pytest -v -s distributed/test_utils.py
- pytest -v -s compile/test_basic_correctness.py
Expand Down
4 changes: 2 additions & 2 deletions tests/v1/engine/test_engine_core_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -565,8 +565,8 @@ def create_mock_executor(vllm_config):

from vllm.v1.engine.utils import EngineZmqAddresses

def mock_startup_handshake(self, handshake_socket, on_head_node,
parallel_config):
def mock_startup_handshake(self, handshake_socket, local_client,
headless, parallel_config):
return EngineZmqAddresses(inputs=["tcp://127.0.0.1:5555"],
outputs=["tcp://127.0.0.1:5556"],
coordinator_input=None,
Expand Down
352 changes: 352 additions & 0 deletions tests/v1/test_hybrid_lb_dp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,352 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import asyncio
import os
import threading
import time
from contextlib import AsyncExitStack

import openai # use the official client for correctness check
import pytest
import pytest_asyncio

from tests.utils import RemoteOpenAIServer
from tests.v1.test_utils import check_request_balancing
from vllm.platforms import Platform

MODEL_NAME = "ibm-research/PowerMoE-3b"

# Number of data parallel ranks for hybrid LB testing (4 total)
DP_SIZE = int(os.getenv("DP_SIZE", "4"))
# Default tensor parallel size to use
TP_SIZE = int(os.getenv("TP_SIZE", "1"))

# Number of nodes (2 nodes, each with 2 DP ranks)
NUM_NODES = 2
DP_SIZE_LOCAL = DP_SIZE // NUM_NODES # 2 ranks per node


class HybridLBServerManager:
"""Manages hybrid data parallel vLLM server instances where each node
runs a single logical API server that balances requests only to the
DP engines running on that same node."""

def __init__(self,
model_name: str,
dp_size: int,
api_server_count: int,
base_server_args: list,
dp_size_local: int = DP_SIZE_LOCAL,
tp_size: int = TP_SIZE):
self.model_name = model_name
self.dp_size = dp_size
self.dp_size_local = dp_size_local
self.tp_size = tp_size
self.api_server_count = api_server_count
self.base_server_args = base_server_args
self.servers: list[tuple[RemoteOpenAIServer, list[str]]] = []
self.server_threads: list[threading.Thread] = []
self.num_nodes = dp_size // dp_size_local

def __enter__(self) -> list[tuple[RemoteOpenAIServer, list[str]]]:
"""Start all server instances for hybrid LB mode."""
for node_id in range(self.num_nodes):
# Create server args for this specific node
server_args = self.base_server_args.copy()

# Calculate start rank for this node
start_rank = node_id * self.dp_size_local

# Add hybrid LB specific arguments
server_args.extend([
"--data-parallel-size",
str(self.dp_size),
"--data-parallel-size-local",
str(self.dp_size_local),
"--data-parallel-start-rank",
str(start_rank),
"--data-parallel-hybrid-lb", # Enable hybrid LB mode
"--tensor-parallel-size",
str(self.tp_size),
"--port",
str(8000 + node_id), # Different port for each node
"--api-server-count",
str(self.api_server_count),
"--data-parallel-address",
"127.0.0.1",
"--data-parallel-rpc-port",
"13345",
])

# Use a thread to start each server to allow parallel initialization
def start_server(node: int, sargs: list[str]):
try:
# Calculate GPU devices for this node
gpus_per_node = self.dp_size_local * self.tp_size
gpu_start = node * gpus_per_node
gpu_end = gpu_start + gpus_per_node

# Start the server
server = RemoteOpenAIServer(
self.model_name,
sargs,
auto_port=False,
env_dict={
"CUDA_VISIBLE_DEVICES":
",".join(
str(Platform.device_id_to_physical_device_id(
i)) for i in range(gpu_start, gpu_end))
})
server.__enter__()
print(f"Hybrid LB node {node} started successfully with "
f"{self.dp_size_local} local DP ranks and "
f"{self.api_server_count} API servers")
self.servers.append((server, sargs))
except Exception as e:
print(f"Failed to start hybrid LB node {node}: {e}")
raise

thread = threading.Thread(target=start_server,
args=(node_id, server_args))
thread.start()

self.server_threads.append(thread)

# Wait for all servers to start
for thread in self.server_threads:
thread.join()

# Give servers additional time to fully initialize and coordinate
time.sleep(3)

if len(self.servers) != self.num_nodes:
raise Exception("Servers failed to start")

return self.servers

def __exit__(self, exc_type, exc_val, exc_tb):
"""Stop all server instances."""
while self.servers:
try:
self.servers.pop()[0].__exit__(exc_type, exc_val, exc_tb)
except Exception as e:
print(f"Error stopping server: {e}")


@pytest.fixture(scope="module")
def default_server_args():
return [
# use half precision for speed and memory savings in CI environment
"--dtype",
"bfloat16",
"--max-model-len",
"2048",
"--max-num-seqs",
"128",
"--enforce-eager",
]


@pytest.fixture(scope="module", params=[1]) # Only 1 API server for now
def servers(request, default_server_args):
api_server_count = request.param
with HybridLBServerManager(MODEL_NAME, DP_SIZE, api_server_count,
default_server_args, DP_SIZE_LOCAL,
TP_SIZE) as server_list:
yield server_list


@pytest_asyncio.fixture
async def clients(servers: list[tuple[RemoteOpenAIServer, list[str]]]):
# Create a client for each node (each node has its own API endpoint)
async with AsyncExitStack() as stack:
yield [
await stack.enter_async_context(server.get_async_client())
for server, _ in servers
]


@pytest.mark.asyncio
@pytest.mark.parametrize(
"model_name",
[MODEL_NAME],
)
async def test_hybrid_lb_completion(clients: list[openai.AsyncOpenAI],
servers: list[tuple[RemoteOpenAIServer,
list[str]]],
model_name: str) -> None:

async def make_request(client: openai.AsyncOpenAI):
completion = await client.completions.create(
model=model_name,
prompt="Hello, my name is",
max_tokens=10,
temperature=1.0)

assert completion.id is not None
assert completion.choices is not None and len(completion.choices) == 1

choice = completion.choices[0]
# The exact number of tokens can vary slightly with temperature=1.0,
# so we check for a reasonable minimum length.
assert len(choice.text) >= 1
# Finish reason might not always be 'length' if the model finishes early
# or due to other reasons, especially with high temperature.
# So, we'll accept 'length' or 'stop'.
assert choice.finish_reason in ("length", "stop")

# Token counts can also vary, so we check they are positive.
assert completion.usage.completion_tokens > 0
assert completion.usage.prompt_tokens > 0
assert completion.usage.total_tokens > 0
return completion

# Test single request to each node
for i, client in enumerate(clients):
result = await make_request(client)
assert result is not None
print(
f"Hybrid LB node {i} handled single completion request successfully"
)

await asyncio.sleep(0.5)

# Send requests to all nodes - each should balance within its local DP ranks
num_requests_per_node = 25 # Total 50 requests across 2 nodes
all_tasks = []

for i, client in enumerate(clients):
tasks = [make_request(client) for _ in range(num_requests_per_node)]
all_tasks.extend(tasks)

results = await asyncio.gather(*all_tasks)
assert len(results) == num_requests_per_node * len(clients)
assert all(completion is not None for completion in results)

await asyncio.sleep(0.5)

# Second burst of requests
all_tasks = []
for i, client in enumerate(clients):
tasks = [make_request(client) for _ in range(num_requests_per_node)]
all_tasks.extend(tasks)

results = await asyncio.gather(*all_tasks)
assert len(results) == num_requests_per_node * len(clients)
assert all(completion is not None for completion in results)

_, server_args = servers[0]
api_server_count = (
server_args.count('--api-server-count')
and server_args[server_args.index('--api-server-count') + 1] or 1)
print(
f"Successfully completed hybrid LB test with {len(clients)} nodes "
f"({DP_SIZE_LOCAL} DP ranks each, API server count: {api_server_count})"
)

# Check request balancing within each node
for i, (server, _) in enumerate(servers):
print(f"Checking request balancing for node {i}")
check_request_balancing(server, DP_SIZE_LOCAL)


@pytest.mark.asyncio
@pytest.mark.parametrize(
"model_name",
[MODEL_NAME],
)
async def test_hybrid_lb_completion_streaming(clients: list[
openai.AsyncOpenAI], servers: list[tuple[RemoteOpenAIServer, list[str]]],
model_name: str) -> None:
prompt = "What is an LLM?"

async def make_streaming_request(client: openai.AsyncOpenAI):
# Perform a non-streaming request to get the expected full output
single_completion = await client.completions.create(
model=model_name,
prompt=prompt,
max_tokens=5,
temperature=0.0,
)
single_output = single_completion.choices[0].text

# Perform the streaming request
stream = await client.completions.create(model=model_name,
prompt=prompt,
max_tokens=5,
temperature=0.0,
stream=True)
chunks: list[str] = []
finish_reason_count = 0
last_chunk = None
async for chunk in stream:
chunks.append(chunk.choices[0].text)
if chunk.choices[0].finish_reason is not None:
finish_reason_count += 1
last_chunk = chunk # Keep track of the last chunk

# finish reason should only return in the last block for OpenAI API
assert finish_reason_count == 1, (
"Finish reason should appear exactly once.")
assert last_chunk is not None, (
"Stream should have yielded at least one chunk.")
assert last_chunk.choices[
0].finish_reason == "length", "Finish reason should be 'length'."
# Check that the combined text matches the non-streamed version.
assert "".join(
chunks
) == single_output, "Streamed output should match non-streamed output."
return True # Indicate success for this request

# Test single request to each node
for i, client in enumerate(clients):
result = await make_streaming_request(client)
assert result is not None
print(
f"Hybrid LB node {i} handled single streaming request successfully"
)

await asyncio.sleep(0.5)

# Send streaming requests to all nodes
num_requests_per_node = 25 # Total 50 requests across 2 nodes
all_tasks = []

for i, client in enumerate(clients):
tasks = [
make_streaming_request(client)
for _ in range(num_requests_per_node)
]
all_tasks.extend(tasks)

results = await asyncio.gather(*all_tasks)
assert len(results) == num_requests_per_node * len(clients)
assert all(results), "Not all streaming requests completed successfully."

await asyncio.sleep(0.5)

# Second burst of streaming requests
all_tasks = []
for i, client in enumerate(clients):
tasks = [
make_streaming_request(client)
for _ in range(num_requests_per_node)
]
all_tasks.extend(tasks)

results = await asyncio.gather(*all_tasks)
assert len(results) == num_requests_per_node * len(clients)
assert all(results), "Not all streaming requests completed successfully."

_, server_args = servers[0]
api_server_count = (
server_args.count('--api-server-count')
and server_args[server_args.index('--api-server-count') + 1] or 1)
print(f"Successfully completed hybrid LB streaming test with "
f"{len(clients)} nodes ({DP_SIZE_LOCAL} DP ranks each, "
f"API server count: {api_server_count})")

# Check request balancing within each node
for i, (server, _) in enumerate(servers):
print(f"Checking streaming request balancing for node {i}")
check_request_balancing(server, DP_SIZE_LOCAL)
12 changes: 10 additions & 2 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1906,8 +1906,16 @@ class ParallelConfig:
"""Backend to use for data parallel, either "mp" or "ray"."""
data_parallel_external_lb: bool = False
"""Whether to use "external" DP LB mode. Applies only to online serving
and when data_parallel_size > 0. Set implicitly when
data_parallel_rank is provided explicitly to vllm serve."""
and when data_parallel_size > 0. This is useful for a "one-pod-per-rank"
wide-EP setup in Kuberentes. Set implicitly when --data-parallel-rank
is provided explicitly to vllm serve."""
data_parallel_hybrid_lb: bool = False
"""Whether to use "hybrid" DP LB mode. Applies only to online serving
and when data_parallel_size > 0. Enables running an AsyncLLM
and API server on a "per-node" basis where vLLM load balances
between local data parallel ranks, but an external LB balances
between vLLM nodes/replicas. Set explicitly in conjunction with
--data-parallel-start-rank."""
enable_expert_parallel: bool = False
"""Use expert parallelism instead of tensor parallelism for MoE layers."""
enable_eplb: bool = False
Expand Down
Loading