Skip to content

Commit 645d061

Browse files
Support for attention free models revisited to reuse existing KVCache maanger.
Signed-off-by: Christian Pinto <[email protected]>
1 parent eda0697 commit 645d061

File tree

5 files changed

+53
-35
lines changed

5 files changed

+53
-35
lines changed

vllm/v1/core/kv_cache_coordinator.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -219,7 +219,9 @@ def __init__(self, kv_cache_config: KVCacheConfig, max_model_len: int,
219219
super().__init__(kv_cache_config, max_model_len, use_eagle,
220220
enable_caching, caching_hash_fn,
221221
enable_kv_cache_events)
222-
self.verify_and_split_kv_cache_groups()
222+
# attention free models are initialized with 0 kv_cache_groups
223+
if len(self.kv_cache_config.kv_cache_groups) > 0:
224+
self.verify_and_split_kv_cache_groups()
223225

224226
def verify_and_split_kv_cache_groups(self) -> None:
225227
"""

vllm/v1/core/kv_cache_manager.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -84,12 +84,17 @@ def __init__(
8484
self.log_stats = log_stats
8585
# FIXME: make prefix cache stats conditional on log_stats
8686
self.prefix_cache_stats = PrefixCacheStats() if log_stats else None
87-
assert len(
88-
set(g.kv_cache_spec.block_size
89-
for g in kv_cache_config.kv_cache_groups)
90-
) == 1, "Only one block size is supported for now"
91-
self.block_size = kv_cache_config.kv_cache_groups[
92-
0].kv_cache_spec.block_size
87+
88+
if len(kv_cache_config.kv_cache_groups) == 0:
89+
#This is an attention free model that is started with 0 KVCache groups.
90+
self.block_size = 0
91+
else:
92+
assert len(
93+
set(g.kv_cache_spec.block_size
94+
for g in kv_cache_config.kv_cache_groups)
95+
) == 1, "Only one block size is supported for now"
96+
self.block_size = kv_cache_config.kv_cache_groups[
97+
0].kv_cache_spec.block_size
9398

9499
self.coordinator = get_kv_cache_coordinator(
95100
kv_cache_config=kv_cache_config,

vllm/v1/core/kv_cache_utils.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -551,6 +551,10 @@ def check_enough_kv_cache_memory(vllm_config: VllmConfig,
551551
ValueError: If there is not enough memory available for the KV cache.
552552
"""
553553

554+
# No need to check for available memory if the model is attention free
555+
if vllm_config.model_config.is_attention_free:
556+
return
557+
554558
if available_memory <= 0:
555559
raise ValueError("No available memory for the cache blocks. "
556560
"Try increasing `gpu_memory_utilization` when "
@@ -736,6 +740,11 @@ def is_kv_cache_page_size_uniform(
736740
page_sizes = {layer.page_size_bytes for layer in kv_cache_spec.values()}
737741
return len(page_sizes) == 1
738742

743+
def is_kv_cache_type_attention_free(
744+
kv_cache_spec: dict[str, KVCacheSpec]) -> bool:
745+
746+
if "attention_free" in kv_cache_spec:
747+
return True
739748

740749
def _get_kv_cache_config_uniform_page_size(
741750
vllm_config: VllmConfig, kv_cache_spec: dict[str, KVCacheSpec],
@@ -877,6 +886,11 @@ def _get_kv_cache_config_uniform_page_size(
877886
return kv_cache_config
878887

879888

889+
def _get_kv_cache_config_attention_free() -> KVCacheConfig:
890+
return KVCacheConfig(num_blocks=1,
891+
kv_cache_tensors=[],
892+
kv_cache_groups=[])
893+
880894
def unify_hybrid_kv_cache_specs(kv_cache_spec: dict[str, KVCacheSpec]):
881895
"""
882896
This function tries to convert the KV cache specs to one type if the model
@@ -943,7 +957,9 @@ def get_kv_cache_config(
943957
if vllm_config.scheduler_config.disable_hybrid_kv_cache_manager:
944958
unify_hybrid_kv_cache_specs(kv_cache_spec)
945959

946-
if is_kv_cache_type_uniform(kv_cache_spec):
960+
if is_kv_cache_type_attention_free(kv_cache_spec):
961+
return _get_kv_cache_config_attention_free()
962+
elif is_kv_cache_type_uniform(kv_cache_spec):
947963
# KV cache of all layers are the same, which is true for
948964
# most models. Allocate the same amount of memory for
949965
# each layer.

vllm/v1/engine/core.py

Lines changed: 16 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -134,33 +134,22 @@ def _initialize_kv_caches(
134134
self, vllm_config: VllmConfig) -> tuple[int, int, KVCacheConfig]:
135135
start = time.time()
136136

137-
#TODO: CP start from here
138-
if vllm_config.model_config.is_attention_free:
139-
# No need for initializing anything related to KV cache if the model
140-
# is attention free.
141-
kv_cache_specs = []
142-
kv_cache_configs = [
143-
KVCacheConfig(num_blocks=0,
144-
kv_cache_tensors=[],
145-
kv_cache_groups=[])
146-
]
147-
else:
148-
# Get all kv cache needed by the model
149-
kv_cache_specs = self.model_executor.get_kv_cache_specs()
150-
151-
# Profiles the peak memory usage of the model to determine how much
152-
# memory can be allocated for kv cache.
153-
available_gpu_memory = (
154-
self.model_executor.determine_available_memory())
155-
156-
assert len(kv_cache_specs) == len(available_gpu_memory)
157-
# Get the kv cache tensor size
158-
kv_cache_configs = [
159-
get_kv_cache_config(vllm_config, kv_cache_spec_one_worker,
160-
available_gpu_memory_one_worker)
161-
for kv_cache_spec_one_worker, available_gpu_memory_one_worker
162-
in zip(kv_cache_specs, available_gpu_memory)
163-
]
137+
# Get all kv cache needed by the model
138+
kv_cache_specs = self.model_executor.get_kv_cache_specs()
139+
140+
# Profiles the peak memory usage of the model to determine how much
141+
# memory can be allocated for kv cache.
142+
available_gpu_memory = (
143+
self.model_executor.determine_available_memory())
144+
145+
assert len(kv_cache_specs) == len(available_gpu_memory)
146+
# Get the kv cache tensor size
147+
kv_cache_configs = [
148+
get_kv_cache_config(vllm_config, kv_cache_spec_one_worker,
149+
available_gpu_memory_one_worker)
150+
for kv_cache_spec_one_worker, available_gpu_memory_one_worker
151+
in zip(kv_cache_specs, available_gpu_memory)
152+
]
164153

165154
# Since we use a shared centralized controller, we need the
166155
# `kv_cache_config` to be consistent across all workers to make sure

vllm/v1/executor/abstract.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,10 +73,16 @@ def register_failure_callback(self, callback: FailureCallback):
7373
pass
7474

7575
def determine_available_memory(self) -> list[int]: # in bytes
76+
if self.vllm_config.model_config.is_attention_free:
77+
return [0]
78+
7679
output = self.collective_rpc("determine_available_memory")
7780
return output
7881

7982
def get_kv_cache_specs(self) -> list[dict[str, KVCacheSpec]]:
83+
if self.vllm_config.model_config.is_attention_free:
84+
return [{"attention_free": KVCacheSpec(block_size=0)}]
85+
8086
output = self.collective_rpc("get_kv_cache_spec")
8187
return output
8288

0 commit comments

Comments
 (0)