Skip to content

[Bugfix] Fix the tensor non-contiguous issue for Flashinfer TRT-LLM backend attention kernel #21133

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jul 18, 2025

Conversation

elvischenv
Copy link
Contributor

@elvischenv elvischenv commented Jul 17, 2025

Essential Elements of an Effective PR Description Checklist

  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.

Purpose

  • Fix the non-contiguous tensor decode_query used in Flashinfer TRT-LLM attention kernel
  • Fix the function arguments passing of FlashInferBackend.use_trtllm_decode_attention
    • self.cache_config.cache_dtype instead of attn_metadata.kv_data_type

Test Plan

Check the accuracy with lm_eval.

Test Result

Before:

vllm (pretrained=nvidia/Llama-4-Scout-17B-16E-Instruct-FP8,quantization=modelopt,tensor_parallel_size=1,max_model_len=2048,kv_cache_dtype=auto,trust_remote_code=True), gen_kwargs: (temperature=0.0), limit: 500.0, num_fewshot: 5, batch_size: 200
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value|   |Stderr|
|-----|------:|----------------|-----:|-----------|---|----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.640|±  |0.0215|
|     |       |strict-match    |     5|exact_match|↑  |0.598|±  |0.0219|

After:

vllm (pretrained=nvidia/Llama-4-Scout-17B-16E-Instruct-FP8,quantization=modelopt,tensor_parallel_size=1,max_model_len=2048,kv_cache_dtype=auto,trust_remote_code=True), gen_kwargs: (temperature=0.0), limit: 500.0, num_fewshot: 5, batch_size: 200
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value|   |Stderr|
|-----|------:|----------------|-----:|-----------|---|----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  | 0.93|±  |0.0114|
|     |       |strict-match    |     5|exact_match|↑  | 0.91|±  |0.0128|

(Optional) Documentation Update

Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

@mergify mergify bot added the v1 label Jul 17, 2025
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request fixes two bugs in the FlashInfer attention backend. The first addresses an issue where a non-contiguous decode_query tensor was passed to a kernel, which is resolved by adding a .contiguous() call. The second corrects the data type of an argument passed to use_trtllm_decode_attention. The PR also includes a good refactoring to compute the kv_cache permutation only once.

My review identifies a critical issue with this refactoring. The permuted kv_cache is not guaranteed to be contiguous, which will cause a runtime assertion failure in the TRT-LLM attention path. I've suggested making the tensor contiguous at the point of creation to fix this.

@@ -614,6 +615,7 @@ def forward(
num_prefill_tokens = attn_metadata.num_prefill_tokens

stride_order = FlashInferBackend.get_kv_cache_stride_order()
kv_cache_permute = kv_cache.permute(*stride_order)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The permute operation on a tensor does not guarantee that the resulting tensor is contiguous in memory. The trtllm_batch_decode_with_kv_cache kernel, which is used later in this function, requires a contiguous kv_cache tensor, as suggested by the new assertion assert kv_cache_permute.is_contiguous() on line 667.

However, the permutation for the 'HND' layout will likely produce a non-contiguous tensor, causing the assertion to fail at runtime.

To ensure correctness, you should make the tensor contiguous immediately after permuting. This will satisfy the kernel's requirement and ensure the assertion passes.

kv_cache_permute = kv_cache.permute(*stride_order).contiguous()

@pavanimajety
Copy link
Contributor

cc: @mgoin for review

Thanks for root causing this Elvis, great effort.

Copy link
Contributor

@pavanimajety pavanimajety left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Thank you

k_scale=layer._k_scale_float,
v_scale=layer._v_scale_float,
out=output[num_decode_tokens:],
)
if decode_wrapper := attn_metadata.decode_wrapper:
decode_query = query[:num_decode_tokens]
decode_query = query[:num_decode_tokens].contiguous()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

move this switch to else because it is only trtllm specific

@mgoin mgoin requested a review from LucasWilkinson July 17, 2025 18:43
Copy link
Member

@mgoin mgoin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks for this fix!

@mgoin mgoin enabled auto-merge (squash) July 17, 2025 20:28
@mgoin mgoin added the bug Something isn't working label Jul 17, 2025
@github-actions github-actions bot added the ready ONLY add when PR is ready to merge/full CI is needed label Jul 17, 2025
@mgoin mgoin disabled auto-merge July 17, 2025 23:35
@mgoin mgoin enabled auto-merge (squash) July 17, 2025 23:35
@mgoin mgoin merged commit 8dfb45c into vllm-project:main Jul 18, 2025
78 checks passed
@elvischenv elvischenv deleted the fix-trtllm-gen-attention-accuracy branch July 18, 2025 04:16
WorldExplored pushed a commit to nadathurv/vllm that referenced this pull request Jul 19, 2025
WorldExplored pushed a commit to nadathurv/vllm that referenced this pull request Jul 19, 2025
hj-mistral pushed a commit to hj-mistral/vllm that referenced this pull request Jul 19, 2025
LyrisZhong pushed a commit to LyrisZhong/vllm that referenced this pull request Jul 23, 2025
avigny pushed a commit to avigny/vllm that referenced this pull request Jul 31, 2025
x22x22 pushed a commit to x22x22/vllm that referenced this pull request Aug 5, 2025
Pradyun92 pushed a commit to Pradyun92/vllm that referenced this pull request Aug 6, 2025
npanpaliya pushed a commit to odh-on-pz/vllm-upstream that referenced this pull request Aug 6, 2025
jinzhen-lin pushed a commit to jinzhen-lin/vllm that referenced this pull request Aug 9, 2025
paulpak58 pushed a commit to paulpak58/vllm that referenced this pull request Aug 13, 2025
taneem-ibrahim pushed a commit to taneem-ibrahim/vllm that referenced this pull request Aug 14, 2025
diegocastanibm pushed a commit to diegocastanibm/vllm that referenced this pull request Aug 15, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working ready ONLY add when PR is ready to merge/full CI is needed v1
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants