Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions verl/trainer/config/ppo_trainer.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -441,6 +441,9 @@ actor_rollout_ref:
# number of responses (i.e. num sample times). > 1 for grpo
n: 1

# Whether to wake up inference engine in multi-stage. (Wake up model weights first, then resume kv cache)
multi_stage_wake_up: false

# Extra inference engine arguments (vllm, sglang).
engine_kwargs:

Expand Down
1 change: 1 addition & 0 deletions verl/workers/fsdp_workers.py
Original file line number Diff line number Diff line change
Expand Up @@ -478,6 +478,7 @@ def _build_rollout(self, trust_remote_code=False):
full_params="hf" in self.config.rollout.load_format,
device_mesh=rollout_device_mesh,
offload_param=self._is_offload_param,
multi_stage_wake_up=self.config.rollout.multi_stage_wake_up,
)
log_gpu_memory_usage("After building sharding manager", logger=logger)

Expand Down
16 changes: 11 additions & 5 deletions verl/workers/rollout/sglang_rollout/sglang_rollout.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,21 +127,27 @@ def __init__(self, **kwargs):
# default to use dummy load format, which need to reload weights in first time
self._need_reload = True

async def release_memory_occupation(self):
async def release_memory_occupation(self, tags: Optional[list[str]] = None):
"""Release GPU occupation temporarily."""
obj = ReleaseMemoryOccupationReqInput()
if tags is None:
obj = ReleaseMemoryOccupationReqInput()
else:
obj = ReleaseMemoryOccupationReqInput(tags=tags)
return await self.tokenizer_manager.release_memory_occupation(obj, None)

async def resume_memory_occupation(self):
async def resume_memory_occupation(self, tags: Optional[list[str]] = None):
"""Resume GPU occupation."""

# because __init__ is a sync method, it can not call the async release_memory_occupation
# have to move release_memory_occupation from __init__ to here
# For multi-stage awake, we run release weight and kv_cache when we resume weights for the first time.
if self._need_reload:
await self.release_memory_occupation()
self._need_reload = False

obj = ResumeMemoryOccupationReqInput()
if tags is None:
obj = ResumeMemoryOccupationReqInput()
else:
obj = ResumeMemoryOccupationReqInput(tags=tags)
return await self.tokenizer_manager.resume_memory_occupation(obj, None)

async def update_weights_from_tensor(
Expand Down
20 changes: 16 additions & 4 deletions verl/workers/sharding_manager/fsdp_sglang.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,12 +59,14 @@ def __init__(
full_params: bool = False,
device_mesh: DeviceMesh = None,
offload_param: bool = False,
multi_stage_wake_up: bool = False,
):
self.module = module
self.inference_engine = inference_engine
self.model_config = model_config
self.device_mesh = device_mesh
self.offload_param = offload_param
self.multi_stage_wake_up = multi_stage_wake_up

# Full params
self.full_params = full_params
Expand Down Expand Up @@ -95,7 +97,17 @@ def __init__(
def __enter__(self):
self.timing = {}
with simple_timer("reshard", self.timing):
loop = asyncio.get_event_loop()

if self.device_mesh["infer_tp"].get_local_rank() == 0:
if self.multi_stage_wake_up:
loop.run_until_complete(self.inference_engine.resume_memory_occupation(tags=["weights"]))
log_gpu_memory_usage("Before resume SGLang weights in sharding manager", logger=logger)
else:
loop.run_until_complete(self.inference_engine.resume_memory_occupation())
log_gpu_memory_usage("Before resume SGLang weights + kv_cache in sharding manager", logger=logger)
get_torch_device().empty_cache()

log_gpu_memory_usage("Before state_dict() in sharding manager memory", logger=logger)
if self.offload_param:
load_fsdp_model_to_gpu(self.module)
Expand All @@ -105,7 +117,6 @@ def __enter__(self):
params = {k: v.to(device, non_blocking=True) if fsdp_version(self.module) == 2 else v for k, v in params.items()}
params = convert_weight_keys(params, getattr(self.module, "_fsdp_wrapped_module", self.module))
# Copy, not share memory
loop = asyncio.get_event_loop()
loop.run_until_complete(self.update_weights(params))
log_gpu_memory_usage("After sync model weights in sharding manager", logger=logger)

Expand All @@ -115,6 +126,10 @@ def __enter__(self):
get_torch_device().empty_cache()
log_gpu_memory_usage("After del state_dict and empty_cache in sharding manager", logger=logger)

if self.multi_stage_wake_up:
loop.run_until_complete(self.inference_engine.resume_memory_occupation(tags=["kv_cache"]))
log_gpu_memory_usage("After resume SGLang kv_cache in sharding manager", logger=logger)

# important: need to manually set the random states of each tp to be identical.
if self.device_mesh is not None:
self.torch_random_states = get_torch_device().get_rng_state()
Expand All @@ -138,9 +153,6 @@ def __exit__(self, exc_type, exc_value, traceback):
get_torch_device().set_rng_state(self.torch_random_states)

async def update_weights(self, params):
if self.device_mesh["infer_tp"].get_local_rank() == 0:
await self.inference_engine.resume_memory_occupation()

# Most naive implementation, can optimize a lot if it is bottleneck from sglang Engine weight update
named_tensors = [(k, v) for k, v in params.items()]
load_format = None
Expand Down