Skip to content

[TPU][Bugfix] fix the MoE OOM issue #20339

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 6, 2025

Conversation

yaochengji
Copy link
Collaborator

@yaochengji yaochengji commented Jul 1, 2025

Purpose

The XLA backend for TPUs handles its own functionalization, so we don't need to wrap it as a custom operation to benefit from torch.compile's auto-functionalization. Additionally, using a custom operation would cause HBM OOM errors on TPU.

Test Plan

vllm serve mistralai/Mixtral-8x7B-Instruct-v0.1 --seed 42 --disable-log-requests --gpu-memory-utilization 0.95  --max-num-batched-tokens 4096 --max-num-seqs 256 --tensor-parallel-size 8 --max-model-len 2048 --no-enable-prefix-caching 

Test Result

Passed. (There's HBM OOM without the fix)

Copy link

github-actions bot commented Jul 1, 2025

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary of Changes

Hello @yaochengji, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request implements a targeted fix for an Out-Of-Memory (OOM) issue encountered when running Mixture-of-Experts (MoE) models on TPUs. By optimizing the MoE forward pass to bypass an unnecessary custom operation on TPU platforms, it leverages the XLA backend's native functionalization capabilities, leading to significant memory savings and enabling successful execution of these models.

Highlights

  • TPU-Specific MoE Optimization: I've introduced a conditional logic within the Mixture-of-Experts (MoE) layer's forward pass to specifically handle execution on TPUs. When running on a TPU, the code now directly calls self.forward_impl instead of wrapping the operation as a custom torch.ops.vllm.moe_forward.
  • Out-Of-Memory (OOM) Fix: This change directly addresses and resolves an Out-Of-Memory (OOM) issue that was occurring on TPUs when processing MoE models. The custom operation wrapping was found to be redundant and memory-intensive on TPUs, as the XLA backend handles functionalization natively.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in issue comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist is currently in preview and may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments to provide feedback.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

@yaochengji yaochengji requested review from ProExpertProg and mgoin July 1, 2025 21:53
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request addresses an Out-of-Memory (OOM) issue with Mixture-of-Experts (MoE) layers on TPUs. The fix involves bypassing a custom torch.ops operation on the TPU platform, as the XLA backend handles functionalization automatically. The change is implemented with a clean conditional check for the TPU platform in the FusedMoE.forward method, directly calling the underlying implementation and avoiding the problematic custom op wrapper. The change is straightforward, well-commented, and appears to correctly resolve the issue as described.

@mgoin mgoin requested a review from zou3519 July 1, 2025 22:18
Comment on lines 1567 to 1570
# Note: The XLA backend for TPUs handles its own functionalization, so
# we don't need to wrap it as a custom operation to benefit from
# torch.compile's auto-functionalization. Additionally, using a
# custom operation would cause HBM OOM errors on TPU.
Copy link
Collaborator

@zou3519 zou3519 Jul 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mentioned this offline, but I think what is happening is:

  • vLLM on TPUs uses torch.compile(backend="openxla")
  • torch.compile's frontend (TorchDynamo) isn't responsible for adding auto_functionalization calls. The backend does. On GPUs, backend="inductor" adds the auto_functionalization calls (backend=inductor runs AOTAutograd+Inductor, AOTAutograd is the thing that performs functionalization).
  • The implication is that backend="openxla" is adding the auto_functionalized calls and the not handling them correctly. My guess is that backend="openxla" uses AOTAutograd in part of its implementation.

Does this sound reasonable? If so, then I think the action items are:

  • I'm happy to ship the current patch if it works for you.
  • You should change the comment to indicate this is a workaround for backend="openxla" and that there's a deeper issue going on. This is not the first time TPUs will run into this issue.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @zou3519 for the detailed instruction! I will firstly check the openxla backend to see whether I can fix this in the backend quickly.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After offline discussion with @zou3519 and @bnellnm and more experiments, we've determined that the issue persists even after removing auto_functionalization. The OOM problem consistently appears whenever we use custom operations on the TPU backend.

I'll need some time to investigate the root cause, but in the interim, I believe this PR can serve as a workaround to unblock MoE models for TPU.

@yaochengji yaochengji marked this pull request as draft July 2, 2025 16:59
@yaochengji yaochengji force-pushed the chengji/fix-moe-oom branch from bc687ac to 8477625 Compare July 2, 2025 21:40
@yaochengji yaochengji marked this pull request as ready for review July 2, 2025 21:58
@zou3519 zou3519 added the ready ONLY add when PR is ready to merge/full CI is needed label Jul 5, 2025
@yaochengji yaochengji enabled auto-merge (squash) July 6, 2025 02:21
@vllm-bot vllm-bot merged commit 4548c03 into vllm-project:main Jul 6, 2025
79 of 83 checks passed
huydhn pushed a commit to huydhn/vllm that referenced this pull request Jul 8, 2025
Chen-zexi pushed a commit to Chen-zexi/vllm that referenced this pull request Jul 13, 2025
patrickvonplaten pushed a commit to patrickvonplaten/vllm that referenced this pull request Jul 15, 2025
Signed-off-by: Chengji Yao <[email protected]>
Signed-off-by: Patrick von Platen <[email protected]>
LyrisZhong pushed a commit to LyrisZhong/vllm that referenced this pull request Jul 23, 2025
avigny pushed a commit to avigny/vllm that referenced this pull request Jul 31, 2025
Pradyun92 pushed a commit to Pradyun92/vllm that referenced this pull request Aug 6, 2025
npanpaliya pushed a commit to odh-on-pz/vllm-upstream that referenced this pull request Aug 6, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ready ONLY add when PR is ready to merge/full CI is needed
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants