Skip to content

[Bugfix][ROCm] Fix incorrect casting in GPTQ GEMM kernel #17583

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

nlzy
Copy link
Contributor

@nlzy nlzy commented May 2, 2025

As mentioned in #7374 , when using a GPTQ model with desc_act=True and enabling tensor parallelism, it causes the output to become garbled. Additionally, this issue is specific to ROCm and does not occur on NVIDIA GPUs.

To summarize, this bug can only be triggered if all three of the following conditions are met:

  • Tensor parallelism
  • GPTQ models with desc_act=True
  • ROCm platform

The following code reveals that if the model uses desc_act=True and the user enables tensor parallelism, a non-exllama kernel will be used:

if (input_size != input_size_per_partition
and self.quant_config.group_size != -1):
# For act-order models, we cannot use Exllama for row parallel layer
if self.quant_config.desc_act:
exllama_state = ExllamaState.UNUSED

The non-exllama kernel contains the following code:

#ifndef USE_ROCM
res2 = {};
#else
res2.x = __half_as_ushort(__float2half(0));
res2.y = __half_as_ushort(__float2half(0));
#endif

and

#ifndef USE_ROCM
res[m] = __hadd(res[m], __hadd(res2.x, res2.y));
#else
res[m] = __hadd(
res[m], __hadd(__ushort_as_half(res2.x), __ushort_as_half(res2.y)));
#endif

The above code behaves differently on CUDA and ROCm platforms. On CUDA, the logic is straightforward: variables are simply initialized to zero, and basic addition operations are performed.

However, the ROCm code looks unusual. The function __ushort_as_half() converts ushort to half within the type system, does not execute an actual conversion instruction. Meanwhile, this function required a ushort argument, C++ implicit conversion rules will convert the half argument to ushort, which does execute an actual conversion instruction, altering it's value.

After removing this illogical behavior, the issue was resolved.

Additionally, the ROCm compiler does not accept codes like res2 = {};. It will complain about multiple viable candidate functions, leading to ambiguity in the function call. Since the variable res2 is only used within this code block, this PR moves the declaration of res2 inside the block and uses value initialization syntax half2 res2{}; to ensure the variable is initialized to zero.

FIX #7374

Copy link

github-actions bot commented May 2, 2025

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

@houseroad houseroad requested a review from hongxiayang May 5, 2025 15:51
@houseroad houseroad added the rocm Related to AMD ROCm label May 5, 2025
Copy link
Collaborator

@gshtras gshtras left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the fix, looks like whichever missing feature from old ROCm API the ifdefs were meant to address no longer exists.
cc @tjtanaa to verify, since this ifdef was originally added in #2180

@gshtras gshtras added the ready ONLY add when PR is ready to merge/full CI is needed label Jun 10, 2025
@kliuae
Copy link
Contributor

kliuae commented Jun 17, 2025

Looks good and thanks for the fix. The conversions were workarounds for older ROCm where implicit casting wasn’t as comprehensive, but they’re no longer needed with the latest versions.

@tjtanaa
Copy link
Contributor

tjtanaa commented Jun 17, 2025

Looks good and thanks for the fix. The conversions were workarounds for older ROCm where implicit casting wasn’t as comprehensive, but they’re no longer needed with the latest versions.

Thank you @kliuae (the author of the PR #2180)

cc. @gshtras

@gshtras
Copy link
Collaborator

gshtras commented Jun 18, 2025

@nlzy could you please check the tests? Could be enough to merge from main

btbtyler09 added a commit to btbtyler09/vllm-gfx908 that referenced this pull request Jun 20, 2025
- Fix double type conversion bug in q_gemm.cu affecting all GPTQ models with tensor parallelism on ROCm
- Move half2 res2 declaration inside loop with proper zero initialization
- Remove problematic __half_as_ushort/__ushort_as_half conversions
- Fix false Triton flash attention warning for models with sliding window when VLLM_USE_TRITON_FLASH_ATTN=0
- Changes match upstream PR vllm-project#17583

This fixes silent data corruption that was causing GPTQ models to produce gibberish output on ROCm with tensor parallelism.

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <[email protected]>
@nlzy
Copy link
Contributor Author

nlzy commented Jun 22, 2025

Thanks for reviews. I have checked the failed tests in CI, and they should be unrelated to this PR.
@gshtras

@DarkLight1337
Copy link
Member

Can you merge from main to fix the CI failures?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ready ONLY add when PR is ready to merge/full CI is needed rocm Related to AMD ROCm
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[Bug]: Tensor Parallel > 1 causes desc_act=True GPTQ models to give bad output on ROCm
6 participants