Skip to content

Commit b515b90

Browse files
tjtanaawwl2755-google
authored andcommitted
[Bugfix][V1][ROCm] Fix AITER Flash Attention Backend (Fix API Break and Local Attention Logic: affecting Llama4) (vllm-project#19904)
Signed-off-by: tjtanaa <[email protected]>
1 parent 0881567 commit b515b90

File tree

2 files changed

+46
-23
lines changed

2 files changed

+46
-23
lines changed

vllm/attention/layer.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -306,12 +306,16 @@ def __init__(
306306
block_size=16,
307307
is_attention_free=False)
308308
backend = backend_name_to_enum(attn_backend.get_name())
309-
if backend in {_Backend.FLASH_ATTN, _Backend.FLASH_ATTN_VLLM_V1}:
310-
backend = _Backend.XFORMERS
309+
if current_platform.is_rocm():
310+
# currently, only torch_sdpa is supported on rocm
311+
self.attn_backend = _Backend.TORCH_SDPA
312+
else:
313+
if backend in {_Backend.FLASH_ATTN, _Backend.FLASH_ATTN_VLLM_V1}:
314+
backend = _Backend.XFORMERS
311315

312-
self.attn_backend = backend if backend in {
313-
_Backend.TORCH_SDPA, _Backend.XFORMERS, _Backend.PALLAS_VLLM_V1
314-
} else _Backend.TORCH_SDPA
316+
self.attn_backend = backend if backend in {
317+
_Backend.TORCH_SDPA, _Backend.XFORMERS, _Backend.PALLAS_VLLM_V1
318+
} else _Backend.TORCH_SDPA
315319

316320
def forward(
317321
self,

vllm/v1/attention/backends/rocm_aiter_fa.py

Lines changed: 37 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -243,8 +243,8 @@ def schedule(batch_size, cu_query_lens, max_query_len, seqlens,
243243
self.runner.device, non_blocking=True)
244244
local_seqused_k = torch.from_numpy(virt_k_seqlens_np).to(
245245
self.runner.device, non_blocking=True)
246-
local_max_query_len = seqlens_q_local_np.max()
247-
local_max_seq_len = virt_k_seqlens_np.max()
246+
local_max_query_len = int(seqlens_q_local_np.max())
247+
local_max_seq_len = int(virt_k_seqlens_np.max())
248248
local_scheduler_metadata = schedule(
249249
batch_size=local_query_start_loc.shape[0] - 1,
250250
cu_query_lens=local_query_start_loc,
@@ -253,13 +253,25 @@ def schedule(batch_size, cu_query_lens, max_query_len, seqlens,
253253
max_seq_len=local_max_seq_len,
254254
causal=True)
255255

256+
local_cu_seq_lens = torch.zeros(virt_k_seqlens_np.shape[0] + 1,
257+
dtype=torch.int32,
258+
device=self.runner.device)
259+
local_cu_seq_lens[1:] = torch.cumsum(
260+
torch.from_numpy(virt_k_seqlens_np).to(
261+
device=self.runner.device,
262+
dtype=torch.int32,
263+
non_blocking=True),
264+
dim=0)
265+
266+
256267
local_attn_metadata = \
257268
AiterFlashAttentionMetadata.LocalAttentionMetadata(
258269
local_query_start_loc=local_query_start_loc,
259270
local_seqused_k=local_seqused_k,
260271
local_block_table=virt_block_table_tensor,
261272
local_max_query_len=local_max_query_len,
262273
local_max_seq_len=local_max_seq_len,
274+
local_cu_seq_lens=local_cu_seq_lens,
263275
local_scheduler_metadata=local_scheduler_metadata,
264276
)
265277

@@ -368,6 +380,7 @@ class LocalAttentionMetadata:
368380
local_block_table: torch.Tensor
369381
local_max_query_len: int
370382
local_max_seq_len: int
383+
local_cu_seq_lens: torch.Tensor
371384
local_scheduler_metadata: Optional[torch.Tensor]
372385

373386
local_attn_metadata: Optional[LocalAttentionMetadata] = None
@@ -387,6 +400,7 @@ def __init__(
387400
blocksparse_params: Optional[dict[str, Any]] = None,
388401
logits_soft_cap: Optional[float] = None,
389402
attn_type: AttentionType = AttentionType.DECODER,
403+
kv_sharing_target_layer_name: Optional[int] = None,
390404
use_irope: bool = False,
391405
) -> None:
392406
if blocksparse_params is not None:
@@ -408,6 +422,7 @@ def __init__(
408422
# In flash-attn, setting logits_soft_cap as 0 means no soft cap.
409423
logits_soft_cap = 0.
410424
self.logits_soft_cap = logits_soft_cap
425+
self.kv_sharing_target_layer_name = kv_sharing_target_layer_name
411426

412427
assert self.num_heads % self.num_kv_heads == 0
413428
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
@@ -478,22 +493,25 @@ def forward(
478493
# performance to make sure it does not introduce any overhead.
479494

480495
num_actual_tokens = attn_metadata.num_actual_tokens
481-
# Reshape the input keys and values and store them in the cache.
482-
# NOTE(woosuk): Here, key and value are padded while slot_mapping is
483-
# not padded. However, we don't need to do key[:num_actual_tokens] and
484-
# value[:num_actual_tokens] because the reshape_and_cache_flash op uses
485-
# the slot_mapping's shape to determine the number of actual tokens.
486496
key_cache, value_cache = kv_cache.unbind(0)
487-
torch.ops._C_cache_ops.reshape_and_cache_flash(
488-
key,
489-
value,
490-
key_cache,
491-
value_cache,
492-
attn_metadata.slot_mapping,
493-
self.kv_cache_dtype,
494-
layer._k_scale,
495-
layer._v_scale,
496-
)
497+
if self.kv_sharing_target_layer_name is None:
498+
# Reshape the input keys and values and store them in the cache.
499+
# Skip this if sharing KV cache with an earlier attention layer.
500+
# NOTE(woosuk): Here, key and value are padded while slot_mapping is
501+
# not padded. However, we don't need to do key[:num_actual_tokens]
502+
# and value[:num_actual_tokens] because the reshape_and_cache_flash
503+
# op uses the slot_mapping's shape to determine the number of
504+
# actual tokens.
505+
torch.ops._C_cache_ops.reshape_and_cache_flash(
506+
key,
507+
value,
508+
key_cache,
509+
value_cache,
510+
attn_metadata.slot_mapping,
511+
self.kv_cache_dtype,
512+
layer._k_scale,
513+
layer._v_scale,
514+
)
497515

498516
if self.kv_cache_dtype.startswith("fp8"):
499517
key_cache = key_cache.view(torch.float8_e4m3fnuz)
@@ -541,7 +559,8 @@ def forward(
541559
alibi_slopes=self.alibi_slopes,
542560
window_size=self.sliding_window,
543561
block_table=block_table,
544-
cu_seqlens_k=cu_seq_lens,
562+
cu_seqlens_k=(cu_seq_lens if not use_local_attn else
563+
local_metadata.local_cu_seq_lens),
545564
)
546565

547566
_, num_heads, head_size = query.shape

0 commit comments

Comments
 (0)