Skip to content

[Bugfix] Fix the tensor non-contiguous issue for Flashinfer TRT-LLM backend attention kernel #21133

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jul 18, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 23 additions & 11 deletions vllm/v1/attention/backends/flashinfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,8 +353,9 @@ def _plan(self, num_prefills: int, num_decodes: int,
attn_metadata.decode_wrapper = self._get_decode_wrapper()
if not FlashInferBackend.use_trtllm_decode_attention(
num_decodes, attn_metadata.max_seq_len,
attn_metadata.kv_data_type, attn_metadata.num_qo_heads,
attn_metadata.num_kv_heads, attn_metadata.head_dim):
self.cache_config.cache_dtype,
attn_metadata.num_qo_heads, attn_metadata.num_kv_heads,
attn_metadata.head_dim):
attn_metadata.decode_wrapper.plan(
attn_metadata.paged_kv_indptr[:num_decodes + 1],
attn_metadata.paged_kv_indices,
Expand Down Expand Up @@ -539,10 +540,10 @@ def forward(
query: shape = [num_tokens, num_heads, head_size]
key: shape = [num_tokens, num_kv_heads, head_size]
value: shape = [num_tokens, num_kv_heads, head_size]
kv_cache: shape -
kv_cache: shape -
# NHD: [num_blocks, 2, block_size, num_kv_heads, head_size]
# HND: [num_blocks, 2, num_kv_heads, block_size, head_size]


attn_metadata: Metadata for attention.
Returns:
Expand Down Expand Up @@ -614,6 +615,7 @@ def forward(
num_prefill_tokens = attn_metadata.num_prefill_tokens

stride_order = FlashInferBackend.get_kv_cache_stride_order()
kv_cache_permute = kv_cache.permute(*stride_order)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The permute operation on a tensor does not guarantee that the resulting tensor is contiguous in memory. The trtllm_batch_decode_with_kv_cache kernel, which is used later in this function, requires a contiguous kv_cache tensor, as suggested by the new assertion assert kv_cache_permute.is_contiguous() on line 667.

However, the permutation for the 'HND' layout will likely produce a non-contiguous tensor, causing the assertion to fail at runtime.

To ensure correctness, you should make the tensor contiguous immediately after permuting. This will satisfy the kernel's requirement and ensure the assertion passes.

kv_cache_permute = kv_cache.permute(*stride_order).contiguous()

# Regular attention (common case).
# Decodes are at the front and prefills are at the back,
# according to reorder_batch()
Expand All @@ -628,7 +630,7 @@ def forward(
assert prefill_wrapper._sm_scale == self.scale
prefill_wrapper.run(
prefill_query,
kv_cache.permute(*stride_order),
kv_cache_permute,
k_scale=layer._k_scale_float,
v_scale=layer._v_scale_float,
out=output[num_decode_tokens:],
Expand All @@ -647,27 +649,37 @@ def forward(
assert decode_wrapper._sm_scale == self.scale
decode_wrapper.run(
decode_query,
kv_cache.permute(*stride_order),
kv_cache_permute,
k_scale=layer._k_scale_float,
v_scale=layer._v_scale_float,
out=output[:num_decode_tokens],
)
else:
# This path needs to be enabled with VLLM_KV_CACHE_LAYOUT = HND
if num_decode_tokens > 0:
# decode_query may be non-contiguous
decode_query = decode_query.contiguous()
block_tables_decode = attn_metadata.block_table_tensor[:
num_decode_tokens]
seq_lens_decode = attn_metadata.seq_lens[:
num_decode_tokens]

assert get_kv_cache_layout() == "HND"
assert decode_query.is_contiguous()
assert kv_cache_permute.is_contiguous()
assert block_tables_decode.is_contiguous()
assert seq_lens_decode.is_contiguous()

output[:num_decode_tokens] = (
trtllm_batch_decode_with_kv_cache(
query=decode_query,
kv_cache=kv_cache.permute(*stride_order),
kv_cache=kv_cache_permute,
workspace_buffer=attn_metadata.workspace_buffer,
num_heads=self.num_heads,
num_kv_heads=self.num_kv_heads,
scale=self.scale,
block_tables=attn_metadata.
block_table_tensor[:num_decode_tokens],
seq_lens=attn_metadata.
seq_lens[:num_decode_tokens],
block_tables=block_tables_decode,
seq_lens=seq_lens_decode,
block_size=attn_metadata.page_size,
max_seq_len=attn_metadata.max_seq_len,
kv_cache_dtype=self.kv_cache_dtype,
Expand Down