@@ -243,8 +243,8 @@ def schedule(batch_size, cu_query_lens, max_query_len, seqlens,
243
243
self .runner .device , non_blocking = True )
244
244
local_seqused_k = torch .from_numpy (virt_k_seqlens_np ).to (
245
245
self .runner .device , non_blocking = True )
246
- local_max_query_len = seqlens_q_local_np .max ()
247
- local_max_seq_len = virt_k_seqlens_np .max ()
246
+ local_max_query_len = int ( seqlens_q_local_np .max () )
247
+ local_max_seq_len = int ( virt_k_seqlens_np .max () )
248
248
local_scheduler_metadata = schedule (
249
249
batch_size = local_query_start_loc .shape [0 ] - 1 ,
250
250
cu_query_lens = local_query_start_loc ,
@@ -253,13 +253,25 @@ def schedule(batch_size, cu_query_lens, max_query_len, seqlens,
253
253
max_seq_len = local_max_seq_len ,
254
254
causal = True )
255
255
256
+ local_cu_seq_lens = torch .zeros (virt_k_seqlens_np .shape [0 ] + 1 ,
257
+ dtype = torch .int32 ,
258
+ device = self .runner .device )
259
+ local_cu_seq_lens [1 :] = torch .cumsum (
260
+ torch .from_numpy (virt_k_seqlens_np ).to (
261
+ device = self .runner .device ,
262
+ dtype = torch .int32 ,
263
+ non_blocking = True ),
264
+ dim = 0 )
265
+
266
+
256
267
local_attn_metadata = \
257
268
AiterFlashAttentionMetadata .LocalAttentionMetadata (
258
269
local_query_start_loc = local_query_start_loc ,
259
270
local_seqused_k = local_seqused_k ,
260
271
local_block_table = virt_block_table_tensor ,
261
272
local_max_query_len = local_max_query_len ,
262
273
local_max_seq_len = local_max_seq_len ,
274
+ local_cu_seq_lens = local_cu_seq_lens ,
263
275
local_scheduler_metadata = local_scheduler_metadata ,
264
276
)
265
277
@@ -368,6 +380,7 @@ class LocalAttentionMetadata:
368
380
local_block_table : torch .Tensor
369
381
local_max_query_len : int
370
382
local_max_seq_len : int
383
+ local_cu_seq_lens : torch .Tensor
371
384
local_scheduler_metadata : Optional [torch .Tensor ]
372
385
373
386
local_attn_metadata : Optional [LocalAttentionMetadata ] = None
@@ -387,6 +400,7 @@ def __init__(
387
400
blocksparse_params : Optional [dict [str , Any ]] = None ,
388
401
logits_soft_cap : Optional [float ] = None ,
389
402
attn_type : AttentionType = AttentionType .DECODER ,
403
+ kv_sharing_target_layer_name : Optional [int ] = None ,
390
404
use_irope : bool = False ,
391
405
) -> None :
392
406
if blocksparse_params is not None :
@@ -408,6 +422,7 @@ def __init__(
408
422
# In flash-attn, setting logits_soft_cap as 0 means no soft cap.
409
423
logits_soft_cap = 0.
410
424
self .logits_soft_cap = logits_soft_cap
425
+ self .kv_sharing_target_layer_name = kv_sharing_target_layer_name
411
426
412
427
assert self .num_heads % self .num_kv_heads == 0
413
428
self .num_queries_per_kv = self .num_heads // self .num_kv_heads
@@ -478,22 +493,25 @@ def forward(
478
493
# performance to make sure it does not introduce any overhead.
479
494
480
495
num_actual_tokens = attn_metadata .num_actual_tokens
481
- # Reshape the input keys and values and store them in the cache.
482
- # NOTE(woosuk): Here, key and value are padded while slot_mapping is
483
- # not padded. However, we don't need to do key[:num_actual_tokens] and
484
- # value[:num_actual_tokens] because the reshape_and_cache_flash op uses
485
- # the slot_mapping's shape to determine the number of actual tokens.
486
496
key_cache , value_cache = kv_cache .unbind (0 )
487
- torch .ops ._C_cache_ops .reshape_and_cache_flash (
488
- key ,
489
- value ,
490
- key_cache ,
491
- value_cache ,
492
- attn_metadata .slot_mapping ,
493
- self .kv_cache_dtype ,
494
- layer ._k_scale ,
495
- layer ._v_scale ,
496
- )
497
+ if self .kv_sharing_target_layer_name is None :
498
+ # Reshape the input keys and values and store them in the cache.
499
+ # Skip this if sharing KV cache with an earlier attention layer.
500
+ # NOTE(woosuk): Here, key and value are padded while slot_mapping is
501
+ # not padded. However, we don't need to do key[:num_actual_tokens]
502
+ # and value[:num_actual_tokens] because the reshape_and_cache_flash
503
+ # op uses the slot_mapping's shape to determine the number of
504
+ # actual tokens.
505
+ torch .ops ._C_cache_ops .reshape_and_cache_flash (
506
+ key ,
507
+ value ,
508
+ key_cache ,
509
+ value_cache ,
510
+ attn_metadata .slot_mapping ,
511
+ self .kv_cache_dtype ,
512
+ layer ._k_scale ,
513
+ layer ._v_scale ,
514
+ )
497
515
498
516
if self .kv_cache_dtype .startswith ("fp8" ):
499
517
key_cache = key_cache .view (torch .float8_e4m3fnuz )
@@ -541,7 +559,8 @@ def forward(
541
559
alibi_slopes = self .alibi_slopes ,
542
560
window_size = self .sliding_window ,
543
561
block_table = block_table ,
544
- cu_seqlens_k = cu_seq_lens ,
562
+ cu_seqlens_k = (cu_seq_lens if not use_local_attn else
563
+ local_metadata .local_cu_seq_lens ),
545
564
)
546
565
547
566
_ , num_heads , head_size = query .shape
0 commit comments