Skip to content

Commit 5bf8789

Browse files
authored
[Bugfix] Block manager v2 with preemption and lookahead slots (#8824)
1 parent d153703 commit 5bf8789

File tree

9 files changed

+133
-116
lines changed

9 files changed

+133
-116
lines changed

tests/basic_correctness/test_preemption.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,10 @@
2323
@pytest.fixture(scope="module", autouse=True)
2424
def check_settings():
2525
assert ENABLE_ARTIFICIAL_PREEMPT is True, (
26-
"Use an env var VLLM_TEST_ENABLE_ARTIFICIAL_PREEMPT=1. "
27-
"`VLLM_TEST_ENABLE_ARTIFICIAL_PREEMPT=1 pytest "
26+
"Use an env var VLLM_TEST_ENABLE_ARTIFICIAL_PREEMPT=1, "
27+
"VLLM_ALLOW_DEPRECATED_BEAM_SEARCH=1. "
28+
"`VLLM_TEST_ENABLE_ARTIFICIAL_PREEMPT=1 "
29+
"VLLM_ALLOW_DEPRECATED_BEAM_SEARCH=1 pytest "
2830
"tests/basic_correctness/test_preemption.py`")
2931

3032

@@ -199,6 +201,7 @@ def test_swap(
199201
@pytest.mark.parametrize("dtype", ["float"])
200202
@pytest.mark.parametrize("max_tokens", [96])
201203
@pytest.mark.parametrize("beam_width", [4])
204+
@pytest.mark.parametrize("use_v2_block_manager", [True, False])
202205
def test_swap_infeasible(
203206
vllm_runner,
204207
example_prompts,
@@ -207,6 +210,7 @@ def test_swap_infeasible(
207210
max_tokens: int,
208211
beam_width: int,
209212
worker_use_ray: bool,
213+
use_v2_block_manager: bool,
210214
) -> None:
211215
"""Verify infeasible swap request will be ignored."""
212216
BLOCK_SIZE = 16
@@ -223,6 +227,7 @@ def test_swap_infeasible(
223227
num_gpu_blocks_override=prefill_blocks + decode_blocks,
224228
max_model_len=(prefill_blocks + decode_blocks) * BLOCK_SIZE,
225229
worker_use_ray=worker_use_ray,
230+
use_v2_block_manager=use_v2_block_manager,
226231
) as vllm_model:
227232
sampling_params = SamplingParams(n=beam_width,
228233
use_beam_search=True,

tests/core/block/test_block_manager_v2.py

Lines changed: 46 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -373,6 +373,52 @@ def test_can_swap(block_size, num_gpu_blocks, num_lookahead_slots,
373373
seq_group, num_lookahead_slots) == AllocStatus.NEVER
374374

375375

376+
@pytest.mark.parametrize("num_lookahead_slots", [0, 2, 10])
377+
@pytest.mark.parametrize("enable_caching", [False, True])
378+
def test_swap_in_infeasible(num_lookahead_slots, enable_caching):
379+
"""Verifies that swapping fails if there is not enough free blocks
380+
to account for unseen tokens and lookahead_slots.
381+
"""
382+
block_size = 8
383+
num_cpu_blocks = 1
384+
num_gpu_blocks = 1
385+
block_manager = BlockSpaceManagerV2(block_size,
386+
num_cpu_blocks,
387+
num_gpu_blocks,
388+
watermark=0,
389+
enable_caching=enable_caching)
390+
prompt_length = block_size - 3
391+
assert prompt_length > 0
392+
prompt, seq_group = create_dummy_prompt("1", prompt_length=prompt_length)
393+
prompt.status = SequenceStatus.WAITING
394+
block_manager.allocate(seq_group)
395+
# Emulate a forward pass by appending a single token.
396+
# The block manager then knows how many unprocessed
397+
# tokens will be written in the next forward pass.
398+
token_id = 0
399+
prompt.status = SequenceStatus.RUNNING
400+
prompt.append_token_id(token_id, {token_id: Logprob(0.0)})
401+
402+
# Swap seq group from GPU -> CPU.
403+
assert block_manager.can_swap_out(seq_group)
404+
block_manager.swap_out(seq_group)
405+
prompt.status = SequenceStatus.SWAPPED
406+
407+
# Swap seq group from CPU -> GPU.
408+
# The number of unseen tokens is 1. If the number of existing
409+
# tokens plus the unseen ones and number of lookahead slots exceeds
410+
# the total number of available GPU blocks then the swap
411+
# should fail.
412+
num_unseen_tokens = 1
413+
if (num_lookahead_slots + num_unseen_tokens +
414+
prompt_length) <= (block_size * num_gpu_blocks):
415+
assert block_manager.can_swap_in(seq_group,
416+
num_lookahead_slots) == AllocStatus.OK
417+
else:
418+
assert block_manager.can_swap_in(
419+
seq_group, num_lookahead_slots) == AllocStatus.NEVER
420+
421+
376422
# TODO(cade/kaiyang): add comprehensive tests for swapping at allocator level.
377423

378424

@@ -400,7 +446,6 @@ def check_used(min_n, max_n=None):
400446
if max_n is None:
401447
max_n = min_n
402448
used = num_gpu_blocks - block_manager.get_num_free_gpu_blocks()
403-
#print("check", min_n, used, max_n)
404449
assert min_n <= used
405450
assert used <= max_n
406451

tests/core/block/test_naive_block.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -104,9 +104,9 @@ def test_get_num_free_blocks(allocate_type: str, num_blocks: int,
104104
@staticmethod
105105
@pytest.mark.parametrize("num_blocks", [4])
106106
@pytest.mark.parametrize("block_size", [8])
107-
def test_naive_block_get_num_blocks_touched(num_blocks, block_size):
107+
def test_naive_block_get_num_full_blocks_touched(num_blocks, block_size):
108108
""" Verify the allocator can correctly return the number of
109-
blocks touched, with different lookahead slots.
109+
full blocks touched.
110110
"""
111111
allocator_src = NaiveBlockAllocator(create_block=NaiveBlock,
112112
num_blocks=num_blocks,
@@ -124,7 +124,7 @@ def test_naive_block_get_num_blocks_touched(num_blocks, block_size):
124124
src_blocks = [allocate_block() for _ in range(num_blocks - 1)]
125125

126126
# All blocks are cached
127-
assert allocator_dst.get_num_blocks_touched(
127+
assert allocator_dst.get_num_full_blocks_touched(
128128
src_blocks) == num_blocks - 1
129129

130130
# Insert one non-full block in the src
@@ -136,9 +136,10 @@ def test_naive_block_get_num_blocks_touched(num_blocks, block_size):
136136
src_blocks.append(allocate_non_full_block())
137137
src_blocks[-1].append_token_ids([0])
138138

139-
assert allocator_dst.get_num_blocks_touched(
140-
src_blocks, num_lookahead_slots=1) == num_blocks
141-
assert allocator_dst.get_num_blocks_touched(
142-
src_blocks, num_lookahead_slots=block_size - 1) == num_blocks
143-
assert allocator_dst.get_num_blocks_touched(
144-
src_blocks, num_lookahead_slots=block_size) == (num_blocks + 1)
139+
assert allocator_dst.get_num_full_blocks_touched(
140+
src_blocks) == num_blocks - 1
141+
# Fill up the last source block and then invoke
142+
# get_num_blocks_touched
143+
src_blocks[-1].append_token_ids([0] * (block_size - 1))
144+
assert allocator_dst.get_num_full_blocks_touched(
145+
src_blocks) == num_blocks

tests/core/block/test_prefix_caching_block.py

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -318,11 +318,10 @@ def test_get_num_free_blocks(num_blocks: int, block_size: int, seed: int):
318318
@staticmethod
319319
@pytest.mark.parametrize("num_blocks", [4])
320320
@pytest.mark.parametrize("block_size", [8])
321-
def test_prefix_caching_block_get_num_blocks_touched(
321+
def test_prefix_caching_block_get_num_full_blocks_touched(
322322
num_blocks, block_size):
323323
""" Verify the allocator can correctly return the number of
324-
blocks touched, when there are cached prefixes and different
325-
lookahead slots.
324+
blocks touched, when there are cached prefixes.
326325
"""
327326
allocator_src = PrefixCachingBlockAllocator(num_blocks=num_blocks,
328327
block_size=block_size)
@@ -346,28 +345,30 @@ def test_prefix_caching_block_get_num_blocks_touched(
346345
token_ids=token_ids,
347346
allocator=allocator_src,
348347
)
349-
350348
# All blocks are cached
351-
assert allocator_dst.get_num_blocks_touched(blocks_to_swap_in) == 0
349+
assert allocator_dst.get_num_full_blocks_touched(
350+
blocks_to_swap_in) == 0
352351

353352
# Free the first block in the dst
354353
allocator_dst.free(cached_blocks[0])
355354

356355
# Now the first block becomes dangling, the swapped blocks need
357356
# to reclaim the first block in the dst
358-
assert allocator_dst.get_num_blocks_touched(blocks_to_swap_in) == 1
357+
assert allocator_dst.get_num_full_blocks_touched(
358+
blocks_to_swap_in) == 1
359359

360360
# Insert one non-full block in the src
361361
non_full_block = allocator_src.allocate_mutable_block(
362362
blocks_to_swap_in[-1])
363363
non_full_block.append_token_ids([0])
364364
blocks_to_swap_in.append(non_full_block)
365-
assert allocator_dst.get_num_blocks_touched(blocks_to_swap_in,
366-
num_lookahead_slots=1) == 2
367-
assert allocator_dst.get_num_blocks_touched(
368-
blocks_to_swap_in, num_lookahead_slots=block_size - 1) == 2
369-
assert allocator_dst.get_num_blocks_touched(
370-
blocks_to_swap_in, num_lookahead_slots=block_size) == 3
365+
assert allocator_dst.get_num_full_blocks_touched(
366+
blocks_to_swap_in) == 1
367+
# Fill up the last mutable block and invoke get_num_blocks_touched.
368+
# Note: The last block is not cached so it will be touched.
369+
non_full_block.append_token_ids([0] * (block_size - 1))
370+
assert allocator_dst.get_num_full_blocks_touched(
371+
blocks_to_swap_in) == 2
371372

372373
@staticmethod
373374
@pytest.mark.parametrize("num_blocks", [1024])

vllm/core/block/cpu_gpu_block_allocator.py

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -259,25 +259,22 @@ def swap(self, blocks: List[Block], src_device: Device,
259259
current_swap_mapping[src_block_id] = dst_block_id
260260
return current_swap_mapping
261261

262-
def get_num_blocks_touched(self,
263-
blocks: List[Block],
264-
device: Device,
265-
num_lookahead_slots: int = 0) -> int:
266-
"""Returns the number of blocks that will be touched by
262+
def get_num_full_blocks_touched(self, blocks: List[Block],
263+
device: Device) -> int:
264+
"""Returns the number of full blocks that will be touched by
267265
swapping in/out the given blocks on to the 'device'.
268266
269267
Args:
270268
blocks: List of blocks to be swapped.
271269
device (Device): Device to swap the 'blocks' on.
272-
num_lookahead_slots (int): Number of lookahead slots used in
273-
speculative decoding, default to 0.
274270
275271
Returns:
276-
int: the number of blocks that will be touched by
272+
int: the number of full blocks that will be touched by
277273
swapping in/out the given blocks on to the 'device'.
274+
Non full blocks are ignored when deciding the number
275+
of blocks to touch.
278276
"""
279-
return self._allocators[device].get_num_blocks_touched(
280-
blocks, num_lookahead_slots)
277+
return self._allocators[device].get_num_full_blocks_touched(blocks)
281278

282279
def clear_copy_on_writes(self) -> List[Tuple[int, int]]:
283280
"""Clears the copy-on-write (CoW) state and returns the mapping of

vllm/core/block/interfaces.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -181,9 +181,7 @@ def promote_to_immutable_block(self, block: Block) -> BlockId:
181181
pass
182182

183183
@abstractmethod
184-
def get_num_blocks_touched(self,
185-
blocks: List[Block],
186-
num_lookahead_slots: int = 0) -> int:
184+
def get_num_full_blocks_touched(self, blocks: List[Block]) -> int:
187185
pass
188186

189187
@abstractmethod
@@ -260,10 +258,8 @@ def get_common_computed_block_ids(
260258
pass
261259

262260
@abstractmethod
263-
def get_num_blocks_touched(self,
264-
blocks: List[Block],
265-
device: Device,
266-
num_lookahead_slots: int = 0) -> int:
261+
def get_num_full_blocks_touched(self, blocks: List[Block],
262+
device: Device) -> int:
267263
pass
268264

269265
@abstractmethod

vllm/core/block/naive_block.py

Lines changed: 10 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
from vllm.core.block.common import (BlockPool, CopyOnWriteTracker, RefCounter,
55
get_all_blocks_recursively)
66
from vllm.core.block.interfaces import Block, BlockAllocator, BlockId, Device
7-
from vllm.utils import cdiv
87

98
Refcount = int
109

@@ -282,40 +281,26 @@ def get_common_computed_block_ids(
282281
def promote_to_immutable_block(self, block: Block) -> BlockId:
283282
raise NotImplementedError("There is no promotion for naive blocks")
284283

285-
def get_num_blocks_touched(self,
286-
blocks: List[Block],
287-
num_lookahead_slots: int = 0) -> int:
288-
"""Determine the number of blocks that will be touched by
289-
swapping in/out the given blocks from certain sequence
290-
group with the provided num_lookahead_slots.
284+
def get_num_full_blocks_touched(self, blocks: List[Block]) -> int:
285+
"""Returns the number of full blocks that will be touched by
286+
swapping in/out.
291287
292288
Args:
293-
blocks (List[Block]): The potential blocks to swap.
294-
num_lookahead_slots (int): number of lookahead slots (0 for swap
295-
out).
296-
289+
blocks: List of blocks to be swapped.
297290
Returns:
298-
int: the number of blocks that will be touched by
299-
swapping in/out the given blocks and num_lookahead_slots.
291+
int: the number of full blocks that will be touched by
292+
swapping in/out the given blocks. Non full blocks are ignored
293+
when deciding the number of blocks to touch.
300294
"""
301295
# NOTE: for naive block, we use set to eliminate common blocks among
302296
# seqs, also we compare the empty slots in the mutable blocks with
303297
# lookahead slots to get the number of unique new block that are
304298
# needed.
305299
old_block_set = set()
306-
new_block_count = 0
307-
# TODO(cade): make sure the logic is correct and clean it up.
308300
for block in blocks:
309-
if not block.is_full and num_lookahead_slots != 0:
310-
new_block_count += 1
311-
if num_lookahead_slots > block.num_empty_slots:
312-
new_block_count += cdiv(
313-
num_lookahead_slots - block.num_empty_slots,
314-
self._block_size)
315-
else:
316-
old_block_set.add(block.block_id)
317-
num_touched_blocks = new_block_count + len(old_block_set)
318-
return num_touched_blocks
301+
if block.is_full:
302+
old_block_set.add(block)
303+
return len(old_block_set)
319304

320305
def swap_out(self, blocks: List[Block]) -> None:
321306
for block in blocks:

vllm/core/block/prefix_caching_block.py

Lines changed: 15 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
from vllm.core.block.naive_block import (BlockPool, NaiveBlock,
99
NaiveBlockAllocator)
1010
from vllm.core.evictor_v2 import EvictionPolicy, Evictor, make_evictor
11-
from vllm.utils import cdiv
1211

1312
PrefixHash = int
1413

@@ -576,37 +575,27 @@ def get_common_computed_block_ids(
576575
if ids
577576
])
578577

579-
def get_num_blocks_touched(self,
580-
blocks: List[Block],
581-
num_lookahead_slots: int = 0) -> int:
582-
"""Determine the number of blocks that will be touched by
583-
swapping in/out the given blocks from certain sequence
584-
group with the provided num_lookahead_slots.
578+
def get_num_full_blocks_touched(self, blocks: List[Block]) -> int:
579+
"""Returns the number of full blocks that will be touched by
580+
swapping in/out.
585581
586582
Args:
587-
blocks (List[Block]): The potential blocks to swap.
588-
num_lookahead_slots (int): number of lookahead slots (0 for
589-
swap out).
590-
583+
blocks: List of blocks to be swapped.
591584
Returns:
592-
int: the number of blocks that will be touched by
593-
swapping in/out the given blocks and num_lookahead_slots.
585+
int: the number of full blocks that will be touched by
586+
swapping in/out the given blocks. Non full blocks are ignored
587+
when deciding the number of blocks to touch.
594588
"""
595-
num_touched_blocks = 0
589+
num_touched_blocks: int = 0
596590
for block in blocks:
597-
if not block.is_full:
591+
# If the block has a match in the cache and the cached
592+
# block is not referenced, then we still count it as a
593+
# touched block
594+
if block.is_full and (not self.is_block_cached(block) or \
595+
(block.content_hash is not None and \
596+
self._cached_blocks[block.content_hash] in \
597+
self.evictor)):
598598
num_touched_blocks += 1
599-
if num_lookahead_slots > block.num_empty_slots:
600-
num_touched_blocks += cdiv(
601-
num_lookahead_slots - block.num_empty_slots,
602-
self._block_size)
603-
else:
604-
# If the block has a match in the cache and the cached block
605-
# is not referenced, then we still count it as a touched block
606-
if not self.is_block_cached(block) or \
607-
(block.content_hash is not None and \
608-
self._cached_blocks[block.content_hash] in self.evictor):
609-
num_touched_blocks += 1
610599
return num_touched_blocks
611600

612601
def swap_out(self, blocks: List[Block]) -> None:

0 commit comments

Comments
 (0)