Skip to content

Commit 599eabe

Browse files
chenhaiqvermouth1992
authored andcommitted
[fsdp] fix: set _set_allocator_settings to True to avoid fsdp2 oom (volcengine#3020)
### What does this PR do? Enable expandable_segments to avoid the increasing memory fragmentation caused by temporary variables during the training process of fsdp2, which may trigger probabilistic out-of-memory (OOM) errors. Since both sglang and vllm can not work with expandable_segments:True, it has to be turn off during rollout. ### Test Without this fix, memory reserved could be very high after compute_log_prob or update_actor. ``` (WorkerDict pid=339320) [2025-08-11 17:43:01] dp actor After compute_log_prob, memory allocated (GB): 5.53, memory reserved (GB): 73.59, device memory used/total (GB): 77.47/79.15 ``` With this fix, it stays low during training. ``` (WorkerDict pid=396879) [2025-08-12 07:39:42] dp actor After compute_log_prob, memory allocated (GB): 4.95, memory reserved (GB): 14.20, device memory used/total (GB): 17.72/79.15 ``` --------- Co-authored-by: narutolhy <[email protected]>" Co-authored-by: Chi Zhang <[email protected]>
1 parent ef49344 commit 599eabe

File tree

3 files changed

+28
-2
lines changed

3 files changed

+28
-2
lines changed

verl/utils/device.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,3 +84,12 @@ def get_nccl_backend() -> str:
8484
return "hccl"
8585
else:
8686
raise RuntimeError(f"No available nccl backend found on device type {get_device_name()}.")
87+
88+
89+
def set_expandable_segments(enable: bool) -> None:
90+
"""Enable or disable expandable segments for cuda.
91+
Args:
92+
enable (bool): Whether to enable expandable segments. Used to avoid OOM.
93+
"""
94+
if is_cuda_available:
95+
torch.cuda.memory._set_allocator_settings(f"expandable_segments:{enable}")

verl/workers/sharding_manager/fsdp_sglang.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626

2727
from verl import DataProto
2828
from verl.protocol import all_gather_data_proto
29-
from verl.utils.device import get_device_id, get_torch_device
29+
from verl.utils.device import get_device_id, get_torch_device, set_expandable_segments
3030
from verl.utils.fsdp_utils import fsdp_version, load_fsdp_model_to_gpu, offload_fsdp_model_to_cpu
3131
from verl.utils.model import convert_weight_keys
3232
from verl.utils.profiler import GPUMemoryLogger, log_gpu_memory_usage, simple_timer
@@ -144,6 +144,10 @@ async def wake_up(self):
144144

145145
log_gpu_memory_usage("After offload_param in sharding manager memory", logger=logger)
146146

147+
# sglang need to set _set_allocator_settings to False
148+
logger.debug("fsdp sglang sharding_manager _set_allocator_settings to False")
149+
set_expandable_segments(False)
150+
147151
if self.device_mesh["infer_tp"].get_local_rank() == 0 and self.rollout_config.free_cache_engine:
148152
if self.multi_stage_wake_up:
149153
await self.inference_engine.resume_memory_occupation(tags=["weights"])
@@ -185,6 +189,11 @@ async def sleep(self):
185189
# add empty cache after each compute
186190
get_torch_device().empty_cache()
187191

192+
# always set _set_allocator_settings to True when using sglang
193+
# it is required by fsdp2 to avoid oom
194+
logger.debug("fsdp sglang sharding_manager _set_allocator_settings to True")
195+
set_expandable_segments(True)
196+
188197
# restore random states
189198
if self.device_mesh is not None:
190199
self.gen_random_states = get_torch_device().get_rng_state()

verl/workers/sharding_manager/fsdp_vllm.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
from verl.protocol import all_gather_data_proto
3535
from verl.third_party.vllm import LLM
3636
from verl.third_party.vllm import parallel_state as vllm_ps
37-
from verl.utils.device import get_device_id, get_device_name, get_torch_device
37+
from verl.utils.device import get_device_id, get_device_name, get_torch_device, set_expandable_segments
3838
from verl.utils.fsdp_utils import (
3939
fsdp_version,
4040
layered_summon_lora_params,
@@ -210,6 +210,10 @@ def __collect_lora_params() -> OrderedDict:
210210
offload_fsdp_model_to_cpu(self.module)
211211
log_gpu_memory_usage("After state_dict() in sharding manager memory", logger=logger)
212212

213+
# vllm need to set _set_allocator_settings to False
214+
logger.debug("fsdp vllm sharding_manager _set_allocator_settings to False")
215+
set_expandable_segments(False)
216+
213217
if self.rollout_config.free_cache_engine:
214218
if "tags" in inspect.signature(self.inference_engine.wake_up).parameters:
215219
self.inference_engine.wake_up(tags=["weights"])
@@ -245,6 +249,10 @@ def __exit__(self, exc_type, exc_value, traceback):
245249
# add empty cache after each compute
246250
get_torch_device().empty_cache()
247251

252+
# _set_allocator_settings to True is required by fsdp2 to avoid oom
253+
logger.debug("fsdp vllm sharding_manager _set_allocator_settings to True")
254+
set_expandable_segments(True)
255+
248256
# restore random states
249257
if self.device_mesh is not None:
250258
self.gen_random_states = get_torch_device().get_rng_state()

0 commit comments

Comments
 (0)