Skip to content

Commit 5a5753a

Browse files
authored
[BugFix] Fix multi-node offline data parallel (vllm-project#19937)
Signed-off-by: Nick Hill <[email protected]> Signed-off-by: Will Eaton <[email protected]>
1 parent f863de6 commit 5a5753a

File tree

5 files changed

+31
-4
lines changed

5 files changed

+31
-4
lines changed

.buildkite/test-pipeline.yaml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -615,13 +615,16 @@ steps:
615615
- vllm/executor/
616616
- vllm/model_executor/models/
617617
- tests/distributed/
618+
- tests/examples/offline_inference/data_parallel.py
618619
commands:
619620
- # the following commands are for the first node, with ip 192.168.10.10 (ray environment already set up)
620621
- VLLM_TEST_SAME_HOST=0 torchrun --nnodes 2 --nproc-per-node=2 --rdzv_backend=c10d --rdzv_endpoint=192.168.10.10 distributed/test_same_node.py | grep 'Same node test passed'
622+
- python3 ../examples/offline_inference/data_parallel.py --dp-size=2 --tp-size=1 --node-size=2 --node-rank=0 --master-addr=192.168.10.10 --master-port=12345 --enforce-eager --trust-remote-code
621623
- VLLM_MULTI_NODE=1 pytest -v -s distributed/test_multi_node_assignment.py
622624
- VLLM_MULTI_NODE=1 pytest -v -s distributed/test_pipeline_parallel.py
623625
- # the following commands are for the second node, with ip 192.168.10.11 (ray environment already set up)
624626
- VLLM_TEST_SAME_HOST=0 torchrun --nnodes 2 --nproc-per-node=2 --rdzv_backend=c10d --rdzv_endpoint=192.168.10.10 distributed/test_same_node.py | grep 'Same node test passed'
627+
- python3 ../examples/offline_inference/data_parallel.py --dp-size=2 --tp-size=1 --node-size=2 --node-rank=1 --master-addr=192.168.10.10 --master-port=12345 --enforce-eager --trust-remote-code
625628

626629
- label: Distributed Tests (2 GPUs) # 40min
627630
mirror_hardwares: [amdexperimental]

vllm/entrypoints/llm.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1568,6 +1568,8 @@ def _run_engine(
15681568
pbar.update(n)
15691569
else:
15701570
pbar.update(1)
1571+
if pbar.n == num_requests:
1572+
pbar.refresh()
15711573

15721574
if use_tqdm:
15731575
pbar.close()

vllm/v1/engine/core.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -877,12 +877,16 @@ def run_busy_loop(self):
877877
local_unfinished_reqs)
878878

879879
if not self.engines_running:
880-
if self.dp_rank == 0:
880+
if self.dp_rank == 0 or not self.has_coordinator:
881881
# Notify client that we are pausing the loop.
882882
logger.debug("Wave %d finished, pausing engine loop.",
883883
self.current_wave)
884+
# In the coordinator case, dp rank 0 sends updates to the
885+
# coordinator. Otherwise (offline spmd case), each rank
886+
# sends the update to its colocated front-end process.
887+
client_index = -1 if self.has_coordinator else 0
884888
self.output_queue.put_nowait(
885-
(-1,
889+
(client_index,
886890
EngineCoreOutputs(wave_complete=self.current_wave)))
887891
self.current_wave += 1
888892

vllm/v1/engine/core_client.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,11 @@ def collective_rpc(self,
155155
kwargs: Optional[dict[str, Any]] = None) -> list[_R]:
156156
raise NotImplementedError
157157

158+
def dp_engines_running(self) -> bool:
159+
"""Returns True id data parallel engines are collectively in a
160+
running state."""
161+
raise NotImplementedError
162+
158163
async def get_output_async(self) -> EngineCoreOutputs:
159164
raise NotImplementedError
160165

@@ -282,6 +287,9 @@ def collective_rpc(self,
282287
kwargs: Optional[dict[str, Any]] = None) -> list[_R]:
283288
return self.engine_core.collective_rpc(method, timeout, args, kwargs)
284289

290+
def dp_engines_running(self) -> bool:
291+
return False
292+
285293

286294
@dataclass
287295
class BackgroundResources:
@@ -384,6 +392,9 @@ def __init__(
384392
dp_size = parallel_config.data_parallel_size
385393
dp_rank = parallel_config.data_parallel_rank
386394

395+
# State used for data parallel.
396+
self.engines_running = False
397+
387398
# SPMD mode is where there is an LLM instance per DP rank and
388399
# one core engine per LLM, see
389400
# examples/offline_inference/data_parallel.py.
@@ -539,6 +550,9 @@ def free_pending_messages(self):
539550
while self.pending_messages and self.pending_messages[-1][0].done:
540551
self.pending_messages.pop()
541552

553+
def dp_engines_running(self) -> bool:
554+
return self.engines_running
555+
542556

543557
def _process_utility_output(output: UtilityOutput,
544558
utility_results: dict[int, AnyFuture]):
@@ -562,6 +576,7 @@ def __init__(self, vllm_config: VllmConfig, executor_class: type[Executor],
562576
log_stats=log_stats,
563577
)
564578

579+
self.is_dp = self.vllm_config.parallel_config.data_parallel_size > 1
565580
self.outputs_queue = queue.Queue[Union[EngineCoreOutputs, Exception]]()
566581

567582
# Ensure that the outputs socket processing thread does not have
@@ -623,6 +638,8 @@ def get_output(self) -> EngineCoreOutputs:
623638
outputs = self.outputs_queue.get()
624639
if isinstance(outputs, Exception):
625640
raise self._format_exception(outputs) from None
641+
if outputs.wave_complete is not None:
642+
self.engines_running = False
626643
return outputs
627644

628645
def _send_input(self, request_type: EngineCoreRequestType, request: Any):
@@ -650,6 +667,8 @@ def call_utility(self, method: str, *args) -> Any:
650667
return future.result()
651668

652669
def add_request(self, request: EngineCoreRequest) -> None:
670+
if self.is_dp:
671+
self.engines_running = True
653672
self._send_input(EngineCoreRequestType.ADD, request)
654673

655674
def abort_requests(self, request_ids: list[str]) -> None:
@@ -911,7 +930,6 @@ def __init__(self,
911930
client_addresses: Optional[dict[str, str]] = None,
912931
client_index: int = 0):
913932
self.current_wave = 0
914-
self.engines_running = False
915933
# To route aborts to the correct engine.
916934
self.reqs_in_flight: dict[str, CoreEngine] = {}
917935

vllm/v1/engine/llm_engine.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,7 @@ def get_num_unfinished_requests(self) -> int:
160160
def has_unfinished_requests(self) -> bool:
161161
has_unfinished = self.output_processor.has_unfinished_requests()
162162
if self.dp_group is None:
163-
return has_unfinished
163+
return has_unfinished or self.engine_core.dp_engines_running()
164164
return self.has_unfinished_requests_dp(has_unfinished)
165165

166166
def has_unfinished_requests_dp(self, has_unfinished: bool) -> bool:

0 commit comments

Comments
 (0)