-
Notifications
You must be signed in to change notification settings - Fork 2.5k
[rollout] feat: support over sampling rollout in SGLang Rollout #2929
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 59 commits
953eae6
0384aef
a0ab8b0
652f0f7
3fdfd00
d37f51d
51410e7
de8feb3
da37e6a
5308507
25a15b9
07ad3f9
1b8bfa9
519a1b0
ae43c5a
45265ec
ef6ab93
264bed6
0c63925
5d3b970
89da6d7
614f0ad
08b7e02
36abd78
b979a73
b4fdfcf
88a23ce
4e2316b
3a37c6e
94d7c68
5fd6b67
1b6eb79
271b493
6e98696
a32cfb4
ea29d31
cc8fb16
60c1610
96a8ada
903c2d3
53099d1
3d7fa4a
5a733f8
4cc54b8
c67ebff
43a9c31
8e17a93
71687d7
b084dfd
b66327e
23c2957
b944230
7ee096e
97d137d
35b65eb
6f3e49e
b0f0245
2f5d4f5
4c882d1
4822564
b28f178
a794b29
cdc2eff
77e8491
4fd6b36
58f3570
23452ed
3ce9b1f
d6128a2
ba87964
96c5d75
c1e8884
06beaf0
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -118,6 +118,20 @@ def compute_data_metrics(batch: DataProto, use_critic: bool = True) -> dict[str, | |
prompt_length = response_info["prompt_length"] | ||
response_length = response_info["response_length"] | ||
|
||
aborted_mask = (response_length == 0).bool() | ||
non_aborted_mask = ~aborted_mask | ||
|
||
non_aborted_sequence_score = sequence_score[non_aborted_mask] | ||
non_aborted_sequence_reward = sequence_reward[non_aborted_mask] | ||
|
||
score_mean = torch.mean(non_aborted_sequence_score).detach().item() | ||
score_max = torch.max(non_aborted_sequence_score).detach().item() | ||
score_min = torch.min(non_aborted_sequence_score).detach().item() | ||
|
||
reward_mean = torch.mean(non_aborted_sequence_reward).detach().item() | ||
reward_max = torch.max(non_aborted_sequence_reward).detach().item() | ||
reward_min = torch.min(non_aborted_sequence_reward).detach().item() | ||
zhaochenyang20 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
valid_adv = torch.masked_select(advantages, response_mask) | ||
valid_returns = torch.masked_select(returns, response_mask) | ||
|
||
|
@@ -127,15 +141,30 @@ def compute_data_metrics(batch: DataProto, use_critic: bool = True) -> dict[str, | |
return_diff_var = torch.var(valid_returns - valid_values) | ||
return_var = torch.var(valid_returns) | ||
|
||
# Aborted samples and non-aborted response length statistics | ||
# response_length_non_aborted/*: statistics computed on non-aborted samples only | ||
aborted_ratio = torch.mean(aborted_mask.float()).detach().item() | ||
|
||
non_aborted_response_length = response_length[non_aborted_mask] | ||
if non_aborted_response_length.numel() > 0: | ||
non_aborted_response_length_mean = torch.mean(non_aborted_response_length).detach().item() | ||
non_aborted_response_length_max = torch.max(non_aborted_response_length).detach().item() | ||
non_aborted_response_length_min = torch.min(non_aborted_response_length).detach().item() | ||
non_aborted_response_length_clip_ratio = ( | ||
torch.mean(torch.eq(non_aborted_response_length, max_response_length).float()).detach().item() | ||
) | ||
else: | ||
raise ValueError("All samples are aborted, this should not happen.") | ||
|
||
metrics = { | ||
# score | ||
"critic/score/mean": torch.mean(sequence_score).detach().item(), | ||
"critic/score/max": torch.max(sequence_score).detach().item(), | ||
"critic/score/min": torch.min(sequence_score).detach().item(), | ||
"critic/score/mean": score_mean, | ||
"critic/score/max": score_max, | ||
"critic/score/min": score_min, | ||
# reward | ||
"critic/rewards/mean": torch.mean(sequence_reward).detach().item(), | ||
"critic/rewards/max": torch.max(sequence_reward).detach().item(), | ||
"critic/rewards/min": torch.min(sequence_reward).detach().item(), | ||
"critic/rewards/mean": reward_mean, | ||
"critic/rewards/max": reward_max, | ||
"critic/rewards/min": reward_min, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Does it make sense to apply the same non_aborted_mask logic to the response_length metrics as well ( mean/max/min response length)? Otherwise, the metrics might still include padded responses from aborted requests There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. make sense. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I changed accordingly. Adding the response length mean after/before abort. And I added the drop rate metric. |
||
# adv | ||
"critic/advantages/mean": torch.mean(valid_adv).detach().item(), | ||
"critic/advantages/max": torch.max(valid_adv).detach().item(), | ||
|
@@ -163,6 +192,15 @@ def compute_data_metrics(batch: DataProto, use_critic: bool = True) -> dict[str, | |
"response_length/clip_ratio": torch.mean(torch.eq(response_length, max_response_length).float()) | ||
.detach() | ||
.item(), | ||
# response length (non-aborted only) | ||
# These statistics exclude aborted samples to avoid skew from zeros | ||
"response_length_non_aborted/mean": non_aborted_response_length_mean, | ||
"response_length_non_aborted/max": non_aborted_response_length_max, | ||
"response_length_non_aborted/min": non_aborted_response_length_min, | ||
"response_length_non_aborted/clip_ratio": non_aborted_response_length_clip_ratio, | ||
# aborted ratio | ||
# Fraction of samples whose response length is zero | ||
"response/aborted_ratio": aborted_ratio, | ||
# prompt length | ||
"prompt_length/mean": torch.mean(prompt_length).detach().item(), | ||
"prompt_length/max": torch.max(prompt_length).detach().item(), | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this may not be correct