-
Notifications
You must be signed in to change notification settings - Fork 2.5k
[rollout] feat: support over sampling rollout in SGLang Rollout #2929
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 26 commits
953eae6
0384aef
a0ab8b0
652f0f7
3fdfd00
d37f51d
51410e7
de8feb3
da37e6a
5308507
25a15b9
07ad3f9
1b8bfa9
519a1b0
ae43c5a
45265ec
ef6ab93
264bed6
0c63925
5d3b970
89da6d7
614f0ad
08b7e02
36abd78
b979a73
b4fdfcf
88a23ce
4e2316b
3a37c6e
94d7c68
5fd6b67
1b6eb79
271b493
6e98696
a32cfb4
ea29d31
cc8fb16
60c1610
96a8ada
903c2d3
53099d1
3d7fa4a
5a733f8
4cc54b8
c67ebff
43a9c31
8e17a93
71687d7
b084dfd
b66327e
23c2957
b944230
7ee096e
97d137d
35b65eb
6f3e49e
b0f0245
2f5d4f5
4c882d1
4822564
b28f178
a794b29
cdc2eff
77e8491
4fd6b36
58f3570
23452ed
3ce9b1f
d6128a2
ba87964
96c5d75
c1e8884
06beaf0
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -118,6 +118,26 @@ def compute_data_metrics(batch: DataProto, use_critic: bool = True) -> dict[str, | |
prompt_length = response_info["prompt_length"] | ||
response_length = response_info["response_length"] | ||
|
||
# 检测被abort的请求:response_mask全为0的请求 | ||
# 被abort的请求的response部分完全没有有效token | ||
# 使用batch中的response_mask来保持与agg_loss的一致性 | ||
aborted_mask = (response_length == 0).bool() # response_length为0表示被abort | ||
|
||
non_aborted_mask = ~aborted_mask | ||
|
||
print("over sample rate in metric_utils: ", non_aborted_mask.sum() / len(non_aborted_mask)) | ||
|
||
|
||
non_aborted_sequence_score = sequence_score[non_aborted_mask] | ||
non_aborted_sequence_reward = sequence_reward[non_aborted_mask] | ||
|
||
score_mean = torch.mean(non_aborted_sequence_score).detach().item() | ||
score_max = torch.max(non_aborted_sequence_score).detach().item() | ||
score_min = torch.min(non_aborted_sequence_score).detach().item() | ||
|
||
reward_mean = torch.mean(non_aborted_sequence_reward).detach().item() | ||
reward_max = torch.max(non_aborted_sequence_reward).detach().item() | ||
reward_min = torch.min(non_aborted_sequence_reward).detach().item() | ||
zhaochenyang20 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
valid_adv = torch.masked_select(advantages, response_mask) | ||
valid_returns = torch.masked_select(returns, response_mask) | ||
|
||
|
@@ -128,14 +148,14 @@ def compute_data_metrics(batch: DataProto, use_critic: bool = True) -> dict[str, | |
return_var = torch.var(valid_returns) | ||
|
||
metrics = { | ||
# score | ||
"critic/score/mean": torch.mean(sequence_score).detach().item(), | ||
"critic/score/max": torch.max(sequence_score).detach().item(), | ||
"critic/score/min": torch.min(sequence_score).detach().item(), | ||
# reward | ||
"critic/rewards/mean": torch.mean(sequence_reward).detach().item(), | ||
"critic/rewards/max": torch.max(sequence_reward).detach().item(), | ||
"critic/rewards/min": torch.min(sequence_reward).detach().item(), | ||
# score (只对非abort请求计算平均值) | ||
"critic/score/mean": score_mean, | ||
"critic/score/max": score_max, | ||
"critic/score/min": score_min, | ||
# reward (只对非abort请求计算平均值) | ||
"critic/rewards/mean": reward_mean, | ||
"critic/rewards/max": reward_max, | ||
"critic/rewards/min": reward_min, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Does it make sense to apply the same non_aborted_mask logic to the response_length metrics as well ( mean/max/min response length)? Otherwise, the metrics might still include padded responses from aborted requests There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. make sense. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I changed accordingly. Adding the response length mean after/before abort. And I added the drop rate metric. |
||
# adv | ||
"critic/advantages/mean": torch.mean(valid_adv).detach().item(), | ||
"critic/advantages/max": torch.max(valid_adv).detach().item(), | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -83,6 +83,8 @@ | |
logger = logging.getLogger(__file__) | ||
logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN")) | ||
|
||
OVER_SAMPLE_RATE = float(os.getenv("OVER_SAMPLE_RATE", 1)) | ||
|
||
|
||
# patch to avoid issue https://github.com/sgl-project/sglang/issues/6723 | ||
def _set_envs_and_config(server_args: ServerArgs): | ||
|
@@ -170,6 +172,21 @@ async def update_weights_from_tensor(self, update_weights_request: UpdateWeights | |
async def flush_cache(self): | ||
return await self.tokenizer_manager.flush_cache() | ||
|
||
async def abort_request(self, rid: str = "", abort_all: bool = False): | ||
"""Abort a specific request or all requests. | ||
|
||
Args: | ||
rid: The request ID to abort. If empty and abort_all is False, no action is taken. | ||
abort_all: If True, abort all running requests regardless of rid. | ||
""" | ||
try: | ||
result = self.tokenizer_manager.abort_request(rid=rid, abort_all=abort_all) | ||
print(f"🔍 Abort result: {result}") | ||
return result if result is not None else {"status": "aborted"} | ||
except Exception as e: | ||
logger.error(f"Failed to abort requests: {e}") | ||
raise | ||
|
||
|
||
# NOTE(sgm): add for verl. We can optimize it by making | ||
# the dataloader yield List[int] without padding. | ||
|
@@ -872,7 +889,9 @@ async def _async_rollout_a_request( | |
# Only continue the conversation if the prompt length is not greater than max_model_len - 1, | ||
# since SGLang raises an error when max_new_tokens + 1 is greater to max_model_len (the extra | ||
# token accounts for the EOS token). | ||
if len(_req.get_generation_prompt_ids(self.processing_class)) + 1 >= self.config.max_model_len: | ||
prompt_length = len(_req.get_generation_prompt_ids(self.processing_class)) | ||
|
||
if prompt_length + 1 >= self.config.max_model_len: | ||
finish_reason_type = FinishReasonTypeEnum.LENGTH | ||
break | ||
|
||
|
@@ -1024,9 +1043,11 @@ async def _handle_engine_generate( | |
self, generation_prompt_ids: list[int], sampling_params: dict, image_data: Optional[list[Any]] = None | ||
) -> dict: | ||
max_new_tokens = min(self.config.response_length, self.config.max_model_len - len(generation_prompt_ids) - 1) | ||
|
||
kwargs = sampling_params.copy() | ||
kwargs["max_new_tokens"] = max_new_tokens | ||
kwargs["n"] = 1 # group size is supported in preprocess | ||
|
||
output = await self._engine.async_generate( | ||
input_ids=generation_prompt_ids, | ||
sampling_params=kwargs, | ||
|
@@ -1059,16 +1080,6 @@ async def _handle_pending_state(self, _req: AsyncRolloutRequest) -> AsyncRollout | |
interaction = self.interaction_map[interaction_name] | ||
await interaction.start_interaction(_req.request_id, **interaction_kwargs) | ||
|
||
@GPUMemoryLogger(role="sglang rollout", logger=logger) | ||
@torch.no_grad() | ||
def generate_sequences_with_tools(self, prompts: DataProto, **kwargs) -> DataProto: | ||
logger.warning( | ||
"`generate_sequences_with_tools` is deprecated, please use `generate_sequences(...)`", | ||
DeprecationWarning, | ||
stacklevel=2, | ||
) | ||
return self._req_level_generate_sequences(prompts, **kwargs) | ||
|
||
@GPUMemoryLogger(role="sglang rollout", logger=logger) | ||
@torch.no_grad() | ||
def _req_level_generate_sequences(self, prompts: DataProto, **kwargs) -> DataProto: | ||
|
@@ -1082,16 +1093,125 @@ def _req_level_generate_sequences(self, prompts: DataProto, **kwargs) -> DataPro | |
do_sample = prompts.meta_info.get("do_sample", True) | ||
is_validate = prompts.meta_info.get("validate", False) | ||
tgt_device = prompts.batch["input_ids"].device | ||
|
||
if self._tp_rank == 0: | ||
req_list = self._preprocess_prompt_to_async_rollout_requests( | ||
prompts, | ||
) | ||
loop = asyncio.get_event_loop() | ||
output_req_list = loop.run_until_complete( | ||
asyncio.gather( | ||
*[self._async_rollout_a_request(req, do_sample, is_validate, **kwargs) for req in req_list], | ||
) | ||
) | ||
|
||
# 添加进度监控和abort功能 | ||
total_requests = len(req_list) | ||
target_completion = int(total_requests * OVER_SAMPLE_RATE) # 80%完成时abort | ||
zhaochenyang20 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
print(f"🎯 Training mode: over sampling target {target_completion}/{total_requests}") | ||
completed_count = 0 | ||
aborted_requests = [] | ||
|
||
# 区分训练和验证阶段 | ||
if is_validate: | ||
print(f"🔍 Validation mode: processing all {total_requests} requests without abort") | ||
|
||
# 验证阶段:处理所有请求,不使用abort | ||
async def process_all_requests(): | ||
return await asyncio.gather( | ||
*[self._async_rollout_a_request(req, do_sample, is_validate, **kwargs) for req in req_list], | ||
) | ||
|
||
loop = asyncio.get_event_loop() | ||
output_req_list = loop.run_until_complete(process_all_requests()) | ||
else: | ||
print(f"🎯 Training mode: over sampling target {target_completion}/{total_requests}") | ||
|
||
completion_lock = asyncio.Lock() | ||
all_tasks = [] | ||
|
||
async def process_request_with_monitoring(req): | ||
nonlocal completed_count | ||
try: | ||
result = await self._async_rollout_a_request(req, do_sample, is_validate, **kwargs) | ||
|
||
async with completion_lock: | ||
if completed_count < target_completion: | ||
completed_count += 1 | ||
print(f"✅ Request {req.request_id} completed ({completed_count}/{total_requests})") | ||
return result | ||
else: | ||
# 这个请求虽然完成了,但已经超过目标,返回padding | ||
logger.info(f"Request {req.request_id} finished after target met, creating padding") | ||
return self._create_padding_request(req) | ||
except asyncio.CancelledError: | ||
# 请求被取消,返回padding | ||
logger.info(f"Request {req.request_id} was cancelled, creating padding") | ||
aborted_requests.append(req.request_id) | ||
return self._create_padding_request(req) | ||
except Exception as e: | ||
# 请求失败,也算作完成 | ||
logger.warning(f"Request {req.request_id} failed: {e}") | ||
aborted_requests.append(req.request_id) | ||
async with completion_lock: | ||
if completed_count < target_completion: | ||
completed_count += 1 | ||
return self._create_padding_request(req) | ||
|
||
async def monitor_and_cancel(): | ||
nonlocal completed_count | ||
while completed_count < target_completion: | ||
await asyncio.sleep(0.1) | ||
zhaochenyang20 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
||
print(f"🎯 Target reached: {completed_count}/{total_requests} completed!") | ||
print("🚫 Cancelling remaining requests and sending abort to engine...") | ||
|
||
# 取消剩余的任务 | ||
cancelled_count = 0 | ||
for task in all_tasks: | ||
if not task.done(): | ||
task.cancel() | ||
cancelled_count += 1 | ||
|
||
|
||
print(f"📋 Cancelled {cancelled_count} remaining tasks") | ||
|
||
# 向engine发送abort信号,中断所有正在进行的请求 | ||
try: | ||
abort_result = await self._engine.abort_request(abort_all=True) | ||
print(f"✅ Abort signal sent to engine: {abort_result}") | ||
except Exception as e: | ||
print(f"❌ Failed to send abort signal to engine: {e}") | ||
|
||
async def run_with_cancellation(): | ||
nonlocal all_tasks | ||
|
||
# 创建所有任务 | ||
all_tasks = [asyncio.create_task(process_request_with_monitoring(req)) for req in req_list] | ||
|
||
# 启动监控任务 | ||
monitor_task = asyncio.create_task(monitor_and_cancel()) | ||
|
||
try: | ||
# 等待所有任务完成(包括被取消的) | ||
results = await asyncio.gather(*all_tasks, return_exceptions=True) | ||
|
||
# 处理结果,将异常转换为padding | ||
output_req_list = [] | ||
for i, result in enumerate(results): | ||
if isinstance(result, Exception): | ||
# 如果是异常(包括CancelledError),创建padding | ||
logger.warning(f"Task {i} resulted in exception: {result}") | ||
output_req_list.append(self._create_padding_request(req_list[i])) | ||
else: | ||
output_req_list.append(result) | ||
|
||
return output_req_list | ||
finally: | ||
# 取消监控任务 | ||
monitor_task.cancel() | ||
try: | ||
await monitor_task | ||
except asyncio.CancelledError: | ||
pass | ||
|
||
# 运行异步任务 | ||
loop = asyncio.get_event_loop() | ||
output_req_list = loop.run_until_complete(run_with_cancellation()) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. personally, I think encapsulating this part code in the training mode is better, currently the whole function is too long There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We can do this, but the current design is good to me. 😂 Just wrap these three functions like process_request_with_monitoring, monitor_and_cancel, run_with_cancellation should be cool. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sure, It varies from person to person |
||
|
||
sorted_output_req_list = sorted(output_req_list, key=lambda x: (x.batch_data_id, x.rollout_offset)) | ||
else: | ||
sorted_output_req_list = None | ||
|
@@ -1279,6 +1399,94 @@ def _req_level_generate_sequences(self, prompts: DataProto, **kwargs) -> DataPro | |
non_tensor_batch=non_tensor_batch, | ||
) | ||
|
||
def _create_padding_request(self, original_req: AsyncRolloutRequest) -> AsyncRolloutRequest: | ||
"""创建一个padding请求,用于替代被abort的请求。 | ||
|
||
这个padding请求的特点是: | ||
1. 状态为COMPLETED,但包含空的response | ||
2. response_loss_mask全为0,确保在loss计算中被忽略 | ||
3. 保持原始请求的结构,但内容为空 | ||
""" | ||
# 创建padding的response_ids (全为pad_token_id) | ||
padding_response_length = self.config.response_length | ||
padding_response_ids = torch.full( | ||
(1, padding_response_length), | ||
self.pad_token_id, | ||
dtype=torch.long, | ||
device=original_req.input_ids.device if original_req.input_ids is not None else "cpu", | ||
) | ||
|
||
# 创建padding的attention_mask (全为0) | ||
padding_response_attention_mask = torch.zeros( | ||
(1, padding_response_length), | ||
dtype=torch.long, | ||
device=original_req.attention_mask.device if original_req.attention_mask is not None else "cpu", | ||
) | ||
|
||
# 创建padding的position_ids | ||
if original_req.position_ids is not None: | ||
prompt_length = original_req.prompt_ids.shape[-1] if original_req.prompt_ids is not None else 0 | ||
padding_response_position_ids = torch.arange( | ||
prompt_length, prompt_length + padding_response_length, dtype=torch.long | ||
).unsqueeze(0) | ||
if original_req.position_ids.dim() == 2: | ||
# 如果是2D tensor (如qwen2vl) | ||
padding_response_position_ids = padding_response_position_ids.repeat( | ||
original_req.position_ids.shape[0], 1 | ||
) | ||
else: | ||
padding_response_position_ids = None | ||
|
||
# 创建padding的loss_mask (全为0,确保被忽略) | ||
padding_response_loss_mask = torch.zeros( | ||
(1, padding_response_length), | ||
dtype=torch.long, | ||
device=original_req.loss_mask.device if original_req.loss_mask is not None else "cpu", | ||
) | ||
|
||
# 创建新的请求,保持原始结构但使用padding数据 | ||
padding_req = AsyncRolloutRequest( | ||
batch_data_id=original_req.batch_data_id, | ||
rollout_offset=original_req.rollout_offset, | ||
request_id=original_req.request_id + "_padding", | ||
|
||
state=AsyncRolloutRequestStateEnum.COMPLETED, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. state change to abort |
||
messages=original_req.messages, # 保持原始messages | ||
multi_modal_keys=original_req.multi_modal_keys, | ||
multi_modal_data=original_req.multi_modal_data, | ||
multi_modal_inputs=original_req.multi_modal_inputs, | ||
tool_schemas=original_req.tool_schemas, | ||
tools_kwargs=original_req.tools_kwargs, | ||
interaction_kwargs=original_req.interaction_kwargs, | ||
input_ids=original_req.input_ids, # 保持原始input_ids | ||
prompt_ids=original_req.prompt_ids, # 保持原始prompt_ids | ||
response_ids=padding_response_ids, # 使用padding的response_ids | ||
attention_mask=original_req.attention_mask, # 保持原始attention_mask | ||
prompt_attention_mask=original_req.prompt_attention_mask, # 保持原始prompt_attention_mask | ||
response_attention_mask=padding_response_attention_mask, # 使用padding的response_attention_mask | ||
position_ids=original_req.position_ids, # 保持原始position_ids | ||
prompt_position_ids=original_req.prompt_position_ids, # 保持原始prompt_position_ids | ||
response_position_ids=padding_response_position_ids, # 使用padding的response_position_ids | ||
loss_mask=original_req.loss_mask, # 保持原始loss_mask | ||
prompt_loss_mask=original_req.prompt_loss_mask, # 保持原始prompt_loss_mask | ||
response_loss_mask=padding_response_loss_mask, # 使用padding的response_loss_mask (全为0) | ||
reward_scores={}, # 空的reward_scores | ||
max_prompt_len=original_req.max_prompt_len, | ||
max_response_len=original_req.max_response_len, | ||
max_model_len=original_req.max_model_len, | ||
metrics={}, # 空的metrics | ||
output_token_ids=None, # 空的output_token_ids | ||
rollout_log_probs=None, # 空的rollout_log_probs | ||
use_inference_chat_template=original_req.use_inference_chat_template, | ||
tokenization_sanity_check_mode=original_req.tokenization_sanity_check_mode, | ||
generation_prompt_ids=original_req.generation_prompt_ids, | ||
base_conv_wo_gen_prompt_end_pos=original_req.base_conv_wo_gen_prompt_end_pos, | ||
base_conv_with_gen_prompt_end_pos=original_req.base_conv_with_gen_prompt_end_pos, | ||
processing_class=self.processing_class, # 添加缺少的 processing_class 参数 | ||
) | ||
|
||
logger.info(f"Created padding request for aborted request {original_req.request_id}") | ||
return padding_req | ||
|
||
def _preprocess_prompt_to_async_rollout_requests(self, prompts: DataProto, n: int = 1) -> list[AsyncRolloutRequest]: | ||
assert "raw_prompt" in prompts.non_tensor_batch, ( | ||
"need data.return_raw_chat=True, due to no official way do parse_messages" | ||
|
Uh oh!
There was an error while loading. Please reload this page.