Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 14 additions & 1 deletion docs/advance/ppo_lora.rst
Original file line number Diff line number Diff line change
Expand Up @@ -49,11 +49,24 @@ Best Practices and Notes
- LoRA rank recommendation from @thelongestusernameofall:

- A very small lora_rank can lead to slower convergence or worse training performance. It is recommended to set lora_rank to be>=32. Tests have shown that for a 0.5B model, with lora_rank=32,the training convergence speed and final performance are almost identical to non-LoRA training
- For a 32B model,with lora_rank=128,the training convergence speed and final performance are also almost identical to non-LoRA training.
- For a 32B model,with lora_rank=128,the training convergence speed and final performance are also almost identical to non-LoRA training.
- More comprehensive reference results are coming soon.

.. image:: https://github.com/eric-haibin-lin/verl-community/blob/f2b80b8b26829124dd393b7a795a0640eff11644/docs/lora.jpg?raw=true

SVD-LoRA adapters
-----------------

`ESSA: Evolutionary Strategies for Scalable Alignment <https://arxiv.org/abs/2507.04453>`_ introduced SVD-LoRA-GRPO, which
factorizes LoRA adapters into orthogonal matrices plus singular values. ``verl`` includes a matching experimental
conversion:

- Enable it with ``actor_rollout_ref.model.use_svd_lora=True`` (see :mod:`verl.workers.config.model`). It currently
activates only when the actor runs with ``fsdp2``.
- Initialization calls :func:`verl.utils.experimental.svd_lora.apply_svd_lora`, swapping LoRA modules for
``SVDLinear`` layers whose weights are rebuilt from ``U``, ``sigma``, and ``V``; only the singular values stay
trainable.

3. Reference configuration for RL training with the Qwen2.5-72B model using 8 x 80GB GPUs (increase lora_rank if needed):

.. code-block::
Expand Down
91 changes: 91 additions & 0 deletions verl/utils/experimental/svd_lora.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch
import torch.nn as nn


class SVDLinear(nn.Module):
# Implementation of SVD-LoRA-GRPO method introduced in [ESSA: Evolutionary Strategies for Scalable Alignment](https://arxiv.org/abs/2507.04453)
U: torch.Tensor
sigma: torch.Tensor
V: torch.Tensor

def __init__(
self,
U: torch.Tensor,
sigma: torch.Tensor,
V: torch.Tensor,
dtype: torch.dtype = None,
device: torch.device = None,
):
super().__init__()

if device is not None:
U = U.to(device)
sigma = sigma.to(device)
V = V.to(device)

U = U.contiguous()
sigma = sigma.contiguous()
V = V.contiguous()

self.U = nn.Parameter(U, requires_grad=False).to(torch.float32)
self.sigma = nn.Parameter(sigma, requires_grad=True).to(torch.float32)
self.V = nn.Parameter(V, requires_grad=False).to(torch.float32)

self.out_features = self.U.size(0)
self.in_features = self.V.size(0)
self.bias = None

@staticmethod
def create_from_weight(weight: torch.Tensor) -> "SVDLinear":
U, S, Vh = torch.linalg.svd(weight.to(torch.float32), full_matrices=True)
V = Vh.T
return SVDLinear(U, S, V, dtype=weight.dtype, device=weight.device)

def _get_svd_weight(self) -> torch.Tensor:
W = (self.U * self.sigma) @ self.V.T
return W.contiguous()

@property
def weight(self) -> torch.Tensor:
return self._get_svd_weight()

def forward(self, input: torch.Tensor) -> torch.Tensor:
return input @ self.weight.T


def apply_svd_lora(model: nn.Module) -> nn.Module:
def set_deep_attr(obj, attr_path, value):
parts = attr_path.split(".")
for part in parts[:-1]:
obj = obj[int(part)] if part.isdigit() else getattr(obj, part)
setattr(obj, parts[-1], value)

replacements = {}

for name, mod in model.named_modules():
if ("lora_A" in name or "lora_B" in name) and not name.endswith(".default"):
if isinstance(mod, nn.ModuleDict):
new_dict = nn.ModuleDict()
for key, sub in mod.items():
W = sub.weight.data
new_dict[key] = SVDLinear.create_from_weight(W)
replacements[name] = new_dict

for name, new_mod in replacements.items():
set_deep_attr(model, name, new_mod)

return model
46 changes: 35 additions & 11 deletions verl/utils/fsdp_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import math
import os
from abc import ABC
from collections import OrderedDict
from collections import OrderedDict, defaultdict
from contextlib import contextmanager, nullcontext

import torch
Expand Down Expand Up @@ -566,7 +566,34 @@ def fsdp2_clip_grad_norm_(parameters, max_norm, norm_type=2.0, error_if_nonfinit
return total_norm


def layered_summon_lora_params(fsdp_module) -> OrderedDict:
def to_tensor(param, move_to_cpu: bool):
tensor = param.full_tensor() if hasattr(param, "full_tensor") else param
tensor = tensor.detach()
return tensor.cpu() if move_to_cpu else tensor


def gather_standard_params(state_dict, prefix: str) -> OrderedDict:
collected = OrderedDict()
for name, param in state_dict.items():
collected[f"{prefix}.{name}"] = to_tensor(param, move_to_cpu=True)
return collected


def gather_svd_params(state_dict, prefix: str) -> OrderedDict:
grouped = defaultdict(dict)
for name, param in state_dict.items():
layer_prefix, layer_suffix = name.rsplit(".", 1)
grouped[layer_prefix][layer_suffix] = to_tensor(param, move_to_cpu=False)
collected = OrderedDict()
for layer_name, parts in grouped.items():
U = parts.get("U")
sigma = parts.get("sigma")
V = parts.get("V")
collected[f"{prefix}.{layer_name}.weight"] = ((U * sigma) @ V.T).cpu()
return collected


def layered_summon_lora_params(fsdp_module, use_svd_lora) -> OrderedDict:
from peft.utils.save_and_load import get_peft_model_state_dict

def __prefix_submodules(module, prefix):
Expand Down Expand Up @@ -596,19 +623,16 @@ def __prefix_submodules(module, prefix):
if fsdp_version(submodule) > 0:
with FSDP.summon_full_params(submodule, writeback=False):
sub_lora_params = get_peft_model_state_dict(peft_model, state_dict=submodule.state_dict())
sub_lora_params = {
f"{prefix}.{name}": param.full_tensor().detach().cpu()
if hasattr(param, "full_tensor")
else param.detach().cpu()
for name, param in sub_lora_params.items()
}
lora_params.update(sub_lora_params)
if use_svd_lora:
lora_params.update(gather_svd_params(sub_lora_params, prefix))
else:
lora_params.update(gather_standard_params(sub_lora_params, prefix))
submodule._is_root = False
get_torch_device().empty_cache()
return lora_params


def collect_lora_params(module: FSDP, layered_summon: bool, base_sync_done: bool) -> OrderedDict:
def collect_lora_params(module: FSDP, layered_summon: bool, base_sync_done: bool, use_svd_lora: bool) -> OrderedDict:
"""
collect lora params or full params if base model is not ready in vllm
work with if isinstance(self.module._fsdp_wrapped_module, PeftModel)
Expand All @@ -624,7 +648,7 @@ def collect_lora_params(module: FSDP, layered_summon: bool, base_sync_done: bool
"To use layered_summon, you must make sure base-model is preloaded in vllm, e.g. let "
"rollout.load_format=safetensors"
)
lora_params = layered_summon_lora_params(module)
lora_params = layered_summon_lora_params(module, use_svd_lora)
else:
with FSDP.summon_full_params(module, writeback=False):
if base_sync_done:
Expand Down
2 changes: 2 additions & 0 deletions verl/workers/config/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,8 @@ class HFModelConfig(BaseConfig):
lora_rank: int = 0
lora_alpha: int = 16
target_modules: Optional[str] = "all-linear"
use_svd_lora: bool = False # Toggle for the SVD-GRPO approach described in ESSA
# [ESSA: Evolutionary Strategies for Scalable Alignment](https://arxiv.org/abs/2507.04453)

exclude_modules: Optional[str] = None
use_liger: bool = False
Expand Down
10 changes: 9 additions & 1 deletion verl/workers/fsdp_workers.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@
is_npu_available,
set_expandable_segments,
)
from verl.utils.experimental.svd_lora import apply_svd_lora
from verl.utils.flops_counter import FlopsCounter
from verl.utils.fs import copy_to_local
from verl.utils.fsdp_utils import (
Expand Down Expand Up @@ -420,6 +421,9 @@ def _build_model_optimizer(
}
actor_module = get_peft_model(actor_module, LoraConfig(**lora_config))

if self.config.model.get("use_svd_lora", False) and self.config.actor.strategy == "fsdp2":
actor_module = apply_svd_lora(actor_module)

self.use_orig_params = fsdp_config.get("use_orig_params", False)
if self.config.actor.get("freeze_vision_tower", False):
vision_tower = get_vl_model_vision_tower(actor_module)
Expand Down Expand Up @@ -647,6 +651,7 @@ async def rollout_mode(self):
module=self.actor_module_fsdp,
layered_summon=self.config.rollout.get("layered_summon", False),
base_sync_done=self.base_sync_done,
use_svd_lora=self.config.model.get("use_svd_lora", False),
)
if not self.base_sync_done:
params = {replace_lora_wrapper(k, peft_config): v for k, v in params.items()}
Expand All @@ -666,6 +671,7 @@ async def rollout_mode(self):
module=self.actor_module_fsdp,
layered_summon=self.layered_summon,
base_sync_done=False,
use_svd_lora=self.config.model.get("use_svd_lora", False),
)
base_model_params = {replace_lora_wrapper(k, peft_config): v for k, v in base_model_params.items()}
base_model_params = convert_weight_keys(
Expand Down Expand Up @@ -1040,7 +1046,9 @@ def save_checkpoint(self, local_path, hdfs_path=None, global_step=0, max_ckpt_to
try:
if fsdp_version(self.actor_module_fsdp) > 0:
self.actor_module_fsdp = self.actor_module_fsdp.to(get_device_name())
lora_params = layered_summon_lora_params(self.actor_module_fsdp)
lora_params = layered_summon_lora_params(
self.actor_module_fsdp, self.config.model.get("use_svd_lora", False)
)
if dist.get_rank() == 0:
save_file(lora_params, os.path.join(lora_save_path, "adapter_model.safetensors"))
with open(os.path.join(lora_save_path, "adapter_config.json"), "w", encoding="utf-8") as f:
Expand Down