Skip to content

Conversation

jerryzh168
Copy link

Summary:
Only supporting quantizing all linear layers with torchao config for now. see vllm PR for how to generate the quantization file.
Also requires vllm changes: vllm-project/vllm#23014

Test Plan:
sh examples/ppo_trainer/run_deepseek7b_llm.sh

Reviewers:

Subscribers:

Tasks:

Tags:

What does this PR do?

Add concise overview of what this PR aims to achieve or accomplish. Reference related GitHub issues and PRs that help with the review.

Checklist Before Starting

  • Search for similar PRs. Paste at least one query link here: ...
  • Format the PR title as [{modules}] {type}: {description} (This will be checked by the CI)
    • {modules} include fsdp, megatron, sglang, vllm, rollout, trainer, ci, training_utils, recipe, hardware, deployment, ray, worker, single_controller, misc, perf, model, algo, env, tool, ckpt, doc, data
    • If this PR involves multiple modules, separate them with , like [megatron, fsdp, doc]
    • {type} is in feat, fix, refactor, chore, test
    • If this PR breaks any API (CLI arguments, config, function signature, etc.), add [BREAKING] to the beginning of the title.
    • Example: [BREAKING][fsdp, megatron] feat: dynamic batching

Test

For changes that can not be tested by CI (e.g., algorithm implementation, new model support), validate by experiment(s) and show results like training curve plots, evaluation results, etc.

API and Usage Example

Demonstrate how the API changes if any, and provide usage example(s) if possible.

# Add code snippet or script demonstrating how to use this

Design & Code Changes

Demonstrate the high-level design if this PR is complex, and list the specific changes.

Checklist Before Submitting

Important

Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review.

@CLAassistant
Copy link

CLA assistant check
Thank you for your submission! We really appreciate it. Like many open source projects, we ask that you sign our Contributor License Agreement before we can accept your contribution.
You have signed the CLA already but the status is still pending? Let us recheck it.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for on-the-fly quantization for vLLM rollouts using torchao. The changes involve adding configuration options for quantization and implementing the quantization logic within the FSDP sharding manager.

My review identified a critical issue in the implementation where model weights would fail to load if quantization is disabled. I have provided a code suggestion to fix this. Additionally, I've pointed out that the current method for selecting layers to quantize is too specific and may miss some linear layers, which could lead to unexpected behavior.

Comment on lines 353 to 369
quantization = self.rollout_config.quantization
quantization_config_file = self.rollout_config.quantization_config_file
quantized_updated_params = {}
from vllm.model_executor.layers.quantization import get_quantization_config
import json
if quantization is not None and quantization_config_file is not None:
quant_cls = get_quantization_config(quantization)
config = quant_cls.from_config_file(quantization_config_file)
for name, param in updated_params.items():
if name.endswith("proj.weight"):
quantized_updated_params[name] = config.quantize_param(param)
else:
quantized_updated_params[name] = param

loaded_params = model.load_weights(quantized_updated_params.items())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

This block has a couple of issues:

  1. Critical Bug: When quantization is disabled (quantization is None), quantized_updated_params remains empty. Consequently, model.load_weights() is called with no parameters, meaning the model weights are not loaded at all. This needs to be fixed to load updated_params when quantization is off.

  2. Incomplete Quantization: The condition name.endswith(\"proj.weight\") for selecting parameters to quantize is too specific. It will miss other linear layers, such as lm_head.weight. This will result in a partially quantized model, which can lead to correctness issues or suboptimal performance. This should be made more general, perhaps by using a configurable list of patterns.

I've provided a suggestion that fixes the critical bug and also refactors the logic for clarity. The incomplete quantization issue will still need to be addressed.

Suggested change
quantization = self.rollout_config.quantization
quantization_config_file = self.rollout_config.quantization_config_file
quantized_updated_params = {}
from vllm.model_executor.layers.quantization import get_quantization_config
import json
if quantization is not None and quantization_config_file is not None:
quant_cls = get_quantization_config(quantization)
config = quant_cls.from_config_file(quantization_config_file)
for name, param in updated_params.items():
if name.endswith("proj.weight"):
quantized_updated_params[name] = config.quantize_param(param)
else:
quantized_updated_params[name] = param
loaded_params = model.load_weights(quantized_updated_params.items())
params_to_load = updated_params
quantization = self.rollout_config.quantization
quantization_config_file = self.rollout_config.quantization_config_file
if quantization is not None and quantization_config_file is not None:
from vllm.model_executor.layers.quantization import get_quantization_config
quant_cls = get_quantization_config(quantization)
config = quant_cls.from_config_file(quantization_config_file)
quantized_updated_params = {}
for name, param in updated_params.items():
if name.endswith("proj.weight"):
quantized_updated_params[name] = config.quantize_param(param)
else:
quantized_updated_params[name] = param
params_to_load = quantized_updated_params
loaded_params = model.load_weights(params_to_load.items())

@jerryzh168 jerryzh168 marked this pull request as draft August 16, 2025 00:02
@jerryzh168
Copy link
Author

please let me know if the API makes sense, I can clean up both PRs after confirmation

@vadimkantorov
Copy link

vadimkantorov commented Aug 18, 2025

I also wonder how torchao fp8 quantization compares to vllm's own impl for quantization="fp8" (and related) - and curious, why does vllm not use torchao for this (IIUC there are several "quantization backends" in vllm?), why do they have to use CUDA kernels https://github.com/vllm-project/vllm/tree/main/csrc/quantization/fp8 (why Triton is not sufficient) and why are these kernels not upstreamed to PyTorch core :) Hoping that at least on kernel level, fragmentation can be reduced by upstreaming more of these https://github.com/vllm-project/vllm/tree/main/csrc/quantization :)

And also, how does the approach in this PR compare to FlashRL approach (which also patches vllm)

@jerryzh168
Copy link
Author

I also wonder how torchao fp8 quantization compares to vllm's own impl for quantization="fp8" (and related) - and curious, why does vllm not use torchao for this (IIUC there are several "quantization backends" in vllm?), why do they have to use CUDA kernels vllm-project/vllm@main/csrc/quantization/fp8 (why Triton is not sufficient) and why are these kernels not upstreamed to PyTorch core :) Hoping that at least on kernel level, fragmentation can be reduced by upstreaming more of these vllm-project/vllm@main/csrc/quantization :)

we haven't officially compared with them yet, but we are integrating with fbgemm kernels which should be SOTA.
why does vllm not use torchao: torchao is integrated into vllm recently: https://docs.vllm.ai/en/latest/features/quantization/torchao.html (a few months ago), and we are actively working on improving torchao for vllm users. I don't have context on vllm-project/vllm@main/csrc/quantization/fp8 though.

And also, how does the approach in this PR compare to FlashRL approach (which also patches vllm)

also haven't compared yet, but we can discuss what would be the API that makes most of sense I think

Summary:
Only supporting quantizing all linear layers with torchao config for now. see vllm PR for
how to generate the quantization file.
Also requires vllm changes: vllm-project/vllm#23014

Test Plan:
sh examples/ppo_trainer/run_deepseek7b_llm.sh

Reviewers:

Subscribers:

Tasks:

Tags:
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants