Skip to content

Commit 8ed5421

Browse files
authored
[V1] Eagerly remove finished requests from the batch (#14388)
Signed-off-by: Nick Hill <[email protected]>
1 parent c6359e8 commit 8ed5421

File tree

9 files changed

+58
-16
lines changed

9 files changed

+58
-16
lines changed

tests/v1/engine/test_engine_core.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,14 +102,24 @@ def test_engine_core(monkeypatch):
102102
engine_core.add_request(req)
103103
assert len(engine_core.scheduler.waiting) == 1
104104
assert len(engine_core.scheduler.running) == 0
105+
assert engine_core.scheduler.has_unfinished_requests()
106+
assert not engine_core.scheduler.has_finished_requests()
105107

106108
_ = engine_core.step()
107109
assert len(engine_core.scheduler.waiting) == 0
108110
assert len(engine_core.scheduler.running) == 1
111+
assert engine_core.scheduler.has_unfinished_requests()
112+
assert not engine_core.scheduler.has_finished_requests()
109113

110114
engine_core.abort_requests([request_id])
111115
assert len(engine_core.scheduler.waiting) == 0
112116
assert len(engine_core.scheduler.running) == 0
117+
assert not engine_core.scheduler.has_unfinished_requests()
118+
assert engine_core.scheduler.has_finished_requests()
119+
120+
_ = engine_core.step()
121+
assert not engine_core.scheduler.has_unfinished_requests()
122+
assert not engine_core.scheduler.has_finished_requests()
113123

114124
# Add, step, abort 1 of the 3.
115125
req0 = make_request()

tests/v1/engine/test_engine_core_client.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ def loop_until_done(client: EngineCoreClient, outputs: dict):
5050
engine_core_outputs = client.get_output().outputs
5151

5252
if len(engine_core_outputs) == 0:
53-
break
53+
continue
5454

5555
all_finished = True
5656
for out in engine_core_outputs:
@@ -68,7 +68,7 @@ async def loop_until_done_async(client: EngineCoreClient, outputs: dict):
6868
engine_core_outputs = (await client.get_output_async()).outputs
6969

7070
if len(engine_core_outputs) == 0:
71-
break
71+
continue
7272

7373
all_finished = True
7474
for out in engine_core_outputs:

vllm/v1/core/scheduler.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -682,7 +682,8 @@ def finish_requests(
682682
assert RequestStatus.is_finished(finished_status)
683683
if isinstance(request_ids, str):
684684
request_ids = (request_ids, )
685-
request_ids = set(request_ids)
685+
else:
686+
request_ids = set(request_ids)
686687

687688
for req_id in request_ids:
688689
request = self.requests.get(req_id)
@@ -714,6 +715,14 @@ def get_num_unfinished_requests(self) -> int:
714715
def has_unfinished_requests(self) -> bool:
715716
return self.get_num_unfinished_requests() > 0
716717

718+
def has_finished_requests(self) -> bool:
719+
return len(self.finished_req_ids) > 0
720+
721+
def has_requests(self):
722+
"""Returns True if there are unfinished requests, or finished requests
723+
not yet returned in SchedulerOutputs."""
724+
return self.has_unfinished_requests() or self.has_finished_requests()
725+
717726
def get_num_unscheduled_requests(self) -> int:
718727
"""Number of requests that are not being processed by the executor."""
719728
return self.get_num_unfinished_requests() - len(self.scheduled_req_ids)

vllm/v1/engine/async_llm.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -253,13 +253,14 @@ async def _run_output_handler(self):
253253
while True:
254254
# 1) Pull EngineCoreOutputs from the EngineCore.
255255
outputs = await self.engine_core.get_output_async()
256+
num_outputs = len(outputs.outputs)
256257

257-
iteration_stats = IterationStats() if self.log_stats else None
258+
iteration_stats = IterationStats() if (
259+
self.log_stats and num_outputs) else None
258260

259261
# Split outputs into chunks of at most
260262
# VLLM_V1_OUTPUT_PROC_CHUNK_SIZE, so that we don't block the
261263
# event loop for too long.
262-
num_outputs = len(outputs.outputs)
263264
if num_outputs <= VLLM_V1_OUTPUT_PROC_CHUNK_SIZE:
264265
slices = (outputs.outputs, )
265266
else:
@@ -313,7 +314,6 @@ def _record_stats(
313314
return
314315

315316
assert scheduler_stats is not None
316-
assert iteration_stats is not None
317317
for stat_logger in self.stat_loggers:
318318
stat_logger.record(scheduler_stats=scheduler_stats,
319319
iteration_stats=iteration_stats)

vllm/v1/engine/core.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,9 @@ def abort_requests(self, request_ids: list[str]):
153153
def step(self) -> EngineCoreOutputs:
154154
"""Schedule, execute, and make output."""
155155

156-
if not self.scheduler.has_unfinished_requests():
156+
# Check for any requests remaining in the scheduler - unfinished,
157+
# or finished and not yet removed from the batch.
158+
if not self.scheduler.has_requests():
157159
return EngineCoreOutputs(
158160
outputs=[],
159161
scheduler_stats=self.scheduler.make_stats(),
@@ -335,7 +337,7 @@ def run_busy_loop(self):
335337
# Loop until process is sent a SIGINT or SIGTERM
336338
while True:
337339
# 1) Poll the input queue until there is work to do.
338-
while not self.scheduler.has_unfinished_requests():
340+
while not self.scheduler.has_requests():
339341
logger.debug("EngineCore busy loop waiting.")
340342
req = self.input_queue.get()
341343
self._handle_client_request(*req)

vllm/v1/metrics/loggers.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ class StatLoggerBase(ABC):
2222

2323
@abstractmethod
2424
def record(self, scheduler_stats: SchedulerStats,
25-
iteration_stats: IterationStats):
25+
iteration_stats: Optional[IterationStats]):
2626
...
2727

2828
def log(self): # noqa
@@ -56,10 +56,11 @@ def _get_throughput(self, tracked_stats: list[int], now: float) -> float:
5656
return float(np.sum(tracked_stats) / (now - self.last_log_time))
5757

5858
def record(self, scheduler_stats: SchedulerStats,
59-
iteration_stats: IterationStats):
59+
iteration_stats: Optional[IterationStats]):
6060
"""Log Stats to standard output."""
6161

62-
self._track_iteration_stats(iteration_stats)
62+
if iteration_stats:
63+
self._track_iteration_stats(iteration_stats)
6364

6465
self.prefix_caching_metrics.observe(scheduler_stats.prefix_cache_stats)
6566

@@ -319,7 +320,7 @@ def log_metrics_info(self, type: str, config_obj: SupportsMetricsInfo):
319320
info_gauge.set(1)
320321

321322
def record(self, scheduler_stats: SchedulerStats,
322-
iteration_stats: IterationStats):
323+
iteration_stats: Optional[IterationStats]):
323324
"""Log to prometheus."""
324325
self.gauge_scheduler_running.set(scheduler_stats.num_running_reqs)
325326
self.gauge_scheduler_waiting.set(scheduler_stats.num_waiting_reqs)
@@ -331,6 +332,9 @@ def record(self, scheduler_stats: SchedulerStats,
331332
self.counter_gpu_prefix_cache_hits.inc(
332333
scheduler_stats.prefix_cache_stats.hits)
333334

335+
if iteration_stats is None:
336+
return
337+
334338
self.counter_num_preempted_reqs.inc(iteration_stats.num_preempted_reqs)
335339
self.counter_prompt_tokens.inc(iteration_stats.num_prompt_tokens)
336340
self.counter_generation_tokens.inc(

vllm/v1/outputs.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,3 +80,13 @@ class ModelRunnerOutput:
8080
# [prompt_len, num_prompt_logprobs]
8181
# [prompt_len]
8282
prompt_logprobs_dict: dict[str, Optional[LogprobsTensors]]
83+
84+
85+
EMPTY_MODEL_RUNNER_OUTPUT = ModelRunnerOutput(
86+
req_ids=[],
87+
req_id_to_index={},
88+
sampled_token_ids=[],
89+
spec_token_ids=None,
90+
logprobs=None,
91+
prompt_logprobs_dict={},
92+
)

vllm/v1/worker/gpu_model_runner.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,8 @@
3232
from vllm.v1.engine.mm_input_cache import MMInputCacheClient
3333
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
3434
KVCacheSpec)
35-
from vllm.v1.outputs import LogprobsTensors, ModelRunnerOutput
35+
from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, LogprobsTensors,
36+
ModelRunnerOutput)
3637
from vllm.v1.sample.metadata import SamplingMetadata
3738
from vllm.v1.sample.rejection_sampler import INVALID_TOKEN_ID, RejectionSampler
3839
from vllm.v1.spec_decode.ngram_proposer import NgramProposer
@@ -919,6 +920,9 @@ def execute_model(
919920
intermediate_tensors: Optional[IntermediateTensors] = None,
920921
) -> Union[ModelRunnerOutput, torch.Tensor]:
921922
self._update_states(scheduler_output)
923+
if not scheduler_output.total_num_scheduled_tokens:
924+
# Return empty ModelRunnerOuptut if there's no work to do.
925+
return EMPTY_MODEL_RUNNER_OUTPUT
922926

923927
if self.is_multimodal_model:
924928
# Run the multimodal encoder if any.
@@ -1069,15 +1073,14 @@ def execute_model(
10691073
spec_token_ids = self.generate_draft_token_ids(
10701074
valid_sampled_token_ids)
10711075

1072-
model_runner_output = ModelRunnerOutput(
1076+
return ModelRunnerOutput(
10731077
req_ids=self.input_batch.req_ids,
10741078
req_id_to_index=self.input_batch.req_id_to_index,
10751079
sampled_token_ids=valid_sampled_token_ids,
10761080
spec_token_ids=spec_token_ids,
10771081
logprobs=logprobs_lists,
10781082
prompt_logprobs_dict=prompt_logprobs_dict,
10791083
)
1080-
return model_runner_output
10811084

10821085
def generate_draft_token_ids(
10831086
self,

vllm/v1/worker/tpu_model_runner.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,8 @@
2929
from vllm.v1.core.encoder_cache_manager import compute_encoder_budget
3030
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
3131
KVCacheSpec)
32-
from vllm.v1.outputs import LogprobsTensors, ModelRunnerOutput
32+
from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, LogprobsTensors,
33+
ModelRunnerOutput)
3334
from vllm.v1.utils import bind_kv_cache
3435
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
3536

@@ -546,6 +547,9 @@ def execute_model(
546547
) -> ModelRunnerOutput:
547548
# Update cached state
548549
self._update_states(scheduler_output)
550+
if not scheduler_output.total_num_scheduled_tokens:
551+
# Return empty ModelRunnerOuptut if there's no work to do.
552+
return EMPTY_MODEL_RUNNER_OUTPUT
549553

550554
if self.is_multimodal_model:
551555
# Run the multimodal encoder if any.

0 commit comments

Comments
 (0)