Skip to content

Commit 7bc2ac6

Browse files
robertgshaw2-redhatYuqi Zhang
authored andcommitted
[Bugfix][Nixl] Fix Preemption Bug (vllm-project#18631)
Signed-off-by: [email protected] <[email protected]> Signed-off-by: Yuqi Zhang <[email protected]>
1 parent bef3250 commit 7bc2ac6

File tree

2 files changed

+97
-15
lines changed

2 files changed

+97
-15
lines changed

tests/v1/kv_connector/unit/test_remote_prefill_lifecycle.py

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -340,3 +340,84 @@ def test_full_block_prompt():
340340
output = outputs[0]
341341
assert output.finish_reason == FinishReason.STOP
342342
assert_scheduler_empty(scheduler)
343+
344+
345+
def test_cannot_schedule_after_recv():
346+
"""
347+
Test that we can handle no schedule after recv due to not
348+
enough remaining KV blocks.
349+
"""
350+
351+
# NOTE: the KVCacheManager will use 1 null block.
352+
# So there are 5 total working blocks.
353+
TOTAL_NUM_BLOCKS = 6
354+
vllm_config = create_vllm_config()
355+
scheduler = create_scheduler(vllm_config, num_blocks=TOTAL_NUM_BLOCKS)
356+
357+
# Prime the KVCache.
358+
NUM_PROMPT_BLOCKS = 2
359+
BLOCK_SIZE = vllm_config.cache_config.block_size
360+
# Prompt will use 2 blocks + 1 block after we schedule.
361+
NUM_TOKENS_LOCAL = int(BLOCK_SIZE * NUM_PROMPT_BLOCKS)
362+
NUM_TOKENS_REMOTE = int(BLOCK_SIZE * (NUM_PROMPT_BLOCKS + 0.5))
363+
364+
request_normal = create_request(request_id=1, num_tokens=NUM_TOKENS_LOCAL)
365+
request_remote = create_request(request_id=2,
366+
num_tokens=NUM_TOKENS_REMOTE,
367+
do_remote_prefill=True)
368+
369+
# STEP 1: 3 blocks are in use (2 for prompt, 1 for decode).
370+
scheduler.add_request(request_normal)
371+
scheduler_output = scheduler.schedule()
372+
model_runner_output = create_model_runner_output(reqs=[request_normal])
373+
scheduler.update_from_output(scheduler_output, model_runner_output)
374+
assert len(scheduler.running) == 1
375+
assert len(scheduler.waiting) == 0
376+
377+
# Step 2: 5 blocks are in use (2 new for remote blocks).
378+
scheduler.add_request(request_remote)
379+
scheduler_output = scheduler.schedule()
380+
model_runner_output = create_model_runner_output(reqs=[request_normal])
381+
scheduler.update_from_output(scheduler_output, model_runner_output)
382+
assert len(scheduler.running) == 1
383+
assert len(scheduler.waiting) == 1
384+
385+
# Step 3: finish recving (5 blocks in use)
386+
scheduler_output = scheduler.schedule()
387+
model_runner_output = create_model_runner_output(
388+
reqs=[request_normal], finished_recving=[request_remote.request_id])
389+
scheduler.update_from_output(scheduler_output, model_runner_output)
390+
assert len(scheduler.running) == 1
391+
assert len(scheduler.waiting) == 1
392+
393+
# Step 4: try to schedule, not enough blocks.
394+
scheduler_output = scheduler.schedule()
395+
model_runner_output = create_model_runner_output(reqs=[request_normal])
396+
scheduler.update_from_output(scheduler_output, model_runner_output)
397+
assert len(scheduler.running) == 1
398+
assert len(scheduler.waiting) == 1
399+
400+
# Step 5: finish the request, free it.
401+
scheduler_output = scheduler.schedule()
402+
model_runner_output = create_model_runner_output(reqs=[request_normal],
403+
use_eos=True)
404+
scheduler.update_from_output(scheduler_output, model_runner_output)
405+
assert len(scheduler.running) == 0
406+
assert len(scheduler.waiting) == 1
407+
408+
# Step 6: now we can schedule (with 2 blocks computed).
409+
scheduler_output = scheduler.schedule()
410+
model_runner_output = create_model_runner_output(reqs=[request_remote])
411+
assert (scheduler_output.scheduled_new_reqs[0].num_computed_tokens ==
412+
NUM_PROMPT_BLOCKS * BLOCK_SIZE)
413+
scheduler.update_from_output(scheduler_output, model_runner_output)
414+
assert len(scheduler.running) == 1
415+
assert len(scheduler.waiting) == 0
416+
417+
# Step 7: free everything.
418+
scheduler_output = scheduler.schedule()
419+
model_runner_output = create_model_runner_output(reqs=[request_remote],
420+
use_eos=True)
421+
scheduler.update_from_output(scheduler_output, model_runner_output)
422+
_ = scheduler.schedule()
423+
assert_scheduler_empty(scheduler)

vllm/v1/core/sched/scheduler.py

Lines changed: 16 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -310,15 +310,16 @@ def schedule(self) -> SchedulerOutput:
310310
break
311311

312312
request = self.waiting[0]
313-
num_prealloc_computed_tokens = 0
314-
# P/D: skip request if still waiting for remote kvs.
313+
314+
# KVTransfer: skip request if still waiting for remote kvs.
315315
if request.status == RequestStatus.WAITING_FOR_REMOTE_KVS:
316316
is_ready = self._update_waiting_for_remote_kv(request)
317317
if is_ready:
318318
request.status = RequestStatus.WAITING
319-
num_prealloc_computed_tokens = (
320-
request.num_computed_tokens)
321319
else:
320+
logger.debug(
321+
"%s is still in WAITING_FOR_REMOTE_KVS state.",
322+
request.request_id)
322323
self.waiting.popleft()
323324
skipped_waiting_requests.appendleft(request)
324325
continue
@@ -349,32 +350,32 @@ def schedule(self) -> SchedulerOutput:
349350
load_kv_async = False
350351

351352
# Get already-cached tokens.
352-
if num_prealloc_computed_tokens == 0:
353-
new_computed_blocks, num_native_computed_tokens = \
353+
if request.num_computed_tokens == 0:
354+
# Get locally-cached tokens.
355+
new_computed_blocks, num_new_local_computed_tokens = \
354356
self.kv_cache_manager.get_computed_blocks(
355357
request)
356358

357359
# Get externally-cached tokens if using a KVConnector.
358360
if self.connector is not None:
359361
num_external_computed_tokens, load_kv_async = (
360362
self.connector.get_num_new_matched_tokens(
361-
request, num_native_computed_tokens))
363+
request, num_new_local_computed_tokens))
362364

363365
# Total computed tokens (local + external).
364-
num_computed_tokens = (num_native_computed_tokens +
366+
num_computed_tokens = (num_new_local_computed_tokens +
365367
num_external_computed_tokens)
368+
# KVTransfer: WAITING reqs have num_computed_tokens > 0
369+
# after async KV recvs are completed.
366370
else:
367-
# P/D: skip checking prefix cache if loaded from remote kvs.
368371
new_computed_blocks = KVCacheBlocks.create_empty()
369-
num_native_computed_tokens = 0
370-
371-
# Total computed tokens (allocated in prior step).
372-
num_computed_tokens = num_prealloc_computed_tokens
372+
num_new_local_computed_tokens = 0
373+
num_computed_tokens = request.num_computed_tokens
373374

374375
encoder_inputs_to_schedule = None
375376
new_encoder_budget = encoder_budget
376377

377-
# P/D: loading remote KV, do not allocate for new work.
378+
# KVTransfer: loading remote KV, do not allocate for new work.
378379
if load_kv_async:
379380
assert num_external_computed_tokens > 0
380381
num_new_tokens = 0
@@ -405,7 +406,7 @@ def schedule(self) -> SchedulerOutput:
405406
new_blocks = self.kv_cache_manager.allocate_slots(
406407
request,
407408
num_new_tokens + num_external_computed_tokens,
408-
num_native_computed_tokens,
409+
num_new_local_computed_tokens,
409410
new_computed_blocks,
410411
num_lookahead_tokens=self.num_lookahead_tokens,
411412
delay_cache_blocks=load_kv_async,

0 commit comments

Comments
 (0)