-
Notifications
You must be signed in to change notification settings - Fork 2.5k
[vllm] feat: Support on the fly quant for rollout with torchao #3084
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces support for on-the-fly quantization for vLLM rollouts using torchao
. The changes involve adding configuration options for quantization and implementing the quantization logic within the FSDP sharding manager.
My review identified a critical issue in the implementation where model weights would fail to load if quantization is disabled. I have provided a code suggestion to fix this. Additionally, I've pointed out that the current method for selecting layers to quantize is too specific and may miss some linear layers, which could lead to unexpected behavior.
quantization = self.rollout_config.quantization | ||
quantization_config_file = self.rollout_config.quantization_config_file | ||
quantized_updated_params = {} | ||
from vllm.model_executor.layers.quantization import get_quantization_config | ||
import json | ||
if quantization is not None and quantization_config_file is not None: | ||
quant_cls = get_quantization_config(quantization) | ||
config = quant_cls.from_config_file(quantization_config_file) | ||
for name, param in updated_params.items(): | ||
if name.endswith("proj.weight"): | ||
quantized_updated_params[name] = config.quantize_param(param) | ||
else: | ||
quantized_updated_params[name] = param | ||
|
||
loaded_params = model.load_weights(quantized_updated_params.items()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This block has a couple of issues:
-
Critical Bug: When quantization is disabled (
quantization
isNone
),quantized_updated_params
remains empty. Consequently,model.load_weights()
is called with no parameters, meaning the model weights are not loaded at all. This needs to be fixed to loadupdated_params
when quantization is off. -
Incomplete Quantization: The condition
name.endswith(\"proj.weight\")
for selecting parameters to quantize is too specific. It will miss other linear layers, such aslm_head.weight
. This will result in a partially quantized model, which can lead to correctness issues or suboptimal performance. This should be made more general, perhaps by using a configurable list of patterns.
I've provided a suggestion that fixes the critical bug and also refactors the logic for clarity. The incomplete quantization issue will still need to be addressed.
quantization = self.rollout_config.quantization | |
quantization_config_file = self.rollout_config.quantization_config_file | |
quantized_updated_params = {} | |
from vllm.model_executor.layers.quantization import get_quantization_config | |
import json | |
if quantization is not None and quantization_config_file is not None: | |
quant_cls = get_quantization_config(quantization) | |
config = quant_cls.from_config_file(quantization_config_file) | |
for name, param in updated_params.items(): | |
if name.endswith("proj.weight"): | |
quantized_updated_params[name] = config.quantize_param(param) | |
else: | |
quantized_updated_params[name] = param | |
loaded_params = model.load_weights(quantized_updated_params.items()) | |
params_to_load = updated_params | |
quantization = self.rollout_config.quantization | |
quantization_config_file = self.rollout_config.quantization_config_file | |
if quantization is not None and quantization_config_file is not None: | |
from vllm.model_executor.layers.quantization import get_quantization_config | |
quant_cls = get_quantization_config(quantization) | |
config = quant_cls.from_config_file(quantization_config_file) | |
quantized_updated_params = {} | |
for name, param in updated_params.items(): | |
if name.endswith("proj.weight"): | |
quantized_updated_params[name] = config.quantize_param(param) | |
else: | |
quantized_updated_params[name] = param | |
params_to_load = quantized_updated_params | |
loaded_params = model.load_weights(params_to_load.items()) |
4180a17
to
a841cba
Compare
please let me know if the API makes sense, I can clean up both PRs after confirmation |
I also wonder how torchao fp8 quantization compares to vllm's own impl for And also, how does the approach in this PR compare to FlashRL approach (which also patches vllm) |
we haven't officially compared with them yet, but we are integrating with fbgemm kernels which should be SOTA.
also haven't compared yet, but we can discuss what would be the API that makes most of sense I think |
a841cba
to
d06a31d
Compare
Summary: Only supporting quantizing all linear layers with torchao config for now. see vllm PR for how to generate the quantization file. Also requires vllm changes: vllm-project/vllm#23014 Test Plan: sh examples/ppo_trainer/run_deepseek7b_llm.sh Reviewers: Subscribers: Tasks: Tags:
d06a31d
to
3bf8cd5
Compare
Summary:
Only supporting quantizing all linear layers with torchao config for now. see vllm PR for how to generate the quantization file.
Also requires vllm changes: vllm-project/vllm#23014
Test Plan:
sh examples/ppo_trainer/run_deepseek7b_llm.sh
Reviewers:
Subscribers:
Tasks:
Tags:
What does this PR do?
Checklist Before Starting
[{modules}] {type}: {description}
(This will be checked by the CI){modules}
includefsdp
,megatron
,sglang
,vllm
,rollout
,trainer
,ci
,training_utils
,recipe
,hardware
,deployment
,ray
,worker
,single_controller
,misc
,perf
,model
,algo
,env
,tool
,ckpt
,doc
,data
,
like[megatron, fsdp, doc]
{type}
is infeat
,fix
,refactor
,chore
,test
[BREAKING]
to the beginning of the title.[BREAKING][fsdp, megatron] feat: dynamic batching
Test
API and Usage Example
# Add code snippet or script demonstrating how to use this
Design & Code Changes
Checklist Before Submitting
Important
Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review.
pre-commit install && pre-commit run --all-files --show-diff-on-failure --color=always
ci-request
channel in theverl
Slack workspace. (If not accessible, please try the Feishu group (飞书群).)