You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
[rollout] feat: support over sampling rollout in SGLang Rollout (volcengine#2929)
This PR introduces an **over-sample strategy** for verl's SGLang
multi-turn rollout to address the long-tail problem, where a few slow
requests disproportionately increase the overall rollout time. The core
idea is to over-sample the number of requests at the start of the
rollout and then aggressively cancel any requests that haven't finished
once a target number of completions is met.
- **Improves rollout efficiency** for multi-turn conversations by
reducing total time spent waiting for slow requests.
- **Implements a new request monitoring and cancellation mechanism** to
cut off unnecessary computation.
wandb results is as follow:
https://wandb.ai/zhaochenyang20/benchmark_over_sample_2/workspace?nw=nwuserzhaochenyang20
-----
Of course, this strategy has its share of issues. For example, many
might question why the over-long requests that are dropped aren't simply
saved and continued in the next round. This is certainly possible—it's a
partial rollout strategy—but it would require verl to have a data
buffer, which is beyond the scope of this PR. Furthermore, saving and
continuing these requests would introduce an off-policy problem.
There is also a valid concern that this rather "brutal" dropping
strategy could unfairly ignore very long requests. I agree this is a
very reasonable point, but currently, we don't have a lossless solution.
However, our dropping strategy is very flexible and could even change
with our curriculum learning. For instance, in the example I gave, I
just directly dropped the last 20% of requests. **In practice, we can
dynamically adjust this drop rate and even set different dropping
methods. For example, we could record the return time (t) for the 80% of
requests and then drop any requests that haven't returned after 1.5t.**
We've provided an initial, validated idea and have completed its
implementation. We welcome everyone to join the discussion on how to
accelerate multi-turn rollouts with acceptable losses.
The new over-sample strategy was tested with an 8-GPU setup on the
**gsm8k** dataset, yielding the following results:
- **Rollout Time:** Significant reduction in overall rollout time per
step.
- **Training Rewards:**
- The reward metric for training steps shows a positive bias. This is
because we exclude the aborted requests (which are typically more
difficult and have lower rewards) from the reward calculation.
- The reward metric for validation steps remains accurate and aligns
with the baseline. This is because the cancellation logic is not
triggered during validation, ensuring a fair and complete evaluation.
This feature modifies `sglang_rollout.py` and `metric_utils.py`. To use
it, follow the standard setup and then run the training script with the
over-sample parameters.
https://github.com/zhaochenyang20/Awesome-ML-SYS-Tutorial/blob/main/rlhf/verl/multi-turn/release_log/over_sample.md
The design is centered on three main functions that orchestrate the
over-sampling logic: `run_with_cancellation`,
`process_request_with_monitoring`, and `monitor_and_cancel`. These
functions rely on global variables, such as `all_tasks` and
`completion_lock`, to manage state.
- **`run_with_cancellation`:** This is the entry point. It launches all
requests as `process_request_with_monitoring` tasks concurrently with a
single `monitor_and_cancel` task. It uses `asyncio.gather` to wait for
all tasks to complete (or be canceled) and converts any exceptions from
canceled tasks into padding requests before returning the final output.
- **`process_request_with_monitoring`:** This async function handles a
single request. It waits for the request to complete using
`_async_rollout_a_request` and then checks a shared counter,
`completed_count`, using a `completion_lock` for thread safety. If the
target completion count has not been reached, it returns the real
result. If the target has been met, it returns padding data instead,
effectively "discarding" the late result.
- **`monitor_and_cancel`:** This is a separate async task that polls the
`completed_count`. Once the count reaches the `target_completion`
threshold, it immediately cancels all remaining tasks and sends an
`abort_requests` signal to the SGLang engine, halting any ongoing GPU
computation for those requests.
Key code changes:
- **`sglang_rollout.py`**:
- Adds the three core asynchronous functions for the over-sample
strategy.
- The `AsyncEngine` class now includes a new `abort_request` method that
calls the synchronous `abort_request` in the `tokenizer_manager`.
- **`metric_utils.py`**:
- The `compute_data_metrics` function is updated to exclude the aborted
requests (identified by padding) from the denominator when calculating
average rewards during training. This prevents the training reward from
being artificially lowered by the zero-reward aborted requests.
This implementation is designed to be a straightforward and effective
solution for the long-tail problem, though some aspects of the
asynchronous design and the impact on training variance require further
investigation.
---------
Co-authored-by: zhaochenyang <[email protected]>
Co-authored-by: PopSoda2002 <[email protected]>
Co-authored-by: ChangyiYang <[email protected]>
Co-authored-by: PrinsYin <[email protected]>
Co-authored-by: WindowsXp-Beta <[email protected]>
Co-authored-by: zhaochenyang20 <[email protected]>
0 commit comments