Skip to content

Commit 4180a17

Browse files
committed
Support on the fly quant for rollout
Summary: Only supporting quantizing all linear layers with torchao config for now. see vllm PR for how to generate the quantization file. Also requires vllm changes: vllm-project/vllm#23014 Test Plan: sh examples/ppo_trainer/run_deepseek7b_llm.sh Reviewers: Subscribers: Tasks: Tags:
1 parent 0807da9 commit 4180a17

File tree

5 files changed

+33
-6
lines changed

5 files changed

+33
-6
lines changed

examples/ppo_trainer/run_deepseek7b_llm.sh

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@ python3 -m verl.trainer.main_ppo \
2222
actor_rollout_ref.rollout.tensor_model_parallel_size=4 \
2323
actor_rollout_ref.rollout.name=vllm \
2424
actor_rollout_ref.rollout.gpu_memory_utilization=0.4 \
25+
actor_rollout_ref.rollout.quantization=torchao \
26+
actor_rollout_ref.rollout.quantization_config_file=torchao_config.json \
2527
critic.optim.lr=1e-5 \
2628
critic.model.use_remove_padding=True \
2729
critic.model.path=deepseek-ai/deepseek-llm-7b-chat \

verl/trainer/config/rollout/rollout.yaml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -268,3 +268,6 @@ profiler:
268268

269269
# specific tool config
270270
tool_config: ${oc.select:actor_rollout_ref.actor.profiler.tool_config,null}
271+
272+
quantization: null
273+
quantization_config_file: null

verl/workers/config/rollout.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,3 +139,6 @@ class RolloutConfig(BaseConfig):
139139
layered_summon: bool = False
140140

141141
layer_name_map: dict = field(default_factory=dict)
142+
143+
quantization: Optional[str] = None
144+
quantization_config_file: Optional[str] = None

verl/workers/rollout/vllm_rollout/vllm_rollout_spmd.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,8 @@ def __init__(self, model_path: str, config: RolloutConfig, tokenizer, model_hf_c
194194
enable_prefix_caching=True,
195195
trust_remote_code=trust_remote_code,
196196
seed=config.get("seed", 0),
197+
quantization=config.quantization,
198+
hf_overrides={"quantization_config": {"quantization_config_file": config.quantization_config_file}},
197199
**compilation_config,
198200
**lora_kwargs,
199201
**engine_kwargs,

verl/workers/sharding_manager/fsdp_vllm.py

Lines changed: 23 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -338,16 +338,33 @@ def replace_lora_wrapper(k):
338338

339339
updated_params = {replace_lora_wrapper(k): v for k, v in updated_params.items()}
340340

341+
341342
from verl.utils.vllm.patch import patch_vllm_moe_model_weight_loader
342343

343344
patch_vllm_moe_model_weight_loader(model)
344345
device = get_device_id() # used when fsdp2 set cpu_offload_policy
345-
loaded_params = model.load_weights(
346-
(
347-
(name, param.to(device, non_blocking=True).full_tensor() if isinstance(param, DTensor) else param)
348-
for name, param in updated_params.items()
349-
)
350-
)
346+
347+
# make all DTensor full tensor before quantization
348+
updated_params = {
349+
name: param.to(device, non_blocking=True).full_tensor() if isinstance(param, DTensor) else param
350+
for name, param in updated_params.items()
351+
}
352+
353+
quantization = self.rollout_config.quantization
354+
quantization_config_file = self.rollout_config.quantization_config_file
355+
quantized_updated_params = {}
356+
from vllm.model_executor.layers.quantization import get_quantization_config
357+
import json
358+
if quantization is not None and quantization_config_file is not None:
359+
quant_cls = get_quantization_config(quantization)
360+
config = quant_cls.from_config_file(quantization_config_file)
361+
for name, param in updated_params.items():
362+
if name.endswith("proj.weight"):
363+
quantized_updated_params[name] = config.quantize_param(param)
364+
else:
365+
quantized_updated_params[name] = param
366+
367+
loaded_params = model.load_weights(quantized_updated_params.items())
351368

352369
self.base_sync_done = True
353370
logger.info(f"vLLM load weights, loaded_params: {len(loaded_params) if loaded_params else -1}")

0 commit comments

Comments
 (0)