@@ -340,3 +340,84 @@ def test_full_block_prompt():
340
340
output = outputs [0 ]
341
341
assert output .finish_reason == FinishReason .STOP
342
342
assert_scheduler_empty (scheduler )
343
+
344
+
345
+ def test_cannot_schedule_after_recv ():
346
+ """
347
+ Test that we can handle no schedule after recv due to not
348
+ enough remaining KV blocks.
349
+ """
350
+
351
+ # NOTE: the KVCacheManager will use 1 null block.
352
+ # So there are 5 total working blocks.
353
+ TOTAL_NUM_BLOCKS = 6
354
+ vllm_config = create_vllm_config ()
355
+ scheduler = create_scheduler (vllm_config , num_blocks = TOTAL_NUM_BLOCKS )
356
+
357
+ # Prime the KVCache.
358
+ NUM_PROMPT_BLOCKS = 2
359
+ BLOCK_SIZE = vllm_config .cache_config .block_size
360
+ # Prompt will use 2 blocks + 1 block after we schedule.
361
+ NUM_TOKENS_LOCAL = int (BLOCK_SIZE * NUM_PROMPT_BLOCKS )
362
+ NUM_TOKENS_REMOTE = int (BLOCK_SIZE * (NUM_PROMPT_BLOCKS + 0.5 ))
363
+
364
+ request_normal = create_request (request_id = 1 , num_tokens = NUM_TOKENS_LOCAL )
365
+ request_remote = create_request (request_id = 2 ,
366
+ num_tokens = NUM_TOKENS_REMOTE ,
367
+ do_remote_prefill = True )
368
+
369
+ # STEP 1: 3 blocks are in use (2 for prompt, 1 for decode).
370
+ scheduler .add_request (request_normal )
371
+ scheduler_output = scheduler .schedule ()
372
+ model_runner_output = create_model_runner_output (reqs = [request_normal ])
373
+ scheduler .update_from_output (scheduler_output , model_runner_output )
374
+ assert len (scheduler .running ) == 1
375
+ assert len (scheduler .waiting ) == 0
376
+
377
+ # Step 2: 5 blocks are in use (2 new for remote blocks).
378
+ scheduler .add_request (request_remote )
379
+ scheduler_output = scheduler .schedule ()
380
+ model_runner_output = create_model_runner_output (reqs = [request_normal ])
381
+ scheduler .update_from_output (scheduler_output , model_runner_output )
382
+ assert len (scheduler .running ) == 1
383
+ assert len (scheduler .waiting ) == 1
384
+
385
+ # Step 3: finish recving (5 blocks in use)
386
+ scheduler_output = scheduler .schedule ()
387
+ model_runner_output = create_model_runner_output (
388
+ reqs = [request_normal ], finished_recving = [request_remote .request_id ])
389
+ scheduler .update_from_output (scheduler_output , model_runner_output )
390
+ assert len (scheduler .running ) == 1
391
+ assert len (scheduler .waiting ) == 1
392
+
393
+ # Step 4: try to schedule, not enough blocks.
394
+ scheduler_output = scheduler .schedule ()
395
+ model_runner_output = create_model_runner_output (reqs = [request_normal ])
396
+ scheduler .update_from_output (scheduler_output , model_runner_output )
397
+ assert len (scheduler .running ) == 1
398
+ assert len (scheduler .waiting ) == 1
399
+
400
+ # Step 5: finish the request, free it.
401
+ scheduler_output = scheduler .schedule ()
402
+ model_runner_output = create_model_runner_output (reqs = [request_normal ],
403
+ use_eos = True )
404
+ scheduler .update_from_output (scheduler_output , model_runner_output )
405
+ assert len (scheduler .running ) == 0
406
+ assert len (scheduler .waiting ) == 1
407
+
408
+ # Step 6: now we can schedule (with 2 blocks computed).
409
+ scheduler_output = scheduler .schedule ()
410
+ model_runner_output = create_model_runner_output (reqs = [request_remote ])
411
+ assert (scheduler_output .scheduled_new_reqs [0 ].num_computed_tokens ==
412
+ NUM_PROMPT_BLOCKS * BLOCK_SIZE )
413
+ scheduler .update_from_output (scheduler_output , model_runner_output )
414
+ assert len (scheduler .running ) == 1
415
+ assert len (scheduler .waiting ) == 0
416
+
417
+ # Step 7: free everything.
418
+ scheduler_output = scheduler .schedule ()
419
+ model_runner_output = create_model_runner_output (reqs = [request_remote ],
420
+ use_eos = True )
421
+ scheduler .update_from_output (scheduler_output , model_runner_output )
422
+ _ = scheduler .schedule ()
423
+ assert_scheduler_empty (scheduler )
0 commit comments