Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions docs/examples/config.rst
Original file line number Diff line number Diff line change
Expand Up @@ -614,21 +614,28 @@ Optim
.. code:: yaml

optim:
optimizer: AdamW
optimizer_impl: torch.optim
lr: 1e-5
weight_decay: 0.01
warmup_steps_ratio: 0.1
lr_warmup_steps_ratio: 0.1
clip_grad: 1.0
lr_scheduler: cosine
override_optimizer_config: null

- ``optimizer``: Optimizer class name (e.g., ``"AdamW"``, ``"AdamW8bit"``, ``"_AdamW"``). The class name as it appears in the module.
- ``optimizer_impl``: Module path to import optimizer from (e.g., ``"torch.optim"``, ``"torchao.optim"``, ``"bitsandbytes.optim"``).
- ``optim.lr``: Learning rate for the optimizer.
- ``optim.weight_decay``: Weight decay for the optimizer.
- ``optim.warmup_steps_ratio``: Ratio of warmup steps to total training steps.
- ``optim.lr_warmup_steps_ratio``: Ratio of warmup steps to total training steps.
- ``optim.clip_grad``: Gradient clipping value.
- ``optim.lr_scheduler``: Learning rate scheduler type. Options:

- ``cosine``: Cosine learning rate scheduler with warmup (default).
- ``wsd``: Warmup-Stable-Decay scheduler that provides a stable learning rate phase between warmup and decay phases.

- ``override_optimizer_config``: Dictionary of additional optimizer-specific keyword arguments. For example, to use ``torchao.optim``'s ``_AdamW`` with BF16 stochastic rounding: ``{"bf16_stochastic_round": true}``

Model
~~~~~~~~~~~~

Expand Down
9 changes: 2 additions & 7 deletions recipe/prime/prime_fsdp_workers.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
)
from verl.utils.import_utils import import_external_libs
from verl.utils.profiler import log_gpu_memory_usage
from verl.workers.config.optimizer import build_optimizer
from verl.workers.fsdp_workers import create_device_mesh, get_sharding_strategy
from verl.workers.sharding_manager.fsdp_ulysses import FSDPUlyssesShardingManager

Expand Down Expand Up @@ -87,7 +88,6 @@ def __init__(self, config):

def _build_reward_ref_model_optimizer(self, config):
# the following line is necessary
from torch import optim
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import MixedPrecision

Expand Down Expand Up @@ -219,12 +219,7 @@ def _build_reward_ref_model_optimizer(self, config):
cpu_offload=None,
)

reward_optimizer = optim.AdamW(
reward_module.parameters(),
lr=config.model.optim.lr,
betas=config.model.optim.get("betas", (0.9, 0.999)),
weight_decay=config.model.optim.get("weight_decay", 1e-2),
)
reward_optimizer = build_optimizer(reward_module.parameters(), config.model.optim)

total_steps = config.model.optim.get("total_training_steps", 0)
num_warmup_steps = int(config.model.optim.get("lr_warmup_steps", -1))
Expand Down
11 changes: 8 additions & 3 deletions tests/workers/config/test_actor_config_on_cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,12 @@
import unittest

from verl.utils.config import omega_conf_to_dataclass
from verl.workers.config import ActorConfig, FSDPActorConfig, McoreActorConfig, OptimizerConfig
from verl.workers.config import (
ActorConfig,
FSDPActorConfig,
McoreActorConfig,
OptimizerConfig,
)


class TestActorConfig(unittest.TestCase):
Expand All @@ -31,7 +36,7 @@ def test_config_inheritance(self):
"ppo_micro_batch_size_per_gpu": 256,
"clip_ratio": 0.2,
"optim": {
"_target_": "verl.workers.config.OptimizerConfig",
"_target_": "verl.workers.config.McoreOptimizerConfig",
"lr": 0.1,
},
}
Expand All @@ -42,7 +47,7 @@ def test_config_inheritance(self):
"ppo_micro_batch_size_per_gpu": 256,
"clip_ratio": 0.2,
"optim": {
"_target_": "verl.workers.config.OptimizerConfig",
"_target_": "verl.workers.config.FSDPOptimizerConfig",
"lr": 0.1,
},
}
Expand Down
25 changes: 12 additions & 13 deletions tests/workers/config/test_critic_config_on_cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,9 @@
from verl.workers.config import (
CriticConfig,
FSDPCriticConfig,
FSDPOptimizerConfig,
McoreCriticConfig,
McoreOptimizerConfig,
OptimizerConfig,
)

Expand Down Expand Up @@ -103,16 +105,15 @@ def test_fsdp_critic_config_instantiation_from_yaml(self, config_dir):

def test_config_inheritance_hierarchy(self):
"""Test that the inheritance hierarchy is correct."""
optim = OptimizerConfig(lr=0.1)
megatron_config = McoreCriticConfig(ppo_micro_batch_size_per_gpu=1, optim=optim)
megatron_config = McoreCriticConfig(ppo_micro_batch_size_per_gpu=1, optim=McoreOptimizerConfig(lr=0.1))
assert isinstance(megatron_config, CriticConfig)
assert isinstance(megatron_config, McoreCriticConfig)

fsdp_config = FSDPCriticConfig(ppo_micro_batch_size_per_gpu=1, optim=optim)
fsdp_config = FSDPCriticConfig(ppo_micro_batch_size_per_gpu=1, optim=FSDPOptimizerConfig(lr=0.1))
assert isinstance(fsdp_config, CriticConfig)
assert isinstance(fsdp_config, FSDPCriticConfig)

critic_config = CriticConfig(ppo_micro_batch_size_per_gpu=1, strategy="fsdp2", optim=optim)
critic_config = CriticConfig(ppo_micro_batch_size_per_gpu=1, strategy="fsdp2", optim=OptimizerConfig(lr=0.1))
assert isinstance(critic_config, CriticConfig)
assert not isinstance(critic_config, McoreCriticConfig)
assert not isinstance(critic_config, FSDPCriticConfig)
Expand All @@ -136,22 +137,21 @@ def test_config_dict_interface(self):

def test_frozen_fields_immutability(self):
"""Test that frozen fields raise exceptions when modified after creation."""
optim = OptimizerConfig(lr=0.1)
critic_config = CriticConfig(ppo_micro_batch_size_per_gpu=1, strategy="fsdp2", optim=optim)
critic_config = CriticConfig(ppo_micro_batch_size_per_gpu=1, strategy="fsdp2", optim=OptimizerConfig(lr=0.1))
frozen_fields = ["rollout_n", "strategy", "cliprange_value"]

for field_name in frozen_fields:
with pytest.raises((AttributeError, TypeError, ValueError)):
setattr(critic_config, field_name, "modified_value")

megatron_config = McoreCriticConfig(ppo_micro_batch_size_per_gpu=1, optim=optim)
megatron_config = McoreCriticConfig(ppo_micro_batch_size_per_gpu=1, optim=McoreOptimizerConfig(lr=0.1))
megatron_frozen_fields = ["nccl_timeout", "load_weight", "data_loader_seed"]

for field_name in megatron_frozen_fields:
with pytest.raises((AttributeError, TypeError, ValueError)):
setattr(megatron_config, field_name, "modified_value")

fsdp_config = FSDPCriticConfig(ppo_micro_batch_size_per_gpu=1, optim=optim)
fsdp_config = FSDPCriticConfig(ppo_micro_batch_size_per_gpu=1, optim=FSDPOptimizerConfig(lr=0.1))
fsdp_frozen_fields = ["ulysses_sequence_parallel_size", "grad_clip"]

for field_name in fsdp_frozen_fields:
Expand All @@ -171,7 +171,7 @@ def test_batch_size_fields_modifiable(self):
assert critic_config.ppo_micro_batch_size == 4
assert critic_config.ppo_micro_batch_size_per_gpu == 2

fsdp_config = FSDPCriticConfig(ppo_micro_batch_size_per_gpu=1, optim=optim)
fsdp_config = FSDPCriticConfig(ppo_micro_batch_size_per_gpu=1, optim=FSDPOptimizerConfig(lr=0.1))

fsdp_config.forward_micro_batch_size = 16
fsdp_config.forward_micro_batch_size_per_gpu = 8
Expand Down Expand Up @@ -277,12 +277,11 @@ def test_micro_batch_size_divisibility_validation(self):

def test_fsdp_sequence_parallelism_validation(self):
"""Test FSDP sequence parallelism validation in FSDPCriticConfig.__post_init__."""
optim = OptimizerConfig(lr=0.1)
valid_config = FSDPCriticConfig(
ppo_micro_batch_size_per_gpu=2,
ulysses_sequence_parallel_size=2,
model={"use_remove_padding": True},
optim=optim,
optim=FSDPOptimizerConfig(lr=0.1),
)
assert valid_config.ulysses_sequence_parallel_size == 2

Expand All @@ -293,13 +292,13 @@ def test_fsdp_sequence_parallelism_validation(self):
ppo_micro_batch_size_per_gpu=2,
ulysses_sequence_parallel_size=2,
model={"use_remove_padding": False},
optim=optim,
optim=FSDPOptimizerConfig(lr=0.1),
)

valid_config_no_sp = FSDPCriticConfig(
ppo_micro_batch_size_per_gpu=2,
ulysses_sequence_parallel_size=1,
model={"use_remove_padding": False},
optim=optim,
optim=FSDPOptimizerConfig(lr=0.1),
)
assert valid_config_no_sp.ulysses_sequence_parallel_size == 1
4 changes: 2 additions & 2 deletions tests/workers/critic/test_special_dp_critic.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from transformers import AutoConfig

from verl import DataProto
from verl.workers.config import FSDPCriticConfig, OptimizerConfig
from verl.workers.config import FSDPCriticConfig, FSDPOptimizerConfig
from verl.workers.config.critic import FSDPCriticModelCfg
from verl.workers.config.engine import FSDPEngineConfig
from verl.workers.fsdp_workers import CriticWorker
Expand Down Expand Up @@ -72,7 +72,7 @@ def setUp(self):
use_dynamic_bsz=False,
ulysses_sequence_parallel_size=1,
rollout_n=1,
optim=OptimizerConfig(lr=1e-6),
optim=FSDPOptimizerConfig(lr=1e-6),
model=FSDPCriticModelCfg(
path="Qwen/Qwen2.5-0.5B-Instruct",
tokenizer_path="Qwen/Qwen2.5-0.5B-Instruct",
Expand Down
6 changes: 6 additions & 0 deletions verl/trainer/config/_generated_ppo_trainer.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ actor_rollout_ref:
actor:
optim:
_target_: verl.workers.config.FSDPOptimizerConfig
optimizer: AdamW
optimizer_impl: torch.optim
lr: 1.0e-06
lr_warmup_steps_ratio: 0.0
total_training_steps: -1
Expand All @@ -19,6 +21,7 @@ actor_rollout_ref:
min_lr_ratio: 0.0
num_cycles: 0.5
warmup_style: constant
override_optimizer_config: null
fsdp_config:
_target_: verl.workers.config.FSDPEngineConfig
wrap_policy:
Expand Down Expand Up @@ -302,6 +305,8 @@ data:
critic:
optim:
_target_: verl.workers.config.FSDPOptimizerConfig
optimizer: AdamW
optimizer_impl: torch.optim
lr: 1.0e-05
lr_warmup_steps_ratio: 0.0
total_training_steps: -1
Expand All @@ -314,6 +319,7 @@ critic:
min_lr_ratio: 0.0
num_cycles: 0.5
warmup_style: constant
override_optimizer_config: null
model:
fsdp_config:
_target_: verl.workers.config.FSDPEngineConfig
Expand Down
15 changes: 15 additions & 0 deletions verl/trainer/config/optim/fsdp.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,13 @@
# Target class for this configuration
_target_: verl.workers.config.FSDPOptimizerConfig

# Optimizer class name (e.g., "AdamW", "AdamW8bit", "_AdamW", "Adam")
optimizer: AdamW

# Module path to import optimizer
# Examples: "torch.optim", "torchao.optim", "bitsandbytes.optim"
optimizer_impl: torch.optim

# Learning rate
lr: 1e-3

Expand Down Expand Up @@ -31,3 +38,11 @@ num_cycles: 0.5
# LR warmup style: "constant" or "cosine"
warmup_style: constant

# Additional optimizer-specific keyword arguments
# Example for torchao with bf16 stochastic rounding:
# optimizer_impl: torchao.optim
# optimizer: _AdamW
# override_optimizer_config:
# bf16_stochastic_round: true
override_optimizer_config: null

6 changes: 5 additions & 1 deletion verl/trainer/config/sft_trainer.yaml
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
defaults:
- optim: fsdp
- _self_

data:
train_batch_size: 256
micro_batch_size: null # will be deprecated, use micro_batch_size_per_gpu
Expand Down Expand Up @@ -45,7 +49,7 @@ optim:
lr: 1e-5
betas: [0.9, 0.95]
weight_decay: 0.01
warmup_steps_ratio: 0.1
lr_warmup_steps_ratio: 0.1
clip_grad: 1.0
lr_scheduler: cosine
ulysses_sequence_parallel_size: 1
Expand Down
14 changes: 4 additions & 10 deletions verl/trainer/fsdp_sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
from omegaconf import DictConfig, OmegaConf
from peft import LoraConfig, TaskType, get_peft_model
from tensordict import TensorDict
from torch import nn, optim
from torch import nn
from torch.distributed.device_mesh import DeviceMesh, init_device_mesh
from torch.distributed.fsdp import CPUOffload, MixedPrecision, ShardingStrategy
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
Expand Down Expand Up @@ -73,6 +73,7 @@
get_ulysses_sequence_parallel_world_size,
ulysses_pad_and_slice_inputs,
)
from verl.workers.config.optimizer import build_optimizer
from verl.workers.sharding_manager.fsdp_ulysses import FSDPUlyssesShardingManager

logger = logging.getLogger(__file__)
Expand Down Expand Up @@ -317,14 +318,7 @@ def _build_model_optimizer(self):

log_gpu_memory_usage("After FSDP wrapping", logger=logger)

self.optimizer = optim.AdamW(
self.fsdp_model.parameters(),
lr=self.config.optim.lr,
betas=self.config.optim.betas,
weight_decay=self.config.optim.weight_decay,
eps=self.config.optim.get("eps", 1e-08),
fused=True,
)
self.optimizer = build_optimizer(self.fsdp_model.parameters(), self.config.optim)

log_gpu_memory_usage("After initialize optimizer", logger=logger)

Expand All @@ -337,7 +331,7 @@ def _build_model_optimizer(self):
f"{self.config.trainer.total_epochs}, total number of steps {self.total_steps}"
)

num_warmup_steps = int(self.total_steps * self.config.optim.warmup_steps_ratio)
num_warmup_steps = int(self.total_steps * self.config.optim.lr_warmup_steps_ratio)

if not hasattr(self.config.optim, "lr_scheduler") or self.config.optim.lr_scheduler == "cosine":
self.lr_scheduler = get_cosine_schedule_with_warmup(
Expand Down
Loading
Loading