Skip to content

Commit f4c7285

Browse files
committed
Abstract optimizer
1 parent f50e5c2 commit f4c7285

File tree

8 files changed

+104
-37
lines changed

8 files changed

+104
-37
lines changed

docs/examples/config.rst

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -614,12 +614,17 @@ Optim
614614
.. code:: yaml
615615
616616
optim:
617+
optimizer: AdamW
618+
optimizer_impl: torch.optim
617619
lr: 1e-5
618620
weight_decay: 0.01
619621
warmup_steps_ratio: 0.1
620622
clip_grad: 1.0
621623
lr_scheduler: cosine
624+
override_optimizer_config: null
622625
626+
- ``optimizer``: Optimizer class name (e.g., ``"AdamW"``, ``"AdamW8bit"``, ``"_AdamW"``). The class name as it appears in the module.
627+
- ``optimizer_impl``: Module path to import optimizer from (e.g., ``"torch.optim"``, ``"torchao.optim"``, ``"bitsandbytes.optim"``).
623628
- ``optim.lr``: Learning rate for the optimizer.
624629
- ``optim.weight_decay``: Weight decay for the optimizer.
625630
- ``optim.warmup_steps_ratio``: Ratio of warmup steps to total training steps.
@@ -629,6 +634,8 @@ Optim
629634
- ``cosine``: Cosine learning rate scheduler with warmup (default).
630635
- ``wsd``: Warmup-Stable-Decay scheduler that provides a stable learning rate phase between warmup and decay phases.
631636

637+
- ``override_optimizer_config``: Dictionary of additional optimizer-specific keyword arguments. For example, to use ``torchao.optim``'s ``_AdamW`` with BF16 stochastic rounding: ``{"bf16_stochastic_round": true}``
638+
632639
Model
633640
~~~~~~~~~~~~
634641

recipe/prime/prime_fsdp_workers.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
)
4141
from verl.utils.import_utils import import_external_libs
4242
from verl.utils.profiler import log_gpu_memory_usage
43+
from verl.workers.config.optimizer import build_optimizer
4344
from verl.workers.fsdp_workers import create_device_mesh, get_sharding_strategy
4445
from verl.workers.sharding_manager.fsdp_ulysses import FSDPUlyssesShardingManager
4546

@@ -87,7 +88,6 @@ def __init__(self, config):
8788

8889
def _build_reward_ref_model_optimizer(self, config):
8990
# the following line is necessary
90-
from torch import optim
9191
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
9292
from torch.distributed.fsdp import MixedPrecision
9393

@@ -219,12 +219,7 @@ def _build_reward_ref_model_optimizer(self, config):
219219
cpu_offload=None,
220220
)
221221

222-
reward_optimizer = optim.AdamW(
223-
reward_module.parameters(),
224-
lr=config.model.optim.lr,
225-
betas=config.model.optim.get("betas", (0.9, 0.999)),
226-
weight_decay=config.model.optim.get("weight_decay", 1e-2),
227-
)
222+
reward_optimizer = build_optimizer(reward_module.parameters(), config.model.optim)
228223

229224
total_steps = config.model.optim.get("total_training_steps", 0)
230225
num_warmup_steps = int(config.model.optim.get("lr_warmup_steps", -1))

verl/trainer/config/_generated_ppo_trainer.yaml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@ actor_rollout_ref:
77
actor:
88
optim:
99
_target_: verl.workers.config.FSDPOptimizerConfig
10+
optimizer: AdamW
11+
optimizer_impl: torch.optim
1012
lr: 1.0e-06
1113
lr_warmup_steps_ratio: 0.0
1214
total_training_steps: -1
@@ -19,6 +21,7 @@ actor_rollout_ref:
1921
min_lr_ratio: 0.0
2022
num_cycles: 0.5
2123
warmup_style: constant
24+
override_optimizer_config: null
2225
fsdp_config:
2326
_target_: verl.workers.config.FSDPEngineConfig
2427
wrap_policy:
@@ -302,6 +305,8 @@ data:
302305
critic:
303306
optim:
304307
_target_: verl.workers.config.FSDPOptimizerConfig
308+
optimizer: AdamW
309+
optimizer_impl: torch.optim
305310
lr: 1.0e-05
306311
lr_warmup_steps_ratio: 0.0
307312
total_training_steps: -1
@@ -314,6 +319,7 @@ critic:
314319
min_lr_ratio: 0.0
315320
num_cycles: 0.5
316321
warmup_style: constant
322+
override_optimizer_config: null
317323
model:
318324
fsdp_config:
319325
_target_: verl.workers.config.FSDPEngineConfig

verl/trainer/config/optim/fsdp.yaml

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,13 @@
11
# Target class for this configuration
22
_target_: verl.workers.config.FSDPOptimizerConfig
33

4+
# Optimizer class name (e.g., "AdamW", "AdamW8bit", "_AdamW", "Adam")
5+
optimizer: AdamW
6+
7+
# Module path to import optimizer
8+
# Examples: "torch.optim", "torchao.optim", "bitsandbytes.optim"
9+
optimizer_impl: torch.optim
10+
411
# Learning rate
512
lr: 1e-3
613

@@ -31,3 +38,11 @@ num_cycles: 0.5
3138
# LR warmup style: "constant" or "cosine"
3239
warmup_style: constant
3340

41+
# Additional optimizer-specific keyword arguments
42+
# Example for torchao with bf16 stochastic rounding:
43+
# optimizer_impl: torchao.optim
44+
# optimizer: _AdamW
45+
# override_optimizer_config:
46+
# bf16_stochastic_round: true
47+
override_optimizer_config: null
48+

verl/trainer/fsdp_sft_trainer.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
from omegaconf import DictConfig, OmegaConf
3535
from peft import LoraConfig, TaskType, get_peft_model
3636
from tensordict import TensorDict
37-
from torch import nn, optim
37+
from torch import nn
3838
from torch.distributed.device_mesh import DeviceMesh, init_device_mesh
3939
from torch.distributed.fsdp import CPUOffload, MixedPrecision, ShardingStrategy
4040
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
@@ -73,6 +73,7 @@
7373
get_ulysses_sequence_parallel_world_size,
7474
ulysses_pad_and_slice_inputs,
7575
)
76+
from verl.workers.config.optimizer import build_optimizer
7677
from verl.workers.sharding_manager.fsdp_ulysses import FSDPUlyssesShardingManager
7778

7879
logger = logging.getLogger(__file__)
@@ -317,12 +318,7 @@ def _build_model_optimizer(self):
317318

318319
log_gpu_memory_usage("After FSDP wrapping", logger=logger)
319320

320-
self.optimizer = optim.AdamW(
321-
self.fsdp_model.parameters(),
322-
lr=self.config.optim.lr,
323-
betas=self.config.optim.betas,
324-
weight_decay=self.config.optim.weight_decay,
325-
)
321+
self.optimizer = build_optimizer(self.fsdp_model.parameters(), self.config.optim)
326322

327323
log_gpu_memory_usage("After initialize optimizer", logger=logger)
328324

verl/workers/config/optimizer.py

Lines changed: 64 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919

2020
from verl.base_config import BaseConfig
2121

22-
__all__ = ["OptimizerConfig", "FSDPOptimizerConfig", "McoreOptimizerConfig"]
22+
__all__ = ["OptimizerConfig", "FSDPOptimizerConfig", "McoreOptimizerConfig", "build_optimizer"]
2323

2424

2525
@dataclass
@@ -58,15 +58,22 @@ class FSDPOptimizerConfig(OptimizerConfig):
5858
"""FSDP optimizer configuration extending base OptimizerConfig.
5959
6060
Args:
61+
optimizer (str): Optimizer class name (e.g., "AdamW", "AdamW8bit", "_AdamW").
62+
optimizer_impl (str): Module path to import optimizer from (e.g., "torch.optim", "torchao.optim",
63+
"bitsandbytes.optim").
6164
lr (float): Learning rate.
6265
min_lr_ratio (Optional[float]): Minimum LR ratio for cosine schedule.
6366
warmup_style (str): LR warmup style: "constant" or "cosine".
6467
num_cycles (float): Number of cosine cycles in LR schedule.
68+
override_optimizer_config (Optional[dict]): Additional optimizer-specific keyword arguments.
6569
"""
6670

71+
optimizer: str = "AdamW"
72+
optimizer_impl: str = "torch.optim"
6773
min_lr_ratio: Optional[float] = None
6874
warmup_style: str = "constant"
6975
num_cycles: float = 0.5
76+
override_optimizer_config: Optional[dict] = None
7077

7178
def __post_init__(self):
7279
assert self.warmup_style in ["constant", "cosine"]
@@ -101,3 +108,59 @@ class McoreOptimizerConfig(OptimizerConfig):
101108
lr_wsd_decay_steps: Optional[int] = None
102109
use_checkpoint_opt_param_scheduler: bool = False
103110
override_optimizer_config: Optional[dict] = None
111+
112+
113+
def build_optimizer(parameters, config: FSDPOptimizerConfig):
114+
"""Build an optimizer based on the configuration.
115+
116+
Dynamically imports and instantiates an optimizer class from the specified module.
117+
118+
Args:
119+
parameters: Model parameters to optimize
120+
config: FSDPOptimizerConfig with optimizer settings
121+
122+
Returns:
123+
Optimizer instance
124+
125+
Examples:
126+
# PyTorch AdamW
127+
config.optimizer_impl = "torch.optim"
128+
config.optimizer = "AdamW"
129+
130+
# TorchAO AdamW with bf16 stochastic rounding
131+
config.optimizer_impl = "torchao.optim"
132+
config.optimizer = "_AdamW"
133+
config.override_optimizer_config = {"bf16_stochastic_round": True}
134+
135+
# BitsAndBytes AdamW 8bit
136+
config.optimizer_impl = "bitsandbytes.optim"
137+
config.optimizer = "AdamW8bit"
138+
"""
139+
import importlib
140+
141+
optimizer_args = {
142+
"lr": config.lr,
143+
"weight_decay": config.weight_decay,
144+
}
145+
146+
optimizer_name_lower = config.optimizer.lower()
147+
if "adam" in optimizer_name_lower or "ademamix" in optimizer_name_lower:
148+
optimizer_args["betas"] = config.betas
149+
150+
if config.override_optimizer_config is not None:
151+
optimizer_args.update(config.override_optimizer_config)
152+
153+
try:
154+
module = importlib.import_module(config.optimizer_impl)
155+
optimizer_cls = getattr(module, config.optimizer)
156+
except ImportError as e:
157+
raise ImportError(
158+
f"Failed to import module '{config.optimizer_impl}'. Make sure the package is installed. Error: {e}"
159+
) from e
160+
except AttributeError as e:
161+
raise AttributeError(
162+
f"Optimizer '{config.optimizer}' not found in module '{config.optimizer_impl}'. "
163+
f"Available optimizers: {dir(module)}"
164+
) from e
165+
166+
return optimizer_cls(parameters, **optimizer_args)

verl/workers/engine/fsdp/transformer_impl.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -354,14 +354,10 @@ def _build_fsdp_module(self, module):
354354
return module
355355

356356
def _build_optimizer(self, module):
357-
from torch import optim
357+
from verl.workers.config.optimizer import build_optimizer
358+
359+
optimizer = build_optimizer(module.parameters(), self.optimizer_config)
358360

359-
optimizer = optim.AdamW(
360-
module.parameters(),
361-
lr=self.optimizer_config.lr,
362-
betas=self.optimizer_config.betas,
363-
weight_decay=self.optimizer_config.weight_decay,
364-
)
365361
return optimizer
366362

367363
def _build_lr_scheduler(self, optimizer):

verl/workers/fsdp_workers.py

Lines changed: 4 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@
8686
from verl.utils.profiler.performance import reduce_timing, topk_reduce_ratio_min_max
8787
from verl.utils.py_functional import convert_to_regular_types
8888
from verl.workers.config import FSDPCriticConfig, FSDPEngineConfig, HFModelConfig, RolloutConfig
89+
from verl.workers.config.optimizer import build_optimizer
8990
from verl.workers.rollout import get_rollout_class
9091
from verl.workers.sharding_manager.fsdp_ulysses import FSDPUlyssesShardingManager
9192

@@ -279,7 +280,6 @@ def _build_model_optimizer(
279280
role="actor",
280281
enable_activation_offload=False,
281282
):
282-
from torch import optim
283283
from torch.distributed.fsdp import CPUOffload, MixedPrecision
284284
from transformers import (
285285
AutoConfig,
@@ -520,12 +520,7 @@ def _build_model_optimizer(
520520
if role == "actor" and optim_config is not None:
521521
from verl.utils.torch_functional import get_constant_schedule_with_warmup, get_cosine_schedule_with_warmup
522522

523-
actor_optimizer = optim.AdamW(
524-
actor_module_fsdp.parameters(),
525-
lr=optim_config.lr,
526-
betas=optim_config.get("betas", (0.9, 0.999)),
527-
weight_decay=optim_config.get("weight_decay", 1e-2),
528-
)
523+
actor_optimizer = build_optimizer(actor_module_fsdp.parameters(), optim_config)
529524

530525
total_steps = optim_config.get("total_training_steps", 0)
531526
num_warmup_steps = int(optim_config.get("lr_warmup_steps", -1))
@@ -866,7 +861,7 @@ def update_actor(self, data: DataProto):
866861
metrics["perf/cpu_memory_used_gb"] = psutil.virtual_memory().used / (1024**3)
867862

868863
lr = self.actor_lr_scheduler.get_last_lr()[0]
869-
metrics["actor/lr"] = lr
864+
metrics["actor/lr"] = lr.item() if torch.is_tensor(lr) else lr
870865
self.actor_lr_scheduler.step()
871866

872867
# TODO: here, we should return all metrics
@@ -1187,7 +1182,6 @@ def __init__(self, config: FSDPCriticConfig):
11871182

11881183
def _build_critic_model_optimizer(self, config):
11891184
# the following line is necessary
1190-
from torch import optim
11911185
from torch.distributed.fsdp import MixedPrecision
11921186

11931187
from verl.utils.model import load_valuehead_model, print_model_size
@@ -1368,12 +1362,7 @@ def _build_critic_model_optimizer(self, config):
13681362

13691363
log_gpu_memory_usage("After critic FSDP", logger=None)
13701364

1371-
critic_optimizer = optim.AdamW(
1372-
critic_module.parameters(),
1373-
lr=config.optim.lr,
1374-
betas=config.optim.get("betas", (0.9, 0.999)),
1375-
weight_decay=config.optim.get("weight_decay", 1e-2),
1376-
)
1365+
critic_optimizer = build_optimizer(critic_module.parameters(), config.optim)
13771366

13781367
total_steps = config.optim.get("total_training_steps", 0)
13791368
num_warmup_steps = int(config.optim.get("lr_warmup_steps", -1))

0 commit comments

Comments
 (0)