Skip to content

[ROCm][AITER] Enable fp8 kv cache on rocm aiter backend. #20295

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Jul 25, 2025

Conversation

fsx950223
Copy link
Contributor

@fsx950223 fsx950223 commented Jul 1, 2025

Rocm aiter backend could support fp8 kv cache with latest aiter
CMD:
HIP_VISIBLE_DEVICES=3,4 VLLM_ROCM_USE_AITER=1 VLLM_USE_V1=1 vllm serve /models/models--amd--Meta-Llama-3.1-8B-Instruct-FP8-KV/snapshots/fa42f9a9105c545755fea25cf69f49ac8c8b40e1/ --tensor-parallel-size 2 --gpu-memory-utilization 0.9 --trust-remote-code --disable-log-requests --block-size 16 --max-model-len 32768 --dtype float16 --quantization fp8 --no-enable-prefix-caching --max-num-batched-tokens=8192 --kv-cache-dtype fp8 --compilation-config '{"full_cuda_graph":true}'

Tasks Version Filter n-shot Metric Value Stderr
gsm8k 3 flexible-extract 5 exact_match 0.796 ± 0.0255
strict-match 5 exact_match 0.744 ± 0.0277

Essential Elements of an Effective PR Description Checklist

  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.

Purpose

Test Plan

Test Result

(Optional) Documentation Update

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary of Changes

Hello @fsx950223, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request primarily focuses on integrating the Aiter backend for ROCm, enabling FP8 KV cache support, and optimizing attention mechanisms within vLLM. It transitions existing paged attention calls to the new Aiter implementation, introduces explicit handling for FP8 KV cache scaling, and refines supported hardware configurations for ROCm.

Highlights

  • Aiter Backend Integration: Switched existing paged attention calls in rocm_flash_attn.py and chunked_prefill_paged_decode.py to use torch.ops.aiter.paged_attention_rocm and aiter.paged_attention_v1, indicating a full transition to the Aiter implementation for ROCm.
  • FP8 KV Cache Support: Enabled FP8 KV cache by removing the NotImplementedError check, adding k_scale and v_scale parameters to the KV cache layout transformation, and passing these scales to the Aiter paged attention kernels, allowing for efficient FP8 operations.
  • ROCm Compatibility Expansion: Broadened the supported head sizes for ROCm custom paged attention to include 256 and increased the maximum Grouped Query Attention (GQA) ratio to 32, enhancing compatibility with a wider range of models and configurations.
  • Optimized Workspace Management: Introduced a dedicated workspace_buffer within the AiterFlashAttentionBackend's scheduling mechanism, pre-allocating necessary memory for attention operations and passing it to the Aiter kernels for improved performance and memory efficiency.
  • Code Refinements: Streamlined parameter passing for attention functions (e.g., removing query_start_loc, renaming seq_lens to context_lens) and removed redundant query quantization logic, leading to cleaner and more efficient code.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in issue comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist is currently in preview and may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments to provide feedback.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

@mergify mergify bot added rocm Related to AMD ROCm v1 labels Jul 1, 2025
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request enables fp8 KV cache on the ROCm AITer backend. It integrates the aiter library's custom ops, updates function signatures, and adds logic to handle fp8 dequantization. The PR expands support for larger head sizes and GQA ratios. A typo and a reduction in supported head sizes were identified.

Comment on lines 358 to 383
def get_supported_head_sizes() -> list[int]:
return [32, 64, 96, 128, 160, 192, 224, 256]
return [64, 128, 256]

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The list of supported head sizes has been significantly reduced. The previous implementation supported a wider range of sizes ([32, 64, 96, 128, 160, 192, 224, 256]), but the new implementation only supports [64, 128, 256]. This could be a breaking change for models that use other head sizes (e.g., 32, 96).

Please confirm if this is an intended limitation of the new AITer backend and document it if so. If other head sizes can be supported, they should be added back to maintain compatibility.

Comment on lines 312 to 322
nbyes_per_qo_elem = torch.finfo(self.runner.dtype).bits // 8
max_num_partitions = (max_seq_len + _PARTITION_SIZE_ROCM -
1) // _PARTITION_SIZE_ROCM

workspace_buffer = torch.empty(
(num_reqs * self.num_heads_q * max_num_partitions * self.headdim) *
nbyes_per_qo_elem + 2 *
(num_reqs * self.num_heads_q * max_num_partitions) * 4,
dtype=torch.uint8,
device=self.runner.device,
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

There's a typo in the variable name nbyes_per_qo_elem. It should be nbytes_per_qo_elem. This will improve code readability and maintainability.

Suggested change
nbyes_per_qo_elem = torch.finfo(self.runner.dtype).bits // 8
max_num_partitions = (max_seq_len + _PARTITION_SIZE_ROCM -
1) // _PARTITION_SIZE_ROCM
workspace_buffer = torch.empty(
(num_reqs * self.num_heads_q * max_num_partitions * self.headdim) *
nbyes_per_qo_elem + 2 *
(num_reqs * self.num_heads_q * max_num_partitions) * 4,
dtype=torch.uint8,
device=self.runner.device,
)
nbytes_per_qo_elem = torch.finfo(self.runner.dtype).bits // 8
max_num_partitions = (max_seq_len + _PARTITION_SIZE_ROCM -
1) // _PARTITION_SIZE_ROCM
workspace_buffer = torch.empty(
(num_reqs * self.num_heads_q * max_num_partitions * self.headdim) *
nbytes_per_qo_elem + 2 *
(num_reqs * self.num_heads_q * max_num_partitions) * 4,
dtype=torch.uint8,
device=self.runner.device,
)

@fsx950223 fsx950223 force-pushed the character_ai_upstream branch from 6050a47 to b130f85 Compare July 1, 2025 03:51
@fsx950223 fsx950223 marked this pull request as draft July 1, 2025 03:56
Copy link

github-actions bot commented Jul 1, 2025

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

@@ -559,28 +591,14 @@ def forward(
alibi_slopes=self.alibi_slopes,
window_size=self.sliding_window,
block_table=block_table,
cu_seqlens_k=(cu_seq_lens if not use_local_attn else
local_metadata.local_cu_seq_lens),
cu_seqlens_k=cu_seq_lens,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@fsx950223 can you double check if there is accidental removal of a bug fix for local attention cases here?

it should be using the local_metadata.local_cu_seq_lens when local_attention is true.

There is alignment issue when processing an input sequence larger than 8192.

The bug fix PR for local attention was introduced here #19904

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

@@ -305,7 +304,7 @@ def chunked_prefill_paged_decode(
)
max_logits = torch.empty_like(exp_sums)

ops.paged_attention_rocm(
torch.ops.aiter.paged_attention_rocm(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@fsx950223 I think we should avoid replacing this directly.

vLLM's repo ops.paged_attention_rocm is more generic and covers more GPU arch e.g. Radeon GPUs.
With this replacement, it will affect users who are using Radeon GPUs. They could not build AITER repo on Radeon.

CC. @gshtras @hongxiayang @hyoon1

Copy link
Contributor Author

@fsx950223 fsx950223 Jul 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There should be a switch, for Radeon GPUs, use original one, will change later.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great. Hope to make the landing of this PR smooth and we can benefit from the AITER speed boost. 🙌 🚀

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not just Radeon. Aiter is not a hard requirement for running vLLM, so any API should be opt-in, not opt-out

@@ -138,9 +138,9 @@ def use_rocm_custom_paged_attention(
return ((not envs.VLLM_USE_V1 or sliding_window == 0
or sliding_window == (-1, -1))
and (qtype == torch.half or qtype == torch.bfloat16)
and (head_size == 64 or head_size == 128)
and (head_size in [64, 128, 256])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@fsx950223

Based on this comment https://github.com/vllm-project/vllm/pull/20295/files#r2179115840 , as we need to retain the use of ops.paged_attention_rocm , we should create another if else branch just to handle the case for the AITER Flash Attention.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The condition only for gfx9xx

Copy link
Contributor

@tjtanaa tjtanaa Jul 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is ROCm/aiter compatible with Mi200 series GPU? (e.g. gfx90a) ? On upstream gfx9xx also include Mi200 series GPU support.

Copy link
Contributor

@tjtanaa tjtanaa Jul 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moreover chunked_prefill_decode attention relies on this condition as well (it uses ops.paged_attention_rocm). ops.paged_attention_rocm depends on a different set of conditions than that of torch.ops.aiter.paged_attention_rocm.

use_custom = use_rocm_custom_paged_attention(query.dtype, head_size,

So we still need to create a separate condition for the AITER FA

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is ROCm/aiter compatible with Mi200 series GPU? (e.g. gfx90a) ? On upstream gfx9xx also include Mi200 series GPU support.

Yes

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moreover chunked_prefill_decode attention relies on this condition as well (it uses ops.paged_attention_rocm). It depends on a different set of conditions than that of torch.ops.aiter.paged_attention_rocm.

use_custom = use_rocm_custom_paged_attention(query.dtype, head_size,

So we still need to create a separate condition for the AITER FA

Triton backend has its own condition

Copy link
Contributor Author

@fsx950223 fsx950223 Jul 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Make a new limit in the file

Copy link
Contributor

@tjtanaa tjtanaa Jul 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The limit also needs to include the gqa_ratio limits?

latest changes in the code:

- (gqa_ratio >= 1 and gqa_ratio <= 16)
+ (gqa_ratio >= 1 and gqa_ratio <= 32)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

@@ -913,8 +913,7 @@ def forward(
)
max_logits = torch.empty_like(exp_sums)

query_start_loc = None
ops.paged_attention_rocm(
torch.ops.aiter.paged_attention_rocm(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Only gfx9xx gpus support rocm flash attn.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have to change this for v0 workloads.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is ROCm/aiter compatible with Mi200 series GPU? (e.g. gfx90a) ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes

@fsx950223 fsx950223 marked this pull request as ready for review July 3, 2025 11:15
@fsx950223 fsx950223 force-pushed the character_ai_upstream branch from ce69e32 to 70d687f Compare July 4, 2025 07:09
@fsx950223
Copy link
Contributor Author

Only keep v1 changes in the PR.

Copy link

mergify bot commented Jul 8, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @fsx950223.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@fsx950223 fsx950223 force-pushed the character_ai_upstream branch from cfef34c to 764481f Compare July 10, 2025 07:01
@fsx950223 fsx950223 requested a review from tlrmchlsmth as a code owner July 11, 2025 05:34
@fsx950223 fsx950223 force-pushed the character_ai_upstream branch 2 times, most recently from fc2b6ce to b0955f8 Compare July 11, 2025 05:58
@hongxiayang
Copy link
Collaborator

Thanks @fsx950223. Can you add test plan in your PR's description section?

@fsx950223 fsx950223 force-pushed the character_ai_upstream branch from 4360117 to a2acfb8 Compare July 15, 2025 03:51
@mergify mergify bot added performance Performance-related issues speculative-decoding and removed needs-rebase labels Jul 17, 2025
@fsx950223 fsx950223 force-pushed the character_ai_upstream branch from 9c40a55 to 8f050d8 Compare July 18, 2025 02:48
Enable full cuda graph and fp8 kv cache

Signed-off-by: fsx950223 <[email protected]>
@fsx950223 fsx950223 force-pushed the character_ai_upstream branch from 8f050d8 to 8197296 Compare July 18, 2025 02:50
Signed-off-by: fsx950223 <[email protected]>
Copy link

mergify bot commented Jul 18, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @fsx950223.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Jul 18, 2025
@mergify mergify bot removed the needs-rebase label Jul 18, 2025
Copy link

mergify bot commented Jul 21, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @fsx950223.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Jul 21, 2025
@mergify mergify bot removed the needs-rebase label Jul 22, 2025
@gshtras gshtras added the ready ONLY add when PR is ready to merge/full CI is needed label Jul 23, 2025
@gshtras gshtras changed the title Enable fp8 kv cache on rocm aiter backend. [ROCm][AITER] Enable fp8 kv cache on rocm aiter backend. Jul 23, 2025
@tjtanaa
Copy link
Contributor

tjtanaa commented Jul 23, 2025

Does this feature work with the current aiter commit that vLLM is using? Or do we have to upgrade the aiter? I saw there are a few more new commits on FA and PA feature after the current aiter commit.

@fsx950223
Copy link
Contributor Author

Does this feature work with the current aiter commit that vLLM is using? Or do we have to upgrade the aiter? I saw there are a few more new commits on FA and PA feature after the current aiter commit.

You could try it.

Copy link
Contributor

@SageMoore SageMoore left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This generally looks reasonable to me. Thanks for the contribution!

@vllm-bot vllm-bot merged commit b3caeb8 into vllm-project:main Jul 25, 2025
76 of 79 checks passed
liuyumoye pushed a commit to liuyumoye/vllm that referenced this pull request Jul 31, 2025
…t#20295)

Signed-off-by: fsx950223 <[email protected]>
Signed-off-by: amd-ruitang3 <[email protected]>
Co-authored-by: amd-ruitang3 <[email protected]>
wenscarl pushed a commit to wenscarl/vllm that referenced this pull request Aug 4, 2025
…t#20295)

Signed-off-by: fsx950223 <[email protected]>
Signed-off-by: amd-ruitang3 <[email protected]>
Co-authored-by: amd-ruitang3 <[email protected]>
Signed-off-by: shuw <[email protected]>
x22x22 pushed a commit to x22x22/vllm that referenced this pull request Aug 5, 2025
…t#20295)

Signed-off-by: fsx950223 <[email protected]>
Signed-off-by: amd-ruitang3 <[email protected]>
Co-authored-by: amd-ruitang3 <[email protected]>
Signed-off-by: x22x22 <[email protected]>
Pradyun92 pushed a commit to Pradyun92/vllm that referenced this pull request Aug 6, 2025
…t#20295)

Signed-off-by: fsx950223 <[email protected]>
Signed-off-by: amd-ruitang3 <[email protected]>
Co-authored-by: amd-ruitang3 <[email protected]>
npanpaliya pushed a commit to odh-on-pz/vllm-upstream that referenced this pull request Aug 6, 2025
jinzhen-lin pushed a commit to jinzhen-lin/vllm that referenced this pull request Aug 9, 2025
…t#20295)

Signed-off-by: fsx950223 <[email protected]>
Signed-off-by: amd-ruitang3 <[email protected]>
Co-authored-by: amd-ruitang3 <[email protected]>
Signed-off-by: Jinzhen Lin <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ci/build documentation Improvements or additions to documentation llama Related to Llama models performance Performance-related issues ready ONLY add when PR is ready to merge/full CI is needed rocm Related to AMD ROCm speculative-decoding v1
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants