-
-
Notifications
You must be signed in to change notification settings - Fork 9.2k
[ROCm][AITER] Enable fp8 kv cache on rocm aiter backend. #20295
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Summary of Changes
Hello @fsx950223, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!
This pull request primarily focuses on integrating the Aiter backend for ROCm, enabling FP8 KV cache support, and optimizing attention mechanisms within vLLM. It transitions existing paged attention calls to the new Aiter implementation, introduces explicit handling for FP8 KV cache scaling, and refines supported hardware configurations for ROCm.
Highlights
- Aiter Backend Integration: Switched existing paged attention calls in
rocm_flash_attn.py
andchunked_prefill_paged_decode.py
to usetorch.ops.aiter.paged_attention_rocm
andaiter.paged_attention_v1
, indicating a full transition to the Aiter implementation for ROCm. - FP8 KV Cache Support: Enabled FP8 KV cache by removing the
NotImplementedError
check, addingk_scale
andv_scale
parameters to the KV cache layout transformation, and passing these scales to the Aiter paged attention kernels, allowing for efficient FP8 operations. - ROCm Compatibility Expansion: Broadened the supported head sizes for ROCm custom paged attention to include 256 and increased the maximum Grouped Query Attention (GQA) ratio to 32, enhancing compatibility with a wider range of models and configurations.
- Optimized Workspace Management: Introduced a dedicated
workspace_buffer
within theAiterFlashAttentionBackend
's scheduling mechanism, pre-allocating necessary memory for attention operations and passing it to the Aiter kernels for improved performance and memory efficiency. - Code Refinements: Streamlined parameter passing for attention functions (e.g., removing
query_start_loc
, renamingseq_lens
tocontext_lens
) and removed redundant query quantization logic, leading to cleaner and more efficient code.
Using Gemini Code Assist
The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.
Invoking Gemini
You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command>
or @gemini-code-assist <command>
. Below is a summary of the supported commands.
Feature | Command | Description |
---|---|---|
Code Review | /gemini review |
Performs a code review for the current pull request in its current state. |
Pull Request Summary | /gemini summary |
Provides a summary of the current pull request in its current state. |
Comment | @gemini-code-assist | Responds in comments when explicitly tagged, both in issue comments and review comments. |
Help | /gemini help |
Displays a list of available commands. |
Customization
To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/
folder in the base of the repository. Detailed instructions can be found here.
Limitations & Feedback
Gemini Code Assist is currently in preview and may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments to provide feedback.
You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.
Footnotes
-
Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution. ↩
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request enables fp8 KV cache on the ROCm AITer backend. It integrates the aiter
library's custom ops, updates function signatures, and adds logic to handle fp8 dequantization. The PR expands support for larger head sizes and GQA ratios. A typo and a reduction in supported head sizes were identified.
def get_supported_head_sizes() -> list[int]: | ||
return [32, 64, 96, 128, 160, 192, 224, 256] | ||
return [64, 128, 256] | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The list of supported head sizes has been significantly reduced. The previous implementation supported a wider range of sizes ([32, 64, 96, 128, 160, 192, 224, 256]
), but the new implementation only supports [64, 128, 256]
. This could be a breaking change for models that use other head sizes (e.g., 32, 96).
Please confirm if this is an intended limitation of the new AITer backend and document it if so. If other head sizes can be supported, they should be added back to maintain compatibility.
nbyes_per_qo_elem = torch.finfo(self.runner.dtype).bits // 8 | ||
max_num_partitions = (max_seq_len + _PARTITION_SIZE_ROCM - | ||
1) // _PARTITION_SIZE_ROCM | ||
|
||
workspace_buffer = torch.empty( | ||
(num_reqs * self.num_heads_q * max_num_partitions * self.headdim) * | ||
nbyes_per_qo_elem + 2 * | ||
(num_reqs * self.num_heads_q * max_num_partitions) * 4, | ||
dtype=torch.uint8, | ||
device=self.runner.device, | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There's a typo in the variable name nbyes_per_qo_elem
. It should be nbytes_per_qo_elem
. This will improve code readability and maintainability.
nbyes_per_qo_elem = torch.finfo(self.runner.dtype).bits // 8 | |
max_num_partitions = (max_seq_len + _PARTITION_SIZE_ROCM - | |
1) // _PARTITION_SIZE_ROCM | |
workspace_buffer = torch.empty( | |
(num_reqs * self.num_heads_q * max_num_partitions * self.headdim) * | |
nbyes_per_qo_elem + 2 * | |
(num_reqs * self.num_heads_q * max_num_partitions) * 4, | |
dtype=torch.uint8, | |
device=self.runner.device, | |
) | |
nbytes_per_qo_elem = torch.finfo(self.runner.dtype).bits // 8 | |
max_num_partitions = (max_seq_len + _PARTITION_SIZE_ROCM - | |
1) // _PARTITION_SIZE_ROCM | |
workspace_buffer = torch.empty( | |
(num_reqs * self.num_heads_q * max_num_partitions * self.headdim) * | |
nbytes_per_qo_elem + 2 * | |
(num_reqs * self.num_heads_q * max_num_partitions) * 4, | |
dtype=torch.uint8, | |
device=self.runner.device, | |
) |
6050a47
to
b130f85
Compare
👋 Hi! Thank you for contributing to the vLLM project. 💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels. Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can either: Add 🚀 |
@@ -559,28 +591,14 @@ def forward( | |||
alibi_slopes=self.alibi_slopes, | |||
window_size=self.sliding_window, | |||
block_table=block_table, | |||
cu_seqlens_k=(cu_seq_lens if not use_local_attn else | |||
local_metadata.local_cu_seq_lens), | |||
cu_seqlens_k=cu_seq_lens, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@fsx950223 can you double check if there is accidental removal of a bug fix for local attention cases here?
it should be using the local_metadata.local_cu_seq_lens
when local_attention
is true.
There is alignment issue when processing an input sequence larger than 8192.
The bug fix PR for local attention was introduced here #19904
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
@@ -305,7 +304,7 @@ def chunked_prefill_paged_decode( | |||
) | |||
max_logits = torch.empty_like(exp_sums) | |||
|
|||
ops.paged_attention_rocm( | |||
torch.ops.aiter.paged_attention_rocm( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@fsx950223 I think we should avoid replacing this directly.
vLLM's repo ops.paged_attention_rocm
is more generic and covers more GPU arch e.g. Radeon GPUs.
With this replacement, it will affect users who are using Radeon GPUs. They could not build AITER repo on Radeon.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There should be a switch, for Radeon GPUs, use original one, will change later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great. Hope to make the landing of this PR smooth and we can benefit from the AITER speed boost. 🙌 🚀
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not just Radeon. Aiter is not a hard requirement for running vLLM, so any API should be opt-in, not opt-out
vllm/platforms/rocm.py
Outdated
@@ -138,9 +138,9 @@ def use_rocm_custom_paged_attention( | |||
return ((not envs.VLLM_USE_V1 or sliding_window == 0 | |||
or sliding_window == (-1, -1)) | |||
and (qtype == torch.half or qtype == torch.bfloat16) | |||
and (head_size == 64 or head_size == 128) | |||
and (head_size in [64, 128, 256]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Based on this comment https://github.com/vllm-project/vllm/pull/20295/files#r2179115840 , as we need to retain the use of ops.paged_attention_rocm
, we should create another if else branch just to handle the case for the AITER Flash Attention.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The condition only for gfx9xx
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is ROCm/aiter compatible with Mi200 series GPU? (e.g. gfx90a) ? On upstream gfx9xx also include Mi200 series GPU support.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Moreover chunked_prefill_decode attention
relies on this condition as well (it uses ops.paged_attention_rocm
). ops.paged_attention_rocm
depends on a different set of conditions than that of torch.ops.aiter.paged_attention_rocm
.
use_custom = use_rocm_custom_paged_attention(query.dtype, head_size, |
So we still need to create a separate condition for the AITER FA
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is ROCm/aiter compatible with Mi200 series GPU? (e.g. gfx90a) ? On upstream gfx9xx also include Mi200 series GPU support.
Yes
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Moreover
chunked_prefill_decode attention
relies on this condition as well (it usesops.paged_attention_rocm
). It depends on a different set of conditions than that oftorch.ops.aiter.paged_attention_rocm
.
use_custom = use_rocm_custom_paged_attention(query.dtype, head_size, So we still need to create a separate condition for the AITER FA
Triton backend has its own condition
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Make a new limit in the file
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The limit also needs to include the gqa_ratio
limits?
latest changes in the code:
- (gqa_ratio >= 1 and gqa_ratio <= 16)
+ (gqa_ratio >= 1 and gqa_ratio <= 32)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
@@ -913,8 +913,7 @@ def forward( | |||
) | |||
max_logits = torch.empty_like(exp_sums) | |||
|
|||
query_start_loc = None | |||
ops.paged_attention_rocm( | |||
torch.ops.aiter.paged_attention_rocm( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same situation as the comment https://github.com/vllm-project/vllm/pull/20295/files#r2179115840 .
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Only gfx9xx gpus support rocm flash attn.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We have to change this for v0 workloads.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is ROCm/aiter compatible with Mi200 series GPU? (e.g. gfx90a) ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes
ce69e32
to
70d687f
Compare
Only keep v1 changes in the PR. |
This pull request has merge conflicts that must be resolved before it can be |
cfef34c
to
764481f
Compare
fc2b6ce
to
b0955f8
Compare
Thanks @fsx950223. Can you add test plan in your PR's description section? |
4360117
to
a2acfb8
Compare
9c40a55
to
8f050d8
Compare
Enable full cuda graph and fp8 kv cache Signed-off-by: fsx950223 <[email protected]>
8f050d8
to
8197296
Compare
Signed-off-by: fsx950223 <[email protected]>
This pull request has merge conflicts that must be resolved before it can be |
Signed-off-by: amd-ruitang3 <[email protected]>
Signed-off-by: amd-ruitang3 <[email protected]>
Signed-off-by: fsx950223 <[email protected]>
This pull request has merge conflicts that must be resolved before it can be |
Signed-off-by: fsx950223 <[email protected]>
Does this feature work with the current aiter commit that vLLM is using? Or do we have to upgrade the aiter? I saw there are a few more new commits on FA and PA feature after the current aiter commit. |
You could try it. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This generally looks reasonable to me. Thanks for the contribution!
…t#20295) Signed-off-by: fsx950223 <[email protected]> Signed-off-by: amd-ruitang3 <[email protected]> Co-authored-by: amd-ruitang3 <[email protected]>
…t#20295) Signed-off-by: fsx950223 <[email protected]> Signed-off-by: amd-ruitang3 <[email protected]> Co-authored-by: amd-ruitang3 <[email protected]> Signed-off-by: shuw <[email protected]>
…t#20295) Signed-off-by: fsx950223 <[email protected]> Signed-off-by: amd-ruitang3 <[email protected]> Co-authored-by: amd-ruitang3 <[email protected]> Signed-off-by: x22x22 <[email protected]>
…t#20295) Signed-off-by: fsx950223 <[email protected]> Signed-off-by: amd-ruitang3 <[email protected]> Co-authored-by: amd-ruitang3 <[email protected]>
…t#20295) Signed-off-by: fsx950223 <[email protected]> Signed-off-by: amd-ruitang3 <[email protected]> Co-authored-by: amd-ruitang3 <[email protected]>
…t#20295) Signed-off-by: fsx950223 <[email protected]> Signed-off-by: amd-ruitang3 <[email protected]> Co-authored-by: amd-ruitang3 <[email protected]> Signed-off-by: Jinzhen Lin <[email protected]>
Rocm aiter backend could support fp8 kv cache with latest aiter
CMD:
HIP_VISIBLE_DEVICES=3,4 VLLM_ROCM_USE_AITER=1 VLLM_USE_V1=1 vllm serve /models/models--amd--Meta-Llama-3.1-8B-Instruct-FP8-KV/snapshots/fa42f9a9105c545755fea25cf69f49ac8c8b40e1/ --tensor-parallel-size 2 --gpu-memory-utilization 0.9 --trust-remote-code --disable-log-requests --block-size 16 --max-model-len 32768 --dtype float16 --quantization fp8 --no-enable-prefix-caching --max-num-batched-tokens=8192 --kv-cache-dtype fp8 --compilation-config '{"full_cuda_graph":true}'
Essential Elements of an Effective PR Description Checklist
supported_models.md
andexamples
for a new model.Purpose
Test Plan
Test Result
(Optional) Documentation Update