Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions examples/ppo_trainer/run_deepseek7b_llm.sh
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ python3 -m verl.trainer.main_ppo \
actor_rollout_ref.rollout.tensor_model_parallel_size=4 \
actor_rollout_ref.rollout.name=vllm \
actor_rollout_ref.rollout.gpu_memory_utilization=0.4 \
actor_rollout_ref.rollout.quantization=torchao \
actor_rollout_ref.rollout.quantization_config_file=torchao_config.json \
critic.optim.lr=1e-5 \
critic.model.use_remove_padding=True \
critic.model.path=deepseek-ai/deepseek-llm-7b-chat \
Expand Down
3 changes: 3 additions & 0 deletions verl/trainer/config/rollout/rollout.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -289,3 +289,6 @@ profiler:

# specific tool config
tool_config: ${oc.select:actor_rollout_ref.actor.profiler.tool_config,null}

quantization: null
quantization_config_file: null
3 changes: 3 additions & 0 deletions verl/workers/config/rollout.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,9 @@ class RolloutConfig(BaseConfig):

skip_tokenizer_init: bool = False

quantization: Optional[str] = None
quantization_config_file: Optional[str] = None

def __post_init__(self):
"""Validate the rollout config"""
if self.expert_parallel_size > 1:
Expand Down
15 changes: 11 additions & 4 deletions verl/workers/rollout/vllm_rollout/vllm_rollout_spmd.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,8 @@
from vllm import LLM, SamplingParams
from vllm.config import CompilationConfig, CompilationLevel, LoRAConfig
from vllm.lora.request import LoRARequest
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.worker.worker_base import WorkerWrapperBase
from vllm.v1.worker.worker_base import WorkerWrapperBase
from vllm.model_executor.model_loader import get_model_loader

from verl import DataProto
from verl.third_party.vllm import VLLM_SLEEP_LEVEL
Expand Down Expand Up @@ -185,6 +185,7 @@ def __init__(
else:
logger.warning(f"cudagraph_capture_sizes must be a list, but got {cudagraph_capture_sizes}")

hf_overrides = {"quantization_config_file": config.quantization_config_file} if (config.quantization is not None and config.quantization_config_file is not None) else None
self.inference_engine = LLM(
model=model_path,
enable_sleep_mode=config.free_cache_engine,
Expand All @@ -204,6 +205,8 @@ def __init__(
enable_prefix_caching=config.enable_prefix_caching,
trust_remote_code=trust_remote_code,
seed=config.get("seed", 0),
quantization=config.quantization if config.quantization is not None else None,
hf_overrides=hf_overrides,
**compilation_config,
**self.lora_kwargs,
**engine_kwargs,
Expand Down Expand Up @@ -450,7 +453,11 @@ async def update_weights(self, weights: Generator[tuple[str, torch.Tensor], None

model = self.inference_engine.llm_engine.model_executor.driver_worker.worker.model_runner.model
patch_vllm_moe_model_weight_loader(model)
model.load_weights(weights)
load_config = self.inference_engine.llm_engine.vllm_config.load_config
load_config.load_format = "auto"
model_config = self.inference_engine.llm_engine.model_config
model_loader = get_model_loader(load_config)
model_loader.load_weights(model, model_config=model_config)


# https://github.com/vllm-project/vllm/issues/13175
Expand All @@ -460,7 +467,7 @@ def _monkey_patch_compute_logits(model, vocab_size: int):
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
sampling_metadata,
) -> torch.Tensor:
logits = original_compute_logits(hidden_states, sampling_metadata)
logits[..., vocab_size:] = float("-inf")
Expand Down