Skip to content

Commit 391b912

Browse files
committed
move FSDP param load/offload into sharding manager
1 parent da4222b commit 391b912

File tree

8 files changed

+88
-85
lines changed

8 files changed

+88
-85
lines changed

recipe/dapo/src/config/dapo_trainer.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@ actor_rollout_ref:
8181
ulysses_sequence_parallel_size: ${actor_rollout_ref.actor.ulysses_sequence_parallel_size} # sp size
8282
rollout:
8383
name: vllm
84+
mode: "sync" # sync: LLM, async: AsyncLLM
8485
temperature: 1.0
8586
top_k: -1 # 0 for hf rollout, -1 for vllm rollout
8687
top_p: 1

tests/rollout/test_vllm_multi_turn.py

Lines changed: 30 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -20,35 +20,40 @@
2020
from openai.types.chat.chat_completion import ChatCompletion
2121

2222
from verl.single_controller.ray import RayClassWithInitArgs, RayWorkerGroup
23-
from verl.single_controller.ray.base import Worker, create_colocated_worker_cls
23+
from verl.single_controller.ray.base import create_colocated_worker_cls
2424
from verl.trainer.ppo.ray_trainer import ResourcePoolManager, Role
2525
from verl.workers.fsdp_async_workers import AsyncActorRolloutRefWorker, AsyncLLMManager
2626
from verl.workers.rollout.chat_scheduler import ChatCompletionScheduler
2727

2828

2929
async def test_vllm_multi_turn():
3030
config = OmegaConf.load("verl/trainer/config/ppo_trainer.yaml")
31-
model_path = "Qwen/Qwen2-7B-Instruct"
31+
model_path = "/mnt/bn/wuxibin-hl-dev/cache/Qwen/Qwen2-7B-Instruct"
3232
model_name = "/".join(model_path.split("/")[-2:])
3333
config.actor_rollout_ref.model.path = model_path
3434
config.actor_rollout_ref.rollout.mode = "async"
3535
config.actor_rollout_ref.rollout.prompt_length = 4096
3636
config.actor_rollout_ref.rollout.response_length = 4096
3737

38+
# test sleep/wake_up with fsdp offload
39+
config.actor_rollout_ref.actor.fsdp_config.param_offload = True
40+
config.actor_rollout_ref.actor.fsdp_config.optimizer_offload = True
41+
3842
# =========================== 1. Create hybrid ActorRollout workers ===========================
3943
ray.init(
4044
runtime_env={
41-
'env_vars': {
42-
'TOKENIZERS_PARALLELISM': 'true',
43-
'NCCL_DEBUG': 'WARN',
44-
'VLLM_LOGGING_LEVEL': 'WARN',
45-
'VLLM_USE_V1': '1',
45+
"env_vars": {
46+
"TOKENIZERS_PARALLELISM": "true",
47+
"NCCL_DEBUG": "WARN",
48+
"VLLM_LOGGING_LEVEL": "WARN",
49+
"VLLM_USE_V1": "1",
4650
}
47-
})
51+
}
52+
)
4853
role_worker_mapping = {
4954
Role.ActorRollout: ray.remote(AsyncActorRolloutRefWorker),
5055
}
51-
global_pool_id = 'global_pool'
56+
global_pool_id = "global_pool"
5257
resource_pool_spec = {
5358
global_pool_id: [config.trainer.n_gpus_per_node] * config.trainer.nnodes,
5459
}
@@ -61,20 +66,20 @@ async def test_vllm_multi_turn():
6166

6267
# create actor and rollout
6368
resource_pool = resource_pool_manager.get_resource_pool(Role.ActorRollout)
64-
actor_rollout_cls = RayClassWithInitArgs(cls=role_worker_mapping[Role.ActorRollout],
65-
config=config.actor_rollout_ref,
66-
role='actor_rollout')
67-
resource_pool_to_cls[resource_pool]['actor_rollout'] = actor_rollout_cls
69+
actor_rollout_cls = RayClassWithInitArgs(
70+
cls=role_worker_mapping[Role.ActorRollout], config=config.actor_rollout_ref, role="actor_rollout"
71+
)
72+
resource_pool_to_cls[resource_pool]["actor_rollout"] = actor_rollout_cls
6873

6974
all_wg = {}
7075
wg_dicts = []
7176
for resource_pool, class_dict in resource_pool_to_cls.items():
72-
worker_dict_cls = create_colocated_worker_cls(class_dict=class_dict, worker_cls=Worker)
77+
worker_dict_cls = create_colocated_worker_cls(class_dict=class_dict)
7378
wg_dict = RayWorkerGroup(resource_pool=resource_pool, ray_cls_with_init=worker_dict_cls)
7479
spawn_wg = wg_dict.spawn(prefix_set=class_dict.keys())
7580
all_wg.update(spawn_wg)
7681
wg_dicts.append(wg_dict)
77-
actor_rollout_wg = all_wg['actor_rollout']
82+
actor_rollout_wg = all_wg["actor_rollout"]
7883
actor_rollout_wg.init_model()
7984

8085
# =========================== 2. Create AsyncLLMManager&ChatScheduler ===========================
@@ -89,6 +94,10 @@ async def test_vllm_multi_turn():
8994
server_addresses=async_rollout_manager.server_addresses,
9095
)
9196

97+
# test sleep and wake_up
98+
async_rollout_manager.sleep()
99+
async_rollout_manager.wake_up()
100+
92101
# =========================== 3. Multi turn rollout ===========================
93102
async def callback(completions: ChatCompletion, info: Dict[str, Any]):
94103
messages, round = info["messages"], info["round"]
@@ -101,10 +110,7 @@ async def callback(completions: ChatCompletion, info: Dict[str, Any]):
101110
messages.append({"role": "user", "content": "What is your name?"})
102111
await async_chat_scheduler.submit_chat_completions(
103112
callback=callback,
104-
callback_additional_info={
105-
"messages": messages,
106-
"round": 1
107-
},
113+
callback_additional_info={"messages": messages, "round": 1},
108114
model=model_name,
109115
messages=messages,
110116
extra_headers=extra_headers,
@@ -113,27 +119,20 @@ async def callback(completions: ChatCompletion, info: Dict[str, Any]):
113119
messages.append({"role": "user", "content": "What is your favorite color?"})
114120
await async_chat_scheduler.submit_chat_completions(
115121
callback=callback,
116-
callback_additional_info={
117-
"messages": messages,
118-
"round": 2
119-
},
122+
callback_additional_info={"messages": messages, "round": 2},
120123
model=model_name,
121124
messages=messages,
122125
extra_headers=extra_headers,
123126
)
124127
else:
125128
print("Done!")
126129

127-
messages = [{
128-
"role": "user",
129-
"content": "Let's play a role playing game. Your name is Bob, your favorite color is red."
130-
}]
130+
messages = [
131+
{"role": "user", "content": "Let's play a role playing game. Your name is Bob, your favorite color is red."}
132+
]
131133
await async_chat_scheduler.submit_chat_completions(
132134
callback=callback,
133-
callback_additional_info={
134-
"messages": messages,
135-
"round": 0
136-
},
135+
callback_additional_info={"messages": messages, "round": 0},
137136
model=model_name,
138137
messages=messages,
139138
)

verl/single_controller/ray/base.py

Lines changed: 22 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -329,10 +329,9 @@ def from_detached(
329329
worker_names=None,
330330
ray_cls_with_init=None,
331331
):
332-
worker_group = cls(resource_pool=None,
333-
ray_cls_with_init=ray_cls_with_init,
334-
name_prefix=name_prefix,
335-
worker_names=worker_names)
332+
worker_group = cls(
333+
resource_pool=None, ray_cls_with_init=ray_cls_with_init, name_prefix=name_prefix, worker_names=worker_names
334+
)
336335
return worker_group
337336

338337
def spawn(self, prefix_set):
@@ -457,12 +456,13 @@ def func(self, *args, **kwargs):
457456
try:
458457
# bind direct rollout method to class without prefix
459458
if attrs["dispatch_mode"] == Dispatch.DIRECT_ROLLOUT_METHOD and "rollout" in key:
460-
assert not hasattr(cls, method_name), \
459+
assert not hasattr(cls, method_name), (
461460
f"conflict direct rollout method {method_name} with role {key}"
461+
)
462462
setattr(cls, method_name, func)
463463
print(f"bind role {key} method {method_name} to class {cls}")
464464
else:
465-
method_name_with_prefix = key + '_' + method_name
465+
method_name_with_prefix = key + "_" + method_name
466466
setattr(cls, method_name_with_prefix, func)
467467
except Exception as e:
468468
raise ValueError(f"Fail to set method_name {method_name}")
@@ -474,32 +474,31 @@ def _unwrap_ray_remote(cls):
474474
return cls
475475

476476

477-
def _nearest_common_base(mros: List):
478-
last_common = object
479-
min_len = min([len(mro) for mro in mros]) - 1 # exclude final derived class
480-
481-
for i in range(min_len):
482-
mro = mros[0][i]
483-
for j in range(1, len(mros)):
484-
if mro != mros[j][i]:
485-
return last_common
486-
last_common = mro
487-
488-
return last_common
477+
def _determine_fsdp_megatron_base_class(mros: List):
478+
"""
479+
- megatron: base class should be MegatronWorker
480+
- fsdp: base class should be Worker
481+
"""
482+
for cls in mros[0]:
483+
if cls.__name__ == "MegatronWorker":
484+
return cls
485+
if cls.__name__ == "Worker":
486+
return cls
487+
raise ValueError(f"Cannot determine base class for {mros}")
489488

490489

491-
def create_colocated_worker_cls(class_dict: dict[str, RayClassWithInitArgs], worker_cls: type = None):
490+
def create_colocated_worker_cls(class_dict: dict[str, RayClassWithInitArgs]):
492491
"""
493492
This function should return a class instance that delegates the calls to every
494493
cls in cls_dict
495494
"""
496495
cls_dict = {}
497496
init_args_dict = {}
498-
if worker_cls is None:
499-
worker_cls = _nearest_common_base(
500-
[list(reversed(cls.cls.__ray_actor_class__.__mro__)) for cls in class_dict.values()])
497+
worker_cls = _determine_fsdp_megatron_base_class(
498+
[cls.cls.__ray_actor_class__.__mro__ for cls in class_dict.values()]
499+
)
501500
assert issubclass(worker_cls, Worker), f"worker_cls {worker_cls} should be a subclass of Worker"
502-
print(f"find nearest common base class {worker_cls}")
501+
print(f"colocated worker base class {worker_cls}")
503502

504503
for key, cls in class_dict.items():
505504
cls_dict[key] = cls.cls

verl/trainer/config/generation.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ model:
1414
external_lib: null
1515
rollout:
1616
name: vllm
17+
mode: "sync" # sync: LLM, async: AsyncLLM
1718
temperature: 1.0
1819
top_k: 50 # 0 for hf rollout, -1 for vllm rollout
1920
top_p: 0.7

verl/trainer/config/ppo_megatron_trainer.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,7 @@ actor_rollout_ref:
9595
log_prob_micro_batch_size_per_gpu: null
9696
rollout:
9797
name: vllm
98+
mode: "sync" # sync: LLM, async: AsyncLLM
9899
temperature: 1.0
99100
top_k: -1 # 0 for hf rollout, -1 for vllm rollout
100101
top_p: 1

verl/workers/fsdp_workers.py

Lines changed: 16 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -353,18 +353,22 @@ def _build_rollout(self, trust_remote_code=False):
353353
log_gpu_memory_usage(f"Before building {rollout_name} rollout", logger=None)
354354
local_path = copy_to_local(self.config.model.path)
355355
if vllm_mode == "customized":
356-
rollout = vLLMRollout(actor_module=self.actor_module_fsdp,
357-
config=self.config.rollout,
358-
tokenizer=self.tokenizer,
359-
model_hf_config=self.actor_model_config)
356+
rollout = vLLMRollout(
357+
actor_module=self.actor_module_fsdp,
358+
config=self.config.rollout,
359+
tokenizer=self.tokenizer,
360+
model_hf_config=self.actor_model_config,
361+
)
360362
elif vllm_mode == "spmd":
361363
vllm_rollout_cls = vLLMRollout if self.config.rollout.mode == "sync" else vLLMAsyncRollout
362-
rollout = vllm_rollout_cls(model_path=local_path,
363-
config=self.config.rollout,
364-
tokenizer=self.tokenizer,
365-
model_hf_config=self.actor_model_config,
366-
device_mesh=rollout_device_mesh,
367-
trust_remote_code=trust_remote_code)
364+
rollout = vllm_rollout_cls(
365+
model_path=local_path,
366+
config=self.config.rollout,
367+
tokenizer=self.tokenizer,
368+
model_hf_config=self.actor_model_config,
369+
device_mesh=rollout_device_mesh,
370+
trust_remote_code=trust_remote_code,
371+
)
368372
else:
369373
raise NotImplementedError("vllm_mode must be 'customized' or 'spmd'")
370374
log_gpu_memory_usage(f"After building {rollout_name} rollout", logger=None)
@@ -376,6 +380,7 @@ def _build_rollout(self, trust_remote_code=False):
376380
model_config=self.actor_model_config,
377381
full_params="hf" in self.config.rollout.load_format,
378382
device_mesh=rollout_device_mesh,
383+
offload_param=self._is_offload_param,
379384
)
380385
log_gpu_memory_usage("After building sharding manager", logger=None)
381386

@@ -407,6 +412,7 @@ def _build_rollout(self, trust_remote_code=False):
407412
model_config=self.actor_model_config,
408413
full_params="hf" in self.config.rollout.load_format,
409414
device_mesh=rollout_device_mesh,
415+
offload_param=self._is_offload_param,
410416
)
411417
log_gpu_memory_usage("After building sharding manager", logger=None)
412418

@@ -546,8 +552,6 @@ def generate_sequences(self, prompts: DataProto):
546552
prompts = prompts.to(torch.cuda.current_device())
547553

548554
assert self._is_rollout
549-
if self._is_offload_param:
550-
load_fsdp_model_to_gpu(self.actor_module_fsdp)
551555

552556
meta_info = {
553557
"eos_token_id": self.generation_config.eos_token_id
@@ -559,12 +563,6 @@ def generate_sequences(self, prompts: DataProto):
559563
}
560564
prompts.meta_info.update(meta_info)
561565
with self.rollout_sharding_manager:
562-
# after parameters sync with rollout, offload actor model to CPU
563-
if self._is_offload_param:
564-
offload_fsdp_model_to_cpu(self.actor_module_fsdp)
565-
if self._is_offload_optimizer:
566-
offload_fsdp_optimizer(optimizer=self.actor_optimizer)
567-
568566
log_gpu_memory_usage("After entering rollout sharding manager", logger=logger)
569567

570568
prompts = self.rollout_sharding_manager.preprocess_data(prompts)

verl/workers/sharding_manager/fsdp_sglang.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
from verl import DataProto
3838
from verl.protocol import all_gather_data_proto
3939
from verl.utils.debug import log_gpu_memory_usage
40+
from verl.utils.fsdp_utils import load_fsdp_model_to_gpu, offload_fsdp_model_to_cpu
4041
from verl.utils.torch_functional import broadcast_dict_tensor
4142

4243
from .base import BaseShardingManager
@@ -55,11 +56,13 @@ def __init__(
5556
model_config,
5657
full_params: bool = False,
5758
device_mesh: DeviceMesh = None,
59+
offload_param: bool = False,
5860
):
5961
self.module = module
6062
self.inference_engine = inference_engine
6163
self.model_config = model_config
6264
self.device_mesh = device_mesh
65+
self.offload_param = offload_param
6366

6467
# Full params
6568
self.full_params = full_params
@@ -88,6 +91,8 @@ def __init__(
8891
def __enter__(self):
8992
torch.cuda.empty_cache()
9093
log_gpu_memory_usage("Before state_dict() in sharding manager memory", logger=logger)
94+
if self.offload_param:
95+
load_fsdp_model_to_gpu(self.module)
9196
params = self.module.state_dict()
9297
log_gpu_memory_usage("After state_dict() in sharding manager memory", logger=logger)
9398
# Copy, not share memory
@@ -98,15 +103,11 @@ def __enter__(self):
98103
log_gpu_memory_usage("After sync model weights in sharding manager", logger=logger)
99104

100105
del params
106+
if self.offload_param:
107+
offload_fsdp_model_to_cpu(self.module)
101108
torch.cuda.empty_cache()
102109
log_gpu_memory_usage("After del state_dict and empty_cache in sharding manager", logger=logger)
103110

104-
# TODO: offload FSDP model weights
105-
# self.module.cpu()
106-
# torch.cuda.empty_cache()
107-
# if torch.distributed.get_rank() == 0:
108-
# print(f'after model to cpu in sharding manager memory allocated: {torch.cuda.memory_allocated() / 1e9}GB, reserved: {torch.cuda.memory_reserved() / 1e9}GB')
109-
110111
# important: need to manually set the random states of each tp to be identical.
111112
if self.device_mesh is not None:
112113
self.torch_random_states = torch.cuda.get_rng_state()

0 commit comments

Comments
 (0)