211
211
from vllm .model_executor .layers .linear import (ColumnParallelLinear ,
212
212
LinearBase ,
213
213
UnquantizedLinearMethod )
214
- from vllm .model_executor .layers .rotary_embedding import (
215
- DeepseekScalingRotaryEmbedding , RotaryEmbedding )
216
214
from vllm .multimodal import MultiModalPlaceholderMap
217
215
from vllm .platforms import current_platform
218
216
from vllm .triton_utils import HAS_TRITON
@@ -377,7 +375,6 @@ def graph_capture_get_metadata_for_batch(
377
375
seq_start_loc = None ,
378
376
context_lens_tensor = None ,
379
377
block_tables = self ._graph_block_tables [:batch_size ],
380
- input_positions = self ._positions [:batch_size ],
381
378
head_dim = self .runner .model_config .get_head_size ())
382
379
383
380
if is_encoder_decoder_model :
@@ -393,7 +390,6 @@ def get_graph_input_buffers(self,
393
390
"slot_mapping" : attn_metadata .slot_mapping ,
394
391
"seq_lens_tensor" : attn_metadata .decode_metadata .seq_lens_tensor ,
395
392
"block_tables" : attn_metadata .decode_metadata .block_tables ,
396
- "input_positions" : attn_metadata .decode_metadata .input_positions ,
397
393
}
398
394
if is_encoder_decoder_model :
399
395
raise NotImplementedError (
@@ -405,16 +401,10 @@ def prepare_graph_input_buffers(self,
405
401
input_buffers ,
406
402
attn_metadata ,
407
403
is_encoder_decoder_model : bool = False ):
408
- input_positions = attn_metadata .input_positions
409
- num_positions = input_positions .shape [0 ]
410
404
input_buffers ["seq_lens_tensor" ].copy_ (
411
405
attn_metadata .decode_metadata .seq_lens_tensor , non_blocking = True )
412
406
input_buffers ["block_tables" ].copy_ (
413
407
attn_metadata .decode_metadata .block_tables , non_blocking = True )
414
- # CUDA graph buffer is padded so only perform a partial copy based on
415
- # num_positions
416
- input_buffers ["input_positions" ][:num_positions ].copy_ (
417
- input_positions , non_blocking = True )
418
408
if is_encoder_decoder_model :
419
409
raise NotImplementedError (
420
410
"TritonMLAState does not support encoder/decoder yet" )
@@ -456,11 +446,6 @@ class MLACommonMetadata(AttentionMetadata):
456
446
# TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention.
457
447
use_cuda_graph : bool
458
448
459
- # New for MLA (compared to FlashAttention)
460
- # Input positions for rotrary embeddings since for MLA the rotary
461
- # position embeddings are applied inside the attention backend
462
- input_positions : torch .Tensor
463
-
464
449
# NOTE(sang): Definition of context_len, query_len, and seq_len.
465
450
# |---------- N-1 iteration --------|
466
451
# |---------------- N iteration ---------------------|
@@ -563,8 +548,6 @@ def prefill_metadata(self):
563
548
self .context_lens_tensor [:self .num_prefills ])
564
549
block_tables = (None if self .block_tables is None else
565
550
self .block_tables [:self .num_prefills ])
566
- input_positions = (None if self .input_positions is None else
567
- self .input_positions [:self .num_prefill_tokens ])
568
551
569
552
self ._cached_prefill_metadata = self .__class__ (
570
553
# Required by ModelRunner
@@ -578,7 +561,6 @@ def prefill_metadata(self):
578
561
multi_modal_placeholder_index_maps = None ,
579
562
enable_kv_scales_calculation = False ,
580
563
# MLACommonMetadata
581
- input_positions = input_positions ,
582
564
seq_lens = seq_lens ,
583
565
seq_lens_tensor = seq_lens_tensor ,
584
566
max_query_len = self .max_query_len ,
@@ -615,8 +597,6 @@ def decode_metadata(self):
615
597
self .seq_lens_tensor [self .num_prefills :])
616
598
block_tables = (None if self .block_tables is None else
617
599
self .block_tables [self .num_prefills :])
618
- input_positions = (None if self .input_positions is None else
619
- self .input_positions [self .num_prefill_tokens :])
620
600
621
601
self ._cached_decode_metadata = self .__class__ (
622
602
# Required by ModelRunner
@@ -646,7 +626,6 @@ def decode_metadata(self):
646
626
if self .seq_start_loc is not None else None ,
647
627
context_lens_tensor = None ,
648
628
block_tables = block_tables ,
649
- input_positions = input_positions ,
650
629
head_dim = self .head_dim ,
651
630
is_profile_run = self .is_profile_run )
652
631
return self ._cached_decode_metadata
@@ -765,7 +744,6 @@ def prepare(self):
765
744
self .context_lens : List [int ] = []
766
745
self .block_tables : List [List [int ]] = []
767
746
self .curr_seq_lens : List [int ] = []
768
- self .input_positions : List [int ] = []
769
747
self .multimodal_placeholder_maps : Dict [
770
748
str ,
771
749
MultiModalPlaceholderMap ] = defaultdict (MultiModalPlaceholderMap )
@@ -786,13 +764,11 @@ def _add_seq_group(
786
764
block_tables = inter_data .block_tables
787
765
788
766
for (seq_id , token_len , seq_len , curr_seq_len , query_len , context_len ,
789
- curr_sliding_window_block , input_positions ) in zip (
767
+ curr_sliding_window_block ) in zip (
790
768
inter_data .seq_ids , [len (t ) for t in inter_data .input_tokens ],
791
769
inter_data .orig_seq_lens , inter_data .seq_lens ,
792
770
inter_data .query_lens , inter_data .context_lens ,
793
- inter_data .curr_sliding_window_blocks ,
794
- inter_data .input_positions ):
795
- self .input_positions .extend (input_positions )
771
+ inter_data .curr_sliding_window_blocks ):
796
772
self .context_lens .append (context_len )
797
773
if is_prompt :
798
774
self .num_prefills += 1
@@ -912,8 +888,6 @@ def build(self, seq_lens: List[int], query_lens: List[int],
912
888
device , self .runner .pin_memory )
913
889
seq_lens_tensor = async_tensor_h2d (seq_lens , torch .int , device ,
914
890
self .runner .pin_memory )
915
- input_positions = async_tensor_h2d (self .input_positions , torch .long ,
916
- device , self .runner .pin_memory )
917
891
slot_mapping_tensor = async_tensor_h2d (self .slot_mapping , torch .long ,
918
892
device , self .runner .pin_memory )
919
893
query_start_loc_tensor = async_tensor_h2d (query_start_loc , torch .int32 ,
@@ -987,7 +961,6 @@ def build(self, seq_lens: List[int], query_lens: List[int],
987
961
multi_modal_placeholder_index_maps = None , # Not Attention Related
988
962
enable_kv_scales_calculation = False ,
989
963
# MLACommonMetadata
990
- input_positions = input_positions ,
991
964
seq_lens = seq_lens ,
992
965
seq_lens_tensor = seq_lens_tensor ,
993
966
max_query_len = max_query_len ,
@@ -1033,7 +1006,6 @@ def __init__(
1033
1006
qk_rope_head_dim : int ,
1034
1007
qk_head_dim : int ,
1035
1008
v_head_dim : int ,
1036
- rotary_emb : RotaryEmbedding ,
1037
1009
kv_b_proj : ColumnParallelLinear ,
1038
1010
) -> None :
1039
1011
self .num_heads = num_heads
@@ -1048,10 +1020,6 @@ def __init__(
1048
1020
self .qk_rope_head_dim = qk_rope_head_dim
1049
1021
self .qk_head_dim = qk_head_dim
1050
1022
self .v_head_dim = v_head_dim
1051
-
1052
- self .rotary_emb = rotary_emb
1053
- self .use_yarn_rope = isinstance (rotary_emb ,
1054
- DeepseekScalingRotaryEmbedding )
1055
1023
self .kv_b_proj = kv_b_proj
1056
1024
1057
1025
self .triton_fa_func = triton_attention
@@ -1367,41 +1335,15 @@ def forward(
1367
1335
has_decode = attn_metadata .decode_metadata is not None
1368
1336
has_prefill = attn_metadata .prefill_metadata is not None
1369
1337
1370
- # Restore head dim (for rotary embedding)
1371
- k_pe = k_pe .unsqueeze (1 )
1372
- assert hasattr (attn_metadata , "input_positions" )
1373
-
1374
1338
num_prefill_tokens : int = attn_metadata .num_prefill_tokens
1375
1339
q = q .view (- 1 , self .num_heads , self .qk_head_dim )
1376
1340
1377
1341
decode_q = q [num_prefill_tokens :]
1378
- decode_k_pe = k_pe [num_prefill_tokens :]
1379
- decode_input_positions = \
1380
- attn_metadata .input_positions [num_prefill_tokens :]
1381
1342
1382
1343
prefill_q = q [:num_prefill_tokens ]
1383
1344
prefill_k_pe = k_pe [:num_prefill_tokens ]
1384
- prefill_input_positions = \
1385
- attn_metadata .input_positions [:num_prefill_tokens ]
1386
1345
prefill_k_c_normed = k_c_normed [:num_prefill_tokens ]
1387
1346
1388
- if has_decode :
1389
- decode_q_nope , decode_q_pe = decode_q .split (
1390
- [self .qk_nope_head_dim , self .qk_rope_head_dim ], dim = - 1 )
1391
- # Convert from (B, N, P) to (N, B, P)
1392
- decode_q_nope = decode_q_nope .transpose (0 , 1 )
1393
- # Multiply (N, B, P) x (N, P, L) -> (N, B, L)
1394
- decode_ql_nope = torch .bmm (decode_q_nope , self .W_UK_T )
1395
- # Convert from (N, B, L) to (B, N, L)
1396
- decode_ql_nope = decode_ql_nope .transpose (0 , 1 )
1397
- decode_q_pe [...], decode_k_pe [...] = self .rotary_emb (
1398
- decode_input_positions , decode_q_pe , decode_k_pe )
1399
-
1400
- if has_prefill :
1401
- prefill_q_pe = prefill_q [..., self .qk_nope_head_dim :]
1402
- prefill_q_pe [...], prefill_k_pe [...] = self .rotary_emb (
1403
- prefill_input_positions , prefill_q_pe , prefill_k_pe )
1404
-
1405
1347
# write the latent and rope to kv cache
1406
1348
if kv_cache .numel () > 0 :
1407
1349
ops .concat_and_cache_mla (
@@ -1424,6 +1366,15 @@ def forward(
1424
1366
attn_metadata )
1425
1367
1426
1368
if has_decode :
1369
+ decode_q_nope , decode_q_pe = decode_q .split (
1370
+ [self .qk_nope_head_dim , self .qk_rope_head_dim ], dim = - 1 )
1371
+ # Convert from (B, N, P) to (N, B, P)
1372
+ decode_q_nope = decode_q_nope .transpose (0 , 1 )
1373
+ # Multiply (N, B, P) x (N, P, L) -> (N, B, L)
1374
+ decode_ql_nope = torch .bmm (decode_q_nope , self .W_UK_T )
1375
+ # Convert from (N, B, L) to (B, N, L)
1376
+ decode_ql_nope = decode_ql_nope .transpose (0 , 1 )
1377
+
1427
1378
output [num_prefill_tokens :] = self ._forward_decode (
1428
1379
decode_ql_nope , decode_q_pe , kv_cache , attn_metadata )
1429
1380
0 commit comments