Skip to content

Commit 5e6f939

Browse files
[Attention] MLA move rotary embedding to cuda-graph region (#17668)
Signed-off-by: Lucas Wilkinson <[email protected]>
1 parent 760e3ec commit 5e6f939

File tree

6 files changed

+35
-121
lines changed

6 files changed

+35
-121
lines changed

vllm/attention/backends/mla/common.py

Lines changed: 11 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -211,8 +211,6 @@
211211
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
212212
LinearBase,
213213
UnquantizedLinearMethod)
214-
from vllm.model_executor.layers.rotary_embedding import (
215-
DeepseekScalingRotaryEmbedding, RotaryEmbedding)
216214
from vllm.multimodal import MultiModalPlaceholderMap
217215
from vllm.platforms import current_platform
218216
from vllm.triton_utils import HAS_TRITON
@@ -377,7 +375,6 @@ def graph_capture_get_metadata_for_batch(
377375
seq_start_loc=None,
378376
context_lens_tensor=None,
379377
block_tables=self._graph_block_tables[:batch_size],
380-
input_positions=self._positions[:batch_size],
381378
head_dim=self.runner.model_config.get_head_size())
382379

383380
if is_encoder_decoder_model:
@@ -393,7 +390,6 @@ def get_graph_input_buffers(self,
393390
"slot_mapping": attn_metadata.slot_mapping,
394391
"seq_lens_tensor": attn_metadata.decode_metadata.seq_lens_tensor,
395392
"block_tables": attn_metadata.decode_metadata.block_tables,
396-
"input_positions": attn_metadata.decode_metadata.input_positions,
397393
}
398394
if is_encoder_decoder_model:
399395
raise NotImplementedError(
@@ -405,16 +401,10 @@ def prepare_graph_input_buffers(self,
405401
input_buffers,
406402
attn_metadata,
407403
is_encoder_decoder_model: bool = False):
408-
input_positions = attn_metadata.input_positions
409-
num_positions = input_positions.shape[0]
410404
input_buffers["seq_lens_tensor"].copy_(
411405
attn_metadata.decode_metadata.seq_lens_tensor, non_blocking=True)
412406
input_buffers["block_tables"].copy_(
413407
attn_metadata.decode_metadata.block_tables, non_blocking=True)
414-
# CUDA graph buffer is padded so only perform a partial copy based on
415-
# num_positions
416-
input_buffers["input_positions"][:num_positions].copy_(
417-
input_positions, non_blocking=True)
418408
if is_encoder_decoder_model:
419409
raise NotImplementedError(
420410
"TritonMLAState does not support encoder/decoder yet")
@@ -456,11 +446,6 @@ class MLACommonMetadata(AttentionMetadata):
456446
# TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention.
457447
use_cuda_graph: bool
458448

459-
# New for MLA (compared to FlashAttention)
460-
# Input positions for rotrary embeddings since for MLA the rotary
461-
# position embeddings are applied inside the attention backend
462-
input_positions: torch.Tensor
463-
464449
# NOTE(sang): Definition of context_len, query_len, and seq_len.
465450
# |---------- N-1 iteration --------|
466451
# |---------------- N iteration ---------------------|
@@ -563,8 +548,6 @@ def prefill_metadata(self):
563548
self.context_lens_tensor[:self.num_prefills])
564549
block_tables = (None if self.block_tables is None else
565550
self.block_tables[:self.num_prefills])
566-
input_positions = (None if self.input_positions is None else
567-
self.input_positions[:self.num_prefill_tokens])
568551

569552
self._cached_prefill_metadata = self.__class__(
570553
# Required by ModelRunner
@@ -578,7 +561,6 @@ def prefill_metadata(self):
578561
multi_modal_placeholder_index_maps=None,
579562
enable_kv_scales_calculation=False,
580563
# MLACommonMetadata
581-
input_positions=input_positions,
582564
seq_lens=seq_lens,
583565
seq_lens_tensor=seq_lens_tensor,
584566
max_query_len=self.max_query_len,
@@ -615,8 +597,6 @@ def decode_metadata(self):
615597
self.seq_lens_tensor[self.num_prefills:])
616598
block_tables = (None if self.block_tables is None else
617599
self.block_tables[self.num_prefills:])
618-
input_positions = (None if self.input_positions is None else
619-
self.input_positions[self.num_prefill_tokens:])
620600

621601
self._cached_decode_metadata = self.__class__(
622602
# Required by ModelRunner
@@ -646,7 +626,6 @@ def decode_metadata(self):
646626
if self.seq_start_loc is not None else None,
647627
context_lens_tensor=None,
648628
block_tables=block_tables,
649-
input_positions=input_positions,
650629
head_dim=self.head_dim,
651630
is_profile_run=self.is_profile_run)
652631
return self._cached_decode_metadata
@@ -765,7 +744,6 @@ def prepare(self):
765744
self.context_lens: List[int] = []
766745
self.block_tables: List[List[int]] = []
767746
self.curr_seq_lens: List[int] = []
768-
self.input_positions: List[int] = []
769747
self.multimodal_placeholder_maps: Dict[
770748
str,
771749
MultiModalPlaceholderMap] = defaultdict(MultiModalPlaceholderMap)
@@ -786,13 +764,11 @@ def _add_seq_group(
786764
block_tables = inter_data.block_tables
787765

788766
for (seq_id, token_len, seq_len, curr_seq_len, query_len, context_len,
789-
curr_sliding_window_block, input_positions) in zip(
767+
curr_sliding_window_block) in zip(
790768
inter_data.seq_ids, [len(t) for t in inter_data.input_tokens],
791769
inter_data.orig_seq_lens, inter_data.seq_lens,
792770
inter_data.query_lens, inter_data.context_lens,
793-
inter_data.curr_sliding_window_blocks,
794-
inter_data.input_positions):
795-
self.input_positions.extend(input_positions)
771+
inter_data.curr_sliding_window_blocks):
796772
self.context_lens.append(context_len)
797773
if is_prompt:
798774
self.num_prefills += 1
@@ -912,8 +888,6 @@ def build(self, seq_lens: List[int], query_lens: List[int],
912888
device, self.runner.pin_memory)
913889
seq_lens_tensor = async_tensor_h2d(seq_lens, torch.int, device,
914890
self.runner.pin_memory)
915-
input_positions = async_tensor_h2d(self.input_positions, torch.long,
916-
device, self.runner.pin_memory)
917891
slot_mapping_tensor = async_tensor_h2d(self.slot_mapping, torch.long,
918892
device, self.runner.pin_memory)
919893
query_start_loc_tensor = async_tensor_h2d(query_start_loc, torch.int32,
@@ -987,7 +961,6 @@ def build(self, seq_lens: List[int], query_lens: List[int],
987961
multi_modal_placeholder_index_maps=None, # Not Attention Related
988962
enable_kv_scales_calculation=False,
989963
# MLACommonMetadata
990-
input_positions=input_positions,
991964
seq_lens=seq_lens,
992965
seq_lens_tensor=seq_lens_tensor,
993966
max_query_len=max_query_len,
@@ -1033,7 +1006,6 @@ def __init__(
10331006
qk_rope_head_dim: int,
10341007
qk_head_dim: int,
10351008
v_head_dim: int,
1036-
rotary_emb: RotaryEmbedding,
10371009
kv_b_proj: ColumnParallelLinear,
10381010
) -> None:
10391011
self.num_heads = num_heads
@@ -1048,10 +1020,6 @@ def __init__(
10481020
self.qk_rope_head_dim = qk_rope_head_dim
10491021
self.qk_head_dim = qk_head_dim
10501022
self.v_head_dim = v_head_dim
1051-
1052-
self.rotary_emb = rotary_emb
1053-
self.use_yarn_rope = isinstance(rotary_emb,
1054-
DeepseekScalingRotaryEmbedding)
10551023
self.kv_b_proj = kv_b_proj
10561024

10571025
self.triton_fa_func = triton_attention
@@ -1367,41 +1335,15 @@ def forward(
13671335
has_decode = attn_metadata.decode_metadata is not None
13681336
has_prefill = attn_metadata.prefill_metadata is not None
13691337

1370-
# Restore head dim (for rotary embedding)
1371-
k_pe = k_pe.unsqueeze(1)
1372-
assert hasattr(attn_metadata, "input_positions")
1373-
13741338
num_prefill_tokens: int = attn_metadata.num_prefill_tokens
13751339
q = q.view(-1, self.num_heads, self.qk_head_dim)
13761340

13771341
decode_q = q[num_prefill_tokens:]
1378-
decode_k_pe = k_pe[num_prefill_tokens:]
1379-
decode_input_positions = \
1380-
attn_metadata.input_positions[num_prefill_tokens:]
13811342

13821343
prefill_q = q[:num_prefill_tokens]
13831344
prefill_k_pe = k_pe[:num_prefill_tokens]
1384-
prefill_input_positions = \
1385-
attn_metadata.input_positions[:num_prefill_tokens]
13861345
prefill_k_c_normed = k_c_normed[:num_prefill_tokens]
13871346

1388-
if has_decode:
1389-
decode_q_nope, decode_q_pe = decode_q.split(
1390-
[self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
1391-
# Convert from (B, N, P) to (N, B, P)
1392-
decode_q_nope = decode_q_nope.transpose(0, 1)
1393-
# Multiply (N, B, P) x (N, P, L) -> (N, B, L)
1394-
decode_ql_nope = torch.bmm(decode_q_nope, self.W_UK_T)
1395-
# Convert from (N, B, L) to (B, N, L)
1396-
decode_ql_nope = decode_ql_nope.transpose(0, 1)
1397-
decode_q_pe[...], decode_k_pe[...] = self.rotary_emb(
1398-
decode_input_positions, decode_q_pe, decode_k_pe)
1399-
1400-
if has_prefill:
1401-
prefill_q_pe = prefill_q[..., self.qk_nope_head_dim:]
1402-
prefill_q_pe[...], prefill_k_pe[...] = self.rotary_emb(
1403-
prefill_input_positions, prefill_q_pe, prefill_k_pe)
1404-
14051347
# write the latent and rope to kv cache
14061348
if kv_cache.numel() > 0:
14071349
ops.concat_and_cache_mla(
@@ -1424,6 +1366,15 @@ def forward(
14241366
attn_metadata)
14251367

14261368
if has_decode:
1369+
decode_q_nope, decode_q_pe = decode_q.split(
1370+
[self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
1371+
# Convert from (B, N, P) to (N, B, P)
1372+
decode_q_nope = decode_q_nope.transpose(0, 1)
1373+
# Multiply (N, B, P) x (N, P, L) -> (N, B, L)
1374+
decode_ql_nope = torch.bmm(decode_q_nope, self.W_UK_T)
1375+
# Convert from (N, B, L) to (B, N, L)
1376+
decode_ql_nope = decode_ql_nope.transpose(0, 1)
1377+
14271378
output[num_prefill_tokens:] = self._forward_decode(
14281379
decode_ql_nope, decode_q_pe, kv_cache, attn_metadata)
14291380

vllm/attention/backends/rocm_aiter_mla.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -148,13 +148,11 @@ def _add_seq_group(self, inter_data, chunked_prefill_enabled: bool,
148148
block_tables = inter_data.block_tables
149149

150150
for (seq_id, token_len, seq_len, curr_seq_len, query_len, context_len,
151-
curr_sliding_window_block, input_positions) in zip(
151+
curr_sliding_window_block) in zip(
152152
inter_data.seq_ids, [len(t) for t in inter_data.input_tokens],
153153
inter_data.orig_seq_lens, inter_data.seq_lens,
154154
inter_data.query_lens, inter_data.context_lens,
155-
inter_data.curr_sliding_window_blocks,
156-
inter_data.input_positions):
157-
self.input_positions.extend(input_positions)
155+
inter_data.curr_sliding_window_blocks):
158156
self.context_lens.append(context_len)
159157
if is_prompt:
160158
self.num_prefills += 1

vllm/model_executor/layers/rotary_embedding.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -808,8 +808,9 @@ def forward(
808808
query_pass = query[..., self.rotary_dim:]
809809
key_pass = key[..., self.rotary_dim:]
810810

811-
self.cos_sin_cache: torch.Tensor = self.cos_sin_cache.to(
812-
positions.device)
811+
if self.cos_sin_cache.device != positions.device:
812+
self.cos_sin_cache: torch.Tensor = self.cos_sin_cache.to(
813+
positions.device)
813814
cos_sin = self.cos_sin_cache[torch.add(positions, offsets)
814815
if offsets is not None else positions]
815816
cos, sin = cos_sin.chunk(2, dim=-1)

vllm/model_executor/models/deepseek_v2.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -453,7 +453,6 @@ def __init__(
453453
qk_rope_head_dim=self.qk_rope_head_dim,
454454
qk_head_dim=self.qk_head_dim,
455455
v_head_dim=self.v_head_dim,
456-
rotary_emb=self.rotary_emb,
457456
kv_b_proj=self.kv_b_proj,
458457
)
459458

@@ -475,6 +474,13 @@ def forward(
475474
[self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
476475
kv_c_normed = self.kv_a_layernorm(kv_c.contiguous())
477476

477+
q = q.view(-1, self.num_local_heads, self.qk_head_dim)
478+
# Add head dim of 1 to k_pe
479+
k_pe = k_pe.unsqueeze(1)
480+
481+
q[..., self.qk_nope_head_dim:], k_pe = self.rotary_emb(
482+
positions, q[..., self.qk_nope_head_dim:], k_pe)
483+
478484
attn_out = self.mla_attn(
479485
q,
480486
kv_c_normed,

0 commit comments

Comments
 (0)