-
-
Notifications
You must be signed in to change notification settings - Fork 9.6k
[Bugfix] Fix the tensor non-contiguous issue for Flashinfer TRT-LLM backend attention kernel #21133
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Bugfix] Fix the tensor non-contiguous issue for Flashinfer TRT-LLM backend attention kernel #21133
Conversation
Signed-off-by: elvischenv <[email protected]>
👋 Hi! Thank you for contributing to the vLLM project. 💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels. Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can either: Add 🚀 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request fixes two bugs in the FlashInfer attention backend. The first addresses an issue where a non-contiguous decode_query
tensor was passed to a kernel, which is resolved by adding a .contiguous()
call. The second corrects the data type of an argument passed to use_trtllm_decode_attention
. The PR also includes a good refactoring to compute the kv_cache
permutation only once.
My review identifies a critical issue with this refactoring. The permuted kv_cache
is not guaranteed to be contiguous, which will cause a runtime assertion failure in the TRT-LLM attention path. I've suggested making the tensor contiguous at the point of creation to fix this.
@@ -614,6 +615,7 @@ def forward( | |||
num_prefill_tokens = attn_metadata.num_prefill_tokens | |||
|
|||
stride_order = FlashInferBackend.get_kv_cache_stride_order() | |||
kv_cache_permute = kv_cache.permute(*stride_order) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The permute
operation on a tensor does not guarantee that the resulting tensor is contiguous in memory. The trtllm_batch_decode_with_kv_cache
kernel, which is used later in this function, requires a contiguous kv_cache
tensor, as suggested by the new assertion assert kv_cache_permute.is_contiguous()
on line 667.
However, the permutation for the 'HND' layout will likely produce a non-contiguous tensor, causing the assertion to fail at runtime.
To ensure correctness, you should make the tensor contiguous immediately after permuting. This will satisfy the kernel's requirement and ensure the assertion passes.
kv_cache_permute = kv_cache.permute(*stride_order).contiguous()
cc: @mgoin for review Thanks for root causing this Elvis, great effort. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Thank you
k_scale=layer._k_scale_float, | ||
v_scale=layer._v_scale_float, | ||
out=output[num_decode_tokens:], | ||
) | ||
if decode_wrapper := attn_metadata.decode_wrapper: | ||
decode_query = query[:num_decode_tokens] | ||
decode_query = query[:num_decode_tokens].contiguous() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
move this switch to else because it is only trtllm specific
Signed-off-by: elvischenv <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks for this fix!
…ackend attention kernel (vllm-project#21133)
…ackend attention kernel (vllm-project#21133) Signed-off-by: WorldExplored <[email protected]>
…ackend attention kernel (vllm-project#21133) Signed-off-by: Himanshu Jaju <[email protected]>
…ackend attention kernel (vllm-project#21133)
…ackend attention kernel (vllm-project#21133) Signed-off-by: avigny <[email protected]>
…ackend attention kernel (vllm-project#21133) Signed-off-by: x22x22 <[email protected]>
…ackend attention kernel (vllm-project#21133)
…ackend attention kernel (vllm-project#21133)
…ackend attention kernel (vllm-project#21133) Signed-off-by: Jinzhen Lin <[email protected]>
…ackend attention kernel (vllm-project#21133) Signed-off-by: Paul Pak <[email protected]>
…ackend attention kernel (vllm-project#21133)
…ackend attention kernel (vllm-project#21133) Signed-off-by: Diego-Castan <[email protected]>
Essential Elements of an Effective PR Description Checklist
supported_models.md
andexamples
for a new model.Purpose
decode_query
used in Flashinfer TRT-LLM attention kernelFlashInferBackend.use_trtllm_decode_attention
self.cache_config.cache_dtype
instead ofattn_metadata.kv_data_type
Test Plan
Check the accuracy with lm_eval.
Test Result
Before:
After:
(Optional) Documentation Update