Skip to content

Commit 9910f16

Browse files
committed
fix
1 parent 4c2ff9f commit 9910f16

File tree

1 file changed

+8
-4
lines changed

1 file changed

+8
-4
lines changed

verl/workers/rollout/sglang_rollout/sglang_rollout.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -129,21 +129,25 @@ def __init__(self, **kwargs):
129129

130130
async def release_memory_occupation(self, tags: Optional[list[str]] = None):
131131
"""Release GPU occupation temporarily."""
132-
print(f"release_memory_occupation with tags: {tags}")
133-
obj = ReleaseMemoryOccupationReqInput(tags=tags)
132+
if tags is None:
133+
obj = ReleaseMemoryOccupationReqInput()
134+
else:
135+
obj = ReleaseMemoryOccupationReqInput(tags=tags)
134136
return await self.tokenizer_manager.release_memory_occupation(obj, None)
135137

136138
async def resume_memory_occupation(self, tags: Optional[list[str]] = None):
137139
"""Resume GPU occupation."""
138-
print(f"resume_memory_occupation with tags: {tags}")
139140
# because __init__ is a sync method, it can not call the async release_memory_occupation
140141
# have to move release_memory_occupation from __init__ to here
141142
# For multi-stage awake, we run release weight and kv_cache when we resume weights for the first time.
142143
if self._need_reload:
143144
await self.release_memory_occupation()
144145
self._need_reload = False
145146

146-
obj = ResumeMemoryOccupationReqInput(tags=tags)
147+
if tags is None:
148+
obj = ResumeMemoryOccupationReqInput()
149+
else:
150+
obj = ResumeMemoryOccupationReqInput(tags=tags)
147151
return await self.tokenizer_manager.resume_memory_occupation(obj, None)
148152

149153
async def update_weights_from_tensor(

0 commit comments

Comments
 (0)