Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 23 additions & 7 deletions verl/workers/rollout/vllm_rollout/vllm_rollout_spmd.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@
from tensordict import TensorDict
from torch.distributed.device_mesh import DeviceMesh
from vllm import LLM, SamplingParams
from vllm.config import CompilationConfig, CompilationLevel
from vllm.config import CompilationConfig, CompilationLevel, LoRAConfig
from vllm.lora.request import LoRARequest
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.worker.worker_base import WorkerWrapperBase
Expand Down Expand Up @@ -479,10 +479,12 @@ def __init__(
device_mesh: DeviceMesh,
):
super().__init__(config, model_config, device_mesh)

self.tokenizer = model_config.tokenizer
self.inference_engine: WorkerWrapperBase = None
self.address = self._init_zeromq()
self.lora_config = (
{"max_loras": 1, "max_lora_rank": model_config.lora_rank} if model_config.lora_rank > 0 else {}
)

# https://github.com/vllm-project/vllm/issues/25171
if config.layered_summon or config.expert_parallel_size > 1:
Expand Down Expand Up @@ -536,7 +538,6 @@ def _init_worker(self, all_kwargs: list[dict[str, Any]]):
"""Initialize worker engine."""
if not torch.distributed.is_initialized():
initialize_global_process_group_ray()

all_kwargs[0]["rank"] = int(os.environ["RANK"])
device_name = "NPU" if is_npu_available else "GPU"
all_kwargs[0]["local_rank"] = (
Expand All @@ -545,6 +546,8 @@ def _init_worker(self, all_kwargs: list[dict[str, Any]]):
else int(ray.get_runtime_context().get_accelerator_ids()[device_name][0])
)
self.vllm_config = all_kwargs[0]["vllm_config"]
if self.lora_config:
self.vllm_config.lora_config = LoRAConfig(**self.lora_config)
self.inference_engine = WorkerWrapperBase(vllm_config=self.vllm_config)
self.inference_engine.init_worker(all_kwargs)

Expand Down Expand Up @@ -582,11 +585,24 @@ async def update_weights(self, weights: Generator[tuple[str, torch.Tensor], None
Args:
weights: A generator that yields the name of the weight tensor and the tensor itself.
"""
from verl.utils.vllm.patch import patch_vllm_moe_model_weight_loader
peft_config, base_sync_done = kwargs.get("peft_config", None), kwargs.get("base_sync_done", False)
if peft_config and base_sync_done:
lora_int_id = int(time.time_ns() % 0x7FFFFFFF)
lora_reqest = TensorLoRARequest(
lora_name=f"{lora_int_id}",
lora_int_id=lora_int_id,
lora_path="simon_lora_path",
peft_config=asdict(peft_config),
lora_tensors=weights,
)
self.inference_engine.worker.add_lora(lora_reqest)
logger.info(f"vLLM load weights, loaded_params: {len(weights)}")
else:
from verl.utils.vllm.patch import patch_vllm_moe_model_weight_loader

model = self.inference_engine.worker.model_runner.model
patch_vllm_moe_model_weight_loader(model)
model.load_weights(weights)
model = self.inference_engine.worker.model_runner.model
patch_vllm_moe_model_weight_loader(model)
model.load_weights(weights)

def generate_sequences(self, prompts: DataProto) -> DataProto:
"""Batch generate sequences in sync mode."""
Expand Down
Loading