Skip to content

Commit 3336814

Browse files
authored
[Bugfix][V1] Handle MLA in kv_cache_interface (#14462)
Signed-off-by: Tyler Michael Smith <[email protected]>
1 parent ef64044 commit 3336814

File tree

3 files changed

+15
-10
lines changed

3 files changed

+15
-10
lines changed

vllm/v1/kv_cache_interface.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,9 @@ class KVCacheSpecBase:
2323
def type_id(self) -> str:
2424
"""
2525
The type identifier of this KV cache.
26-
Return different strings for layers with different KV cache type (e.g.,
27-
different number of tokens like full attention vs sliding window
28-
attention, different KV cache size per token like layers with different
26+
Return different strings for layers with different KV cache type (e.g.,
27+
different number of tokens like full attention vs sliding window
28+
attention, different KV cache size per token like layers with different
2929
number of heads)
3030
3131
Returns:
@@ -59,14 +59,17 @@ class FullAttentionSpec(KVCacheSpecBase):
5959
num_kv_heads: int
6060
head_size: int
6161
dtype: torch.dtype
62+
use_mla: bool
6263

6364
@property
6465
def type_id(self) -> str:
6566
return f"full_attention_{self.block_size}_{self.page_size_bytes}"
6667

6768
@property
6869
def page_size_bytes(self) -> int:
69-
return 2 * self.block_size * self.num_kv_heads * self.head_size \
70+
# For MLA we only store a single latent vector
71+
coef = 1 if self.use_mla else 2
72+
return coef * self.block_size * self.num_kv_heads * self.head_size \
7073
* get_dtype_size(self.dtype)
7174

7275
def bytes_for_tokens(self, num_tokens: int) -> int:
@@ -104,7 +107,7 @@ class KVCacheConfig:
104107
2. (not implemented yet) A model with the same number of full attention
105108
layers and sliding window attention layers: two groups, one for full
106109
attention layers and one for sliding window attention layers.
107-
3. (not implemented yet) A model with 2 full attention layers and 4 sliding
110+
3. (not implemented yet) A model with 2 full attention layers and 4 sliding
108111
window attention layers: three groups, (full * 2), (sw * 2), (sw * 2).
109112
"""
110113
groups: list[list[str]]

vllm/v1/worker/gpu_model_runner.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1460,21 +1460,22 @@ def get_kv_cache_spec(self) -> KVCacheSpec:
14601460

14611461
forward_ctx = self.vllm_config.compilation_config.static_forward_context
14621462
block_size = self.vllm_config.cache_config.block_size
1463+
use_mla = self.vllm_config.model_config.use_mla
14631464
kv_cache_spec: KVCacheSpec = {}
14641465
for layer_name, attn_module in forward_ctx.items():
14651466
if isinstance(attn_module, FusedMoE):
14661467
continue
14671468

14681469
# TODO: Support other attention modules, e.g., sliding window,
1469-
# cross-attention, MLA.
1470+
# cross-attention
14701471
assert isinstance(attn_module, Attention)
14711472
if attn_module.attn_type == AttentionType.DECODER:
14721473
kv_cache_spec[layer_name] = FullAttentionSpec(
14731474
block_size=block_size,
14741475
num_kv_heads=attn_module.num_kv_heads,
14751476
head_size=attn_module.head_size,
14761477
dtype=attn_module.dtype,
1477-
)
1478+
use_mla=use_mla)
14781479
elif attn_module.attn_type in (AttentionType.ENCODER,
14791480
AttentionType.ENCODER_ONLY):
14801481
# encoder-only attention does not need KV cache.

vllm/v1/worker/tpu_model_runner.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -303,10 +303,10 @@ def get_model(self) -> nn.Module:
303303

304304
def get_kv_cache_spec(self) -> KVCacheSpec:
305305
"""
306-
Generates the KVCacheSpec by parsing the kv cache format from each
306+
Generates the KVCacheSpec by parsing the kv cache format from each
307307
Attention module in the static forward context.
308308
Returns:
309-
KVCacheSpec: A dictionary mapping layer names to their KV cache
309+
KVCacheSpec: A dictionary mapping layer names to their KV cache
310310
format. Layers that do not need KV cache are not included.
311311
"""
312312

@@ -323,6 +323,7 @@ def get_kv_cache_spec(self) -> KVCacheSpec:
323323
num_kv_heads=attn_module.num_kv_heads,
324324
head_size=attn_module.head_size,
325325
dtype=attn_module.dtype,
326+
use_mla=False,
326327
)
327328
elif attn_module.attn_type in (AttentionType.ENCODER,
328329
AttentionType.ENCODER_ONLY):
@@ -764,7 +765,7 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
764765
"""
765766
Initialize KV cache based on `kv_cache_config`.
766767
Args:
767-
kv_cache_config: Configuration for the KV cache, including the KV
768+
kv_cache_config: Configuration for the KV cache, including the KV
768769
cache size of each layer
769770
"""
770771
if len(kv_cache_config.groups) > 1:

0 commit comments

Comments
 (0)