Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 22 additions & 12 deletions vllm/v1/attention/backends/flashinfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,8 +353,9 @@ def _plan(self, num_prefills: int, num_decodes: int,
attn_metadata.decode_wrapper = self._get_decode_wrapper()
if not FlashInferBackend.use_trtllm_decode_attention(
num_decodes, attn_metadata.max_seq_len,
attn_metadata.kv_data_type, attn_metadata.num_qo_heads,
attn_metadata.num_kv_heads, attn_metadata.head_dim):
self.cache_config.cache_dtype,
attn_metadata.num_qo_heads, attn_metadata.num_kv_heads,
attn_metadata.head_dim):
attn_metadata.decode_wrapper.plan(
attn_metadata.paged_kv_indptr[:num_decodes + 1],
attn_metadata.paged_kv_indices,
Expand Down Expand Up @@ -539,10 +540,10 @@ def forward(
query: shape = [num_tokens, num_heads, head_size]
key: shape = [num_tokens, num_kv_heads, head_size]
value: shape = [num_tokens, num_kv_heads, head_size]
kv_cache: shape -
kv_cache: shape -
# NHD: [num_blocks, 2, block_size, num_kv_heads, head_size]
# HND: [num_blocks, 2, num_kv_heads, block_size, head_size]


attn_metadata: Metadata for attention.
Returns:
Expand Down Expand Up @@ -614,6 +615,7 @@ def forward(
num_prefill_tokens = attn_metadata.num_prefill_tokens

stride_order = FlashInferBackend.get_kv_cache_stride_order()
kv_cache_permute = kv_cache.permute(*stride_order)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The permute operation on a tensor does not guarantee that the resulting tensor is contiguous in memory. The trtllm_batch_decode_with_kv_cache kernel, which is used later in this function, requires a contiguous kv_cache tensor, as suggested by the new assertion assert kv_cache_permute.is_contiguous() on line 667.

However, the permutation for the 'HND' layout will likely produce a non-contiguous tensor, causing the assertion to fail at runtime.

To ensure correctness, you should make the tensor contiguous immediately after permuting. This will satisfy the kernel's requirement and ensure the assertion passes.

kv_cache_permute = kv_cache.permute(*stride_order).contiguous()

# Regular attention (common case).
# Decodes are at the front and prefills are at the back,
# according to reorder_batch()
Expand All @@ -628,13 +630,13 @@ def forward(
assert prefill_wrapper._sm_scale == self.scale
prefill_wrapper.run(
prefill_query,
kv_cache.permute(*stride_order),
kv_cache_permute,
k_scale=layer._k_scale_float,
v_scale=layer._v_scale_float,
out=output[num_decode_tokens:],
)
if decode_wrapper := attn_metadata.decode_wrapper:
decode_query = query[:num_decode_tokens]
decode_query = query[:num_decode_tokens].contiguous()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

move this switch to else because it is only trtllm specific

assert decode_query.shape[0] == num_decode_tokens
if not FlashInferBackend.use_trtllm_decode_attention(
attn_metadata.num_decodes, attn_metadata.max_seq_len,
Expand All @@ -647,27 +649,35 @@ def forward(
assert decode_wrapper._sm_scale == self.scale
decode_wrapper.run(
decode_query,
kv_cache.permute(*stride_order),
kv_cache_permute,
k_scale=layer._k_scale_float,
v_scale=layer._v_scale_float,
out=output[:num_decode_tokens],
)
else:
# This path needs to be enabled with VLLM_KV_CACHE_LAYOUT = HND
if num_decode_tokens > 0:
block_tables_decode = attn_metadata.block_table_tensor[:
num_decode_tokens]
seq_lens_decode = attn_metadata.seq_lens[:
num_decode_tokens]

assert get_kv_cache_layout() == "HND"
assert decode_query.is_contiguous()
assert kv_cache_permute.is_contiguous()
assert block_tables_decode.is_contiguous()
assert seq_lens_decode.is_contiguous()

output[:num_decode_tokens] = (
trtllm_batch_decode_with_kv_cache(
query=decode_query,
kv_cache=kv_cache.permute(*stride_order),
kv_cache=kv_cache_permute,
workspace_buffer=attn_metadata.workspace_buffer,
num_heads=self.num_heads,
num_kv_heads=self.num_kv_heads,
scale=self.scale,
block_tables=attn_metadata.
block_table_tensor[:num_decode_tokens],
seq_lens=attn_metadata.
seq_lens[:num_decode_tokens],
block_tables=block_tables_decode,
seq_lens=seq_lens_decode,
block_size=attn_metadata.page_size,
max_seq_len=attn_metadata.max_seq_len,
kv_cache_dtype=self.kv_cache_dtype,
Expand Down