-
Notifications
You must be signed in to change notification settings - Fork 2.5k
[rollout] feat: support over sampling rollout in SGLang Rollout #2929
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
except Exception as e: | ||
logger.error(f"Uncaught exception in process_request_with_monitoring: {e}") | ||
logger.error("This shall not happen, please check the code") | ||
raise e |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this exception capture and logging are a bit unnecessary as the exception will be raised anyway.
if isinstance(result, Exception): | ||
# if it is an exception (including CancelledError), create padding | ||
logger.warning(f"Task {i} resulted in exception: {result}") | ||
output_req_list.append(self._create_padding_request(req_list[i])) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
since all the exceptions are captured, why we still need this check here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just duplicated. Will remove this.
Great catch! |
Could you also modify this docs? With your updated branch considering the over sample. |
ok will do this later |
Co-authored by: https://github.com/PrinsYin huapeng (already in the commit) ChangyiYang [email protected] |
|
Great work,chenyang! |
Thanks chenyang! |
…engine#2929) ### What does this PR do? This PR introduces an **over-sample strategy** for verl's SGLang multi-turn rollout to address the long-tail problem, where a few slow requests disproportionately increase the overall rollout time. The core idea is to over-sample the number of requests at the start of the rollout and then aggressively cancel any requests that haven't finished once a target number of completions is met. - **Improves rollout efficiency** for multi-turn conversations by reducing total time spent waiting for slow requests. - **Implements a new request monitoring and cancellation mechanism** to cut off unnecessary computation. wandb results is as follow: https://wandb.ai/zhaochenyang20/benchmark_over_sample_2/workspace?nw=nwuserzhaochenyang20 ----- Of course, this strategy has its share of issues. For example, many might question why the over-long requests that are dropped aren't simply saved and continued in the next round. This is certainly possible—it's a partial rollout strategy—but it would require verl to have a data buffer, which is beyond the scope of this PR. Furthermore, saving and continuing these requests would introduce an off-policy problem. There is also a valid concern that this rather "brutal" dropping strategy could unfairly ignore very long requests. I agree this is a very reasonable point, but currently, we don't have a lossless solution. However, our dropping strategy is very flexible and could even change with our curriculum learning. For instance, in the example I gave, I just directly dropped the last 20% of requests. **In practice, we can dynamically adjust this drop rate and even set different dropping methods. For example, we could record the return time (t) for the 80% of requests and then drop any requests that haven't returned after 1.5t.** We've provided an initial, validated idea and have completed its implementation. We welcome everyone to join the discussion on how to accelerate multi-turn rollouts with acceptable losses. ### Test The new over-sample strategy was tested with an 8-GPU setup on the **gsm8k** dataset, yielding the following results: - **Rollout Time:** Significant reduction in overall rollout time per step. - **Training Rewards:** - The reward metric for training steps shows a positive bias. This is because we exclude the aborted requests (which are typically more difficult and have lower rewards) from the reward calculation. - The reward metric for validation steps remains accurate and aligns with the baseline. This is because the cancellation logic is not triggered during validation, ensuring a fair and complete evaluation. ### API and Usage Example This feature modifies `sglang_rollout.py` and `metric_utils.py`. To use it, follow the standard setup and then run the training script with the over-sample parameters. https://github.com/zhaochenyang20/Awesome-ML-SYS-Tutorial/blob/main/rlhf/verl/multi-turn/release_log/over_sample.md ### Design & Code Changes The design is centered on three main functions that orchestrate the over-sampling logic: `run_with_cancellation`, `process_request_with_monitoring`, and `monitor_and_cancel`. These functions rely on global variables, such as `all_tasks` and `completion_lock`, to manage state. - **`run_with_cancellation`:** This is the entry point. It launches all requests as `process_request_with_monitoring` tasks concurrently with a single `monitor_and_cancel` task. It uses `asyncio.gather` to wait for all tasks to complete (or be canceled) and converts any exceptions from canceled tasks into padding requests before returning the final output. - **`process_request_with_monitoring`:** This async function handles a single request. It waits for the request to complete using `_async_rollout_a_request` and then checks a shared counter, `completed_count`, using a `completion_lock` for thread safety. If the target completion count has not been reached, it returns the real result. If the target has been met, it returns padding data instead, effectively "discarding" the late result. - **`monitor_and_cancel`:** This is a separate async task that polls the `completed_count`. Once the count reaches the `target_completion` threshold, it immediately cancels all remaining tasks and sends an `abort_requests` signal to the SGLang engine, halting any ongoing GPU computation for those requests. Key code changes: - **`sglang_rollout.py`**: - Adds the three core asynchronous functions for the over-sample strategy. - The `AsyncEngine` class now includes a new `abort_request` method that calls the synchronous `abort_request` in the `tokenizer_manager`. - **`metric_utils.py`**: - The `compute_data_metrics` function is updated to exclude the aborted requests (identified by padding) from the denominator when calculating average rewards during training. This prevents the training reward from being artificially lowered by the zero-reward aborted requests. This implementation is designed to be a straightforward and effective solution for the long-tail problem, though some aspects of the asynchronous design and the impact on training variance require further investigation. --------- Co-authored-by: zhaochenyang <[email protected]> Co-authored-by: PopSoda2002 <[email protected]> Co-authored-by: ChangyiYang <[email protected]> Co-authored-by: PrinsYin <[email protected]> Co-authored-by: WindowsXp-Beta <[email protected]> Co-authored-by: zhaochenyang20 <[email protected]>
…engine#2929) ### What does this PR do? This PR introduces an **over-sample strategy** for verl's SGLang multi-turn rollout to address the long-tail problem, where a few slow requests disproportionately increase the overall rollout time. The core idea is to over-sample the number of requests at the start of the rollout and then aggressively cancel any requests that haven't finished once a target number of completions is met. - **Improves rollout efficiency** for multi-turn conversations by reducing total time spent waiting for slow requests. - **Implements a new request monitoring and cancellation mechanism** to cut off unnecessary computation. wandb results is as follow: https://wandb.ai/zhaochenyang20/benchmark_over_sample_2/workspace?nw=nwuserzhaochenyang20 ----- Of course, this strategy has its share of issues. For example, many might question why the over-long requests that are dropped aren't simply saved and continued in the next round. This is certainly possible—it's a partial rollout strategy—but it would require verl to have a data buffer, which is beyond the scope of this PR. Furthermore, saving and continuing these requests would introduce an off-policy problem. There is also a valid concern that this rather "brutal" dropping strategy could unfairly ignore very long requests. I agree this is a very reasonable point, but currently, we don't have a lossless solution. However, our dropping strategy is very flexible and could even change with our curriculum learning. For instance, in the example I gave, I just directly dropped the last 20% of requests. **In practice, we can dynamically adjust this drop rate and even set different dropping methods. For example, we could record the return time (t) for the 80% of requests and then drop any requests that haven't returned after 1.5t.** We've provided an initial, validated idea and have completed its implementation. We welcome everyone to join the discussion on how to accelerate multi-turn rollouts with acceptable losses. ### Test The new over-sample strategy was tested with an 8-GPU setup on the **gsm8k** dataset, yielding the following results: - **Rollout Time:** Significant reduction in overall rollout time per step. - **Training Rewards:** - The reward metric for training steps shows a positive bias. This is because we exclude the aborted requests (which are typically more difficult and have lower rewards) from the reward calculation. - The reward metric for validation steps remains accurate and aligns with the baseline. This is because the cancellation logic is not triggered during validation, ensuring a fair and complete evaluation. ### API and Usage Example This feature modifies `sglang_rollout.py` and `metric_utils.py`. To use it, follow the standard setup and then run the training script with the over-sample parameters. https://github.com/zhaochenyang20/Awesome-ML-SYS-Tutorial/blob/main/rlhf/verl/multi-turn/release_log/over_sample.md ### Design & Code Changes The design is centered on three main functions that orchestrate the over-sampling logic: `run_with_cancellation`, `process_request_with_monitoring`, and `monitor_and_cancel`. These functions rely on global variables, such as `all_tasks` and `completion_lock`, to manage state. - **`run_with_cancellation`:** This is the entry point. It launches all requests as `process_request_with_monitoring` tasks concurrently with a single `monitor_and_cancel` task. It uses `asyncio.gather` to wait for all tasks to complete (or be canceled) and converts any exceptions from canceled tasks into padding requests before returning the final output. - **`process_request_with_monitoring`:** This async function handles a single request. It waits for the request to complete using `_async_rollout_a_request` and then checks a shared counter, `completed_count`, using a `completion_lock` for thread safety. If the target completion count has not been reached, it returns the real result. If the target has been met, it returns padding data instead, effectively "discarding" the late result. - **`monitor_and_cancel`:** This is a separate async task that polls the `completed_count`. Once the count reaches the `target_completion` threshold, it immediately cancels all remaining tasks and sends an `abort_requests` signal to the SGLang engine, halting any ongoing GPU computation for those requests. Key code changes: - **`sglang_rollout.py`**: - Adds the three core asynchronous functions for the over-sample strategy. - The `AsyncEngine` class now includes a new `abort_request` method that calls the synchronous `abort_request` in the `tokenizer_manager`. - **`metric_utils.py`**: - The `compute_data_metrics` function is updated to exclude the aborted requests (identified by padding) from the denominator when calculating average rewards during training. This prevents the training reward from being artificially lowered by the zero-reward aborted requests. This implementation is designed to be a straightforward and effective solution for the long-tail problem, though some aspects of the asynchronous design and the impact on training variance require further investigation. --------- Co-authored-by: zhaochenyang <[email protected]> Co-authored-by: PopSoda2002 <[email protected]> Co-authored-by: ChangyiYang <[email protected]> Co-authored-by: PrinsYin <[email protected]> Co-authored-by: WindowsXp-Beta <[email protected]> Co-authored-by: zhaochenyang20 <[email protected]>
…engine#2929) This PR introduces an **over-sample strategy** for verl's SGLang multi-turn rollout to address the long-tail problem, where a few slow requests disproportionately increase the overall rollout time. The core idea is to over-sample the number of requests at the start of the rollout and then aggressively cancel any requests that haven't finished once a target number of completions is met. - **Improves rollout efficiency** for multi-turn conversations by reducing total time spent waiting for slow requests. - **Implements a new request monitoring and cancellation mechanism** to cut off unnecessary computation. wandb results is as follow: https://wandb.ai/zhaochenyang20/benchmark_over_sample_2/workspace?nw=nwuserzhaochenyang20 ----- Of course, this strategy has its share of issues. For example, many might question why the over-long requests that are dropped aren't simply saved and continued in the next round. This is certainly possible—it's a partial rollout strategy—but it would require verl to have a data buffer, which is beyond the scope of this PR. Furthermore, saving and continuing these requests would introduce an off-policy problem. There is also a valid concern that this rather "brutal" dropping strategy could unfairly ignore very long requests. I agree this is a very reasonable point, but currently, we don't have a lossless solution. However, our dropping strategy is very flexible and could even change with our curriculum learning. For instance, in the example I gave, I just directly dropped the last 20% of requests. **In practice, we can dynamically adjust this drop rate and even set different dropping methods. For example, we could record the return time (t) for the 80% of requests and then drop any requests that haven't returned after 1.5t.** We've provided an initial, validated idea and have completed its implementation. We welcome everyone to join the discussion on how to accelerate multi-turn rollouts with acceptable losses. The new over-sample strategy was tested with an 8-GPU setup on the **gsm8k** dataset, yielding the following results: - **Rollout Time:** Significant reduction in overall rollout time per step. - **Training Rewards:** - The reward metric for training steps shows a positive bias. This is because we exclude the aborted requests (which are typically more difficult and have lower rewards) from the reward calculation. - The reward metric for validation steps remains accurate and aligns with the baseline. This is because the cancellation logic is not triggered during validation, ensuring a fair and complete evaluation. This feature modifies `sglang_rollout.py` and `metric_utils.py`. To use it, follow the standard setup and then run the training script with the over-sample parameters. https://github.com/zhaochenyang20/Awesome-ML-SYS-Tutorial/blob/main/rlhf/verl/multi-turn/release_log/over_sample.md The design is centered on three main functions that orchestrate the over-sampling logic: `run_with_cancellation`, `process_request_with_monitoring`, and `monitor_and_cancel`. These functions rely on global variables, such as `all_tasks` and `completion_lock`, to manage state. - **`run_with_cancellation`:** This is the entry point. It launches all requests as `process_request_with_monitoring` tasks concurrently with a single `monitor_and_cancel` task. It uses `asyncio.gather` to wait for all tasks to complete (or be canceled) and converts any exceptions from canceled tasks into padding requests before returning the final output. - **`process_request_with_monitoring`:** This async function handles a single request. It waits for the request to complete using `_async_rollout_a_request` and then checks a shared counter, `completed_count`, using a `completion_lock` for thread safety. If the target completion count has not been reached, it returns the real result. If the target has been met, it returns padding data instead, effectively "discarding" the late result. - **`monitor_and_cancel`:** This is a separate async task that polls the `completed_count`. Once the count reaches the `target_completion` threshold, it immediately cancels all remaining tasks and sends an `abort_requests` signal to the SGLang engine, halting any ongoing GPU computation for those requests. Key code changes: - **`sglang_rollout.py`**: - Adds the three core asynchronous functions for the over-sample strategy. - The `AsyncEngine` class now includes a new `abort_request` method that calls the synchronous `abort_request` in the `tokenizer_manager`. - **`metric_utils.py`**: - The `compute_data_metrics` function is updated to exclude the aborted requests (identified by padding) from the denominator when calculating average rewards during training. This prevents the training reward from being artificially lowered by the zero-reward aborted requests. This implementation is designed to be a straightforward and effective solution for the long-tail problem, though some aspects of the asynchronous design and the impact on training variance require further investigation. --------- Co-authored-by: zhaochenyang <[email protected]> Co-authored-by: PopSoda2002 <[email protected]> Co-authored-by: ChangyiYang <[email protected]> Co-authored-by: PrinsYin <[email protected]> Co-authored-by: WindowsXp-Beta <[email protected]> Co-authored-by: zhaochenyang20 <[email protected]>
…engine#2929) ### What does this PR do? This PR introduces an **over-sample strategy** for verl's SGLang multi-turn rollout to address the long-tail problem, where a few slow requests disproportionately increase the overall rollout time. The core idea is to over-sample the number of requests at the start of the rollout and then aggressively cancel any requests that haven't finished once a target number of completions is met. - **Improves rollout efficiency** for multi-turn conversations by reducing total time spent waiting for slow requests. - **Implements a new request monitoring and cancellation mechanism** to cut off unnecessary computation. wandb results is as follow: https://wandb.ai/zhaochenyang20/benchmark_over_sample_2/workspace?nw=nwuserzhaochenyang20 ----- Of course, this strategy has its share of issues. For example, many might question why the over-long requests that are dropped aren't simply saved and continued in the next round. This is certainly possible—it's a partial rollout strategy—but it would require verl to have a data buffer, which is beyond the scope of this PR. Furthermore, saving and continuing these requests would introduce an off-policy problem. There is also a valid concern that this rather "brutal" dropping strategy could unfairly ignore very long requests. I agree this is a very reasonable point, but currently, we don't have a lossless solution. However, our dropping strategy is very flexible and could even change with our curriculum learning. For instance, in the example I gave, I just directly dropped the last 20% of requests. **In practice, we can dynamically adjust this drop rate and even set different dropping methods. For example, we could record the return time (t) for the 80% of requests and then drop any requests that haven't returned after 1.5t.** We've provided an initial, validated idea and have completed its implementation. We welcome everyone to join the discussion on how to accelerate multi-turn rollouts with acceptable losses. ### Test The new over-sample strategy was tested with an 8-GPU setup on the **gsm8k** dataset, yielding the following results: - **Rollout Time:** Significant reduction in overall rollout time per step. - **Training Rewards:** - The reward metric for training steps shows a positive bias. This is because we exclude the aborted requests (which are typically more difficult and have lower rewards) from the reward calculation. - The reward metric for validation steps remains accurate and aligns with the baseline. This is because the cancellation logic is not triggered during validation, ensuring a fair and complete evaluation. ### API and Usage Example This feature modifies `sglang_rollout.py` and `metric_utils.py`. To use it, follow the standard setup and then run the training script with the over-sample parameters. https://github.com/zhaochenyang20/Awesome-ML-SYS-Tutorial/blob/main/rlhf/verl/multi-turn/release_log/over_sample.md ### Design & Code Changes The design is centered on three main functions that orchestrate the over-sampling logic: `run_with_cancellation`, `process_request_with_monitoring`, and `monitor_and_cancel`. These functions rely on global variables, such as `all_tasks` and `completion_lock`, to manage state. - **`run_with_cancellation`:** This is the entry point. It launches all requests as `process_request_with_monitoring` tasks concurrently with a single `monitor_and_cancel` task. It uses `asyncio.gather` to wait for all tasks to complete (or be canceled) and converts any exceptions from canceled tasks into padding requests before returning the final output. - **`process_request_with_monitoring`:** This async function handles a single request. It waits for the request to complete using `_async_rollout_a_request` and then checks a shared counter, `completed_count`, using a `completion_lock` for thread safety. If the target completion count has not been reached, it returns the real result. If the target has been met, it returns padding data instead, effectively "discarding" the late result. - **`monitor_and_cancel`:** This is a separate async task that polls the `completed_count`. Once the count reaches the `target_completion` threshold, it immediately cancels all remaining tasks and sends an `abort_requests` signal to the SGLang engine, halting any ongoing GPU computation for those requests. Key code changes: - **`sglang_rollout.py`**: - Adds the three core asynchronous functions for the over-sample strategy. - The `AsyncEngine` class now includes a new `abort_request` method that calls the synchronous `abort_request` in the `tokenizer_manager`. - **`metric_utils.py`**: - The `compute_data_metrics` function is updated to exclude the aborted requests (identified by padding) from the denominator when calculating average rewards during training. This prevents the training reward from being artificially lowered by the zero-reward aborted requests. This implementation is designed to be a straightforward and effective solution for the long-tail problem, though some aspects of the asynchronous design and the impact on training variance require further investigation. --------- Co-authored-by: zhaochenyang <[email protected]> Co-authored-by: PopSoda2002 <[email protected]> Co-authored-by: ChangyiYang <[email protected]> Co-authored-by: PrinsYin <[email protected]> Co-authored-by: WindowsXp-Beta <[email protected]> Co-authored-by: zhaochenyang20 <[email protected]>
What does this PR do?
This PR introduces an over-sample strategy for verl's SGLang multi-turn rollout to address the long-tail problem, where a few slow requests disproportionately increase the overall rollout time. The core idea is to over-sample the number of requests at the start of the rollout and then aggressively cancel any requests that haven't finished once a target number of completions is met.
wandb results is as follow:
https://wandb.ai/zhaochenyang20/benchmark_over_sample_2/workspace?nw=nwuserzhaochenyang20
Of course, this strategy has its share of issues. For example, many might question why the over-long requests that are dropped aren't simply saved and continued in the next round. This is certainly possible—it's a partial rollout strategy—but it would require verl to have a data buffer, which is beyond the scope of this PR. Furthermore, saving and continuing these requests would introduce an off-policy problem.
There is also a valid concern that this rather "brutal" dropping strategy could unfairly ignore very long requests. I agree this is a very reasonable point, but currently, we don't have a lossless solution. However, our dropping strategy is very flexible and could even change with our curriculum learning. For instance, in the example I gave, I just directly dropped the last 20% of requests. In practice, we can dynamically adjust this drop rate and even set different dropping methods. For example, we could record the return time (t) for the 80% of requests and then drop any requests that haven't returned after 1.5t.
We've provided an initial, validated idea and have completed its implementation. We welcome everyone to join the discussion on how to accelerate multi-turn rollouts with acceptable losses.
Test
The new over-sample strategy was tested with an 8-GPU setup on the gsm8k dataset, yielding the following results:
API and Usage Example
This feature modifies
sglang_rollout.py
andmetric_utils.py
. To use it, follow the standard setup and then run the training script with the over-sample parameters.https://github.com/zhaochenyang20/Awesome-ML-SYS-Tutorial/blob/main/rlhf/verl/multi-turn/release_log/over_sample.md
Design & Code Changes
The design is centered on three main functions that orchestrate the over-sampling logic:
run_with_cancellation
,process_request_with_monitoring
, andmonitor_and_cancel
. These functions rely on global variables, such asall_tasks
andcompletion_lock
, to manage state.run_with_cancellation
: This is the entry point. It launches all requests asprocess_request_with_monitoring
tasks concurrently with a singlemonitor_and_cancel
task. It usesasyncio.gather
to wait for all tasks to complete (or be canceled) and converts any exceptions from canceled tasks into padding requests before returning the final output.process_request_with_monitoring
: This async function handles a single request. It waits for the request to complete using_async_rollout_a_request
and then checks a shared counter,completed_count
, using acompletion_lock
for thread safety. If the target completion count has not been reached, it returns the real result. If the target has been met, it returns padding data instead, effectively "discarding" the late result.monitor_and_cancel
: This is a separate async task that polls thecompleted_count
. Once the count reaches thetarget_completion
threshold, it immediately cancels all remaining tasks and sends anabort_requests
signal to the SGLang engine, halting any ongoing GPU computation for those requests.Key code changes:
sglang_rollout.py
:AsyncEngine
class now includes a newabort_request
method that calls the synchronousabort_request
in thetokenizer_manager
.metric_utils.py
:compute_data_metrics
function is updated to exclude the aborted requests (identified by padding) from the denominator when calculating average rewards during training. This prevents the training reward from being artificially lowered by the zero-reward aborted requests.This implementation is designed to be a straightforward and effective solution for the long-tail problem, though some aspects of the asynchronous design and the impact on training variance require further investigation.