Skip to content

Conversation

zhaochenyang20
Copy link
Collaborator

@zhaochenyang20 zhaochenyang20 commented Aug 5, 2025

What does this PR do?

This PR introduces an over-sample strategy for verl's SGLang multi-turn rollout to address the long-tail problem, where a few slow requests disproportionately increase the overall rollout time. The core idea is to over-sample the number of requests at the start of the rollout and then aggressively cancel any requests that haven't finished once a target number of completions is met.

  • Improves rollout efficiency for multi-turn conversations by reducing total time spent waiting for slow requests.
  • Implements a new request monitoring and cancellation mechanism to cut off unnecessary computation.

wandb results is as follow:

https://wandb.ai/zhaochenyang20/benchmark_over_sample_2/workspace?nw=nwuserzhaochenyang20


Of course, this strategy has its share of issues. For example, many might question why the over-long requests that are dropped aren't simply saved and continued in the next round. This is certainly possible—it's a partial rollout strategy—but it would require verl to have a data buffer, which is beyond the scope of this PR. Furthermore, saving and continuing these requests would introduce an off-policy problem.

There is also a valid concern that this rather "brutal" dropping strategy could unfairly ignore very long requests. I agree this is a very reasonable point, but currently, we don't have a lossless solution. However, our dropping strategy is very flexible and could even change with our curriculum learning. For instance, in the example I gave, I just directly dropped the last 20% of requests. In practice, we can dynamically adjust this drop rate and even set different dropping methods. For example, we could record the return time (t) for the 80% of requests and then drop any requests that haven't returned after 1.5t.

We've provided an initial, validated idea and have completed its implementation. We welcome everyone to join the discussion on how to accelerate multi-turn rollouts with acceptable losses.

Test

The new over-sample strategy was tested with an 8-GPU setup on the gsm8k dataset, yielding the following results:

  • Rollout Time: Significant reduction in overall rollout time per step.
  • Training Rewards:
    • The reward metric for training steps shows a positive bias. This is because we exclude the aborted requests (which are typically more difficult and have lower rewards) from the reward calculation.
    • The reward metric for validation steps remains accurate and aligns with the baseline. This is because the cancellation logic is not triggered during validation, ensuring a fair and complete evaluation.

API and Usage Example

This feature modifies sglang_rollout.py and metric_utils.py. To use it, follow the standard setup and then run the training script with the over-sample parameters.

https://github.com/zhaochenyang20/Awesome-ML-SYS-Tutorial/blob/main/rlhf/verl/multi-turn/release_log/over_sample.md

Design & Code Changes

The design is centered on three main functions that orchestrate the over-sampling logic: run_with_cancellation, process_request_with_monitoring, and monitor_and_cancel. These functions rely on global variables, such as all_tasks and completion_lock, to manage state.

  • run_with_cancellation: This is the entry point. It launches all requests as process_request_with_monitoring tasks concurrently with a single monitor_and_cancel task. It uses asyncio.gather to wait for all tasks to complete (or be canceled) and converts any exceptions from canceled tasks into padding requests before returning the final output.

  • process_request_with_monitoring: This async function handles a single request. It waits for the request to complete using _async_rollout_a_request and then checks a shared counter, completed_count, using a completion_lock for thread safety. If the target completion count has not been reached, it returns the real result. If the target has been met, it returns padding data instead, effectively "discarding" the late result.

  • monitor_and_cancel: This is a separate async task that polls the completed_count. Once the count reaches the target_completion threshold, it immediately cancels all remaining tasks and sends an abort_requests signal to the SGLang engine, halting any ongoing GPU computation for those requests.

Key code changes:

  • sglang_rollout.py:
    • Adds the three core asynchronous functions for the over-sample strategy.
    • The AsyncEngine class now includes a new abort_request method that calls the synchronous abort_request in the tokenizer_manager.
  • metric_utils.py:
    • The compute_data_metrics function is updated to exclude the aborted requests (identified by padding) from the denominator when calculating average rewards during training. This prevents the training reward from being artificially lowered by the zero-reward aborted requests.

This implementation is designed to be a straightforward and effective solution for the long-tail problem, though some aspects of the asynchronous design and the impact on training variance require further investigation.

Comment on lines 1124 to 1127
except Exception as e:
logger.error(f"Uncaught exception in process_request_with_monitoring: {e}")
logger.error("This shall not happen, please check the code")
raise e
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this exception capture and logging are a bit unnecessary as the exception will be raised anyway.

Comment on lines 1160 to 1163
if isinstance(result, Exception):
# if it is an exception (including CancelledError), create padding
logger.warning(f"Task {i} resulted in exception: {result}")
output_req_list.append(self._create_padding_request(req_list[i]))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

since all the exceptions are captured, why we still need this check here?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just duplicated. Will remove this.

@zhaochenyang20
Copy link
Collaborator Author

Additionally, I add the profile logic in https://github.com/zhaochenyang20/verl/tree/over_sample_profile, here I can share some results: image It works right and padding time is minor

Great catch!

@zhaochenyang20
Copy link
Collaborator Author

Additionally, I add the profile logic in https://github.com/zhaochenyang20/verl/tree/over_sample_profile, here I can share some results: image It works right and padding time is minor

Could you also modify this docs?

https://github.com/zhaochenyang20/Awesome-ML-SYS-Tutorial/blob/main/rlhf/verl/multi-turn/tool_examples/profile.md

With your updated branch considering the over sample.

@PopSoda2002
Copy link
Contributor

Additionally, I add the profile logic in https://github.com/zhaochenyang20/verl/tree/over_sample_profile, here I can share some results: image It works right and padding time is minor

Could you also modify this docs?

https://github.com/zhaochenyang20/Awesome-ML-SYS-Tutorial/blob/main/rlhf/verl/multi-turn/tool_examples/profile.md

With your updated branch considering the over sample.

ok will do this later

@zhaochenyang20
Copy link
Collaborator Author

Co-authored by:

https://github.com/PrinsYin
[email protected]

huapeng (already in the commit)

ChangyiYang [email protected]

and https://github.com/WindowsXp-Beta

@PopSoda2002
Copy link
Contributor

Additionally, I add the profile logic in https://github.com/zhaochenyang20/verl/tree/over_sample_profile, here I can share some results: image It works right and padding time is minor

Could you also modify this docs?

https://github.com/zhaochenyang20/Awesome-ML-SYS-Tutorial/blob/main/rlhf/verl/multi-turn/tool_examples/profile.md

With your updated branch considering the over sample.

Done in zhaochenyang20/Awesome-ML-SYS-Tutorial#186

@zhaochenyang20 zhaochenyang20 merged commit ea885f3 into volcengine:main Aug 14, 2025
47 of 52 checks passed
@ChangyiYang
Copy link
Contributor

Great work,chenyang!

@PrinsYin
Copy link
Contributor

PrinsYin commented Aug 14, 2025

Thanks chenyang!

yellowbee686 pushed a commit to yellowbee686/verl that referenced this pull request Aug 15, 2025
…engine#2929)

### What does this PR do?

This PR introduces an **over-sample strategy** for verl's SGLang
multi-turn rollout to address the long-tail problem, where a few slow
requests disproportionately increase the overall rollout time. The core
idea is to over-sample the number of requests at the start of the
rollout and then aggressively cancel any requests that haven't finished
once a target number of completions is met.

- **Improves rollout efficiency** for multi-turn conversations by
reducing total time spent waiting for slow requests.
- **Implements a new request monitoring and cancellation mechanism** to
cut off unnecessary computation.

wandb results is as follow:


https://wandb.ai/zhaochenyang20/benchmark_over_sample_2/workspace?nw=nwuserzhaochenyang20

-----

Of course, this strategy has its share of issues. For example, many
might question why the over-long requests that are dropped aren't simply
saved and continued in the next round. This is certainly possible—it's a
partial rollout strategy—but it would require verl to have a data
buffer, which is beyond the scope of this PR. Furthermore, saving and
continuing these requests would introduce an off-policy problem.

There is also a valid concern that this rather "brutal" dropping
strategy could unfairly ignore very long requests. I agree this is a
very reasonable point, but currently, we don't have a lossless solution.
However, our dropping strategy is very flexible and could even change
with our curriculum learning. For instance, in the example I gave, I
just directly dropped the last 20% of requests. **In practice, we can
dynamically adjust this drop rate and even set different dropping
methods. For example, we could record the return time (t) for the 80% of
requests and then drop any requests that haven't returned after 1.5t.**

We've provided an initial, validated idea and have completed its
implementation. We welcome everyone to join the discussion on how to
accelerate multi-turn rollouts with acceptable losses.

### Test

The new over-sample strategy was tested with an 8-GPU setup on the
**gsm8k** dataset, yielding the following results:

- **Rollout Time:** Significant reduction in overall rollout time per
step.
  - **Training Rewards:**
- The reward metric for training steps shows a positive bias. This is
because we exclude the aborted requests (which are typically more
difficult and have lower rewards) from the reward calculation.
- The reward metric for validation steps remains accurate and aligns
with the baseline. This is because the cancellation logic is not
triggered during validation, ensuring a fair and complete evaluation.

### API and Usage Example

This feature modifies `sglang_rollout.py` and `metric_utils.py`. To use
it, follow the standard setup and then run the training script with the
over-sample parameters.


https://github.com/zhaochenyang20/Awesome-ML-SYS-Tutorial/blob/main/rlhf/verl/multi-turn/release_log/over_sample.md

### Design & Code Changes

The design is centered on three main functions that orchestrate the
over-sampling logic: `run_with_cancellation`,
`process_request_with_monitoring`, and `monitor_and_cancel`. These
functions rely on global variables, such as `all_tasks` and
`completion_lock`, to manage state.

- **`run_with_cancellation`:** This is the entry point. It launches all
requests as `process_request_with_monitoring` tasks concurrently with a
single `monitor_and_cancel` task. It uses `asyncio.gather` to wait for
all tasks to complete (or be canceled) and converts any exceptions from
canceled tasks into padding requests before returning the final output.

- **`process_request_with_monitoring`:** This async function handles a
single request. It waits for the request to complete using
`_async_rollout_a_request` and then checks a shared counter,
`completed_count`, using a `completion_lock` for thread safety. If the
target completion count has not been reached, it returns the real
result. If the target has been met, it returns padding data instead,
effectively "discarding" the late result.

- **`monitor_and_cancel`:** This is a separate async task that polls the
`completed_count`. Once the count reaches the `target_completion`
threshold, it immediately cancels all remaining tasks and sends an
`abort_requests` signal to the SGLang engine, halting any ongoing GPU
computation for those requests.

Key code changes:

  - **`sglang_rollout.py`**:
- Adds the three core asynchronous functions for the over-sample
strategy.
- The `AsyncEngine` class now includes a new `abort_request` method that
calls the synchronous `abort_request` in the `tokenizer_manager`.
  - **`metric_utils.py`**:
- The `compute_data_metrics` function is updated to exclude the aborted
requests (identified by padding) from the denominator when calculating
average rewards during training. This prevents the training reward from
being artificially lowered by the zero-reward aborted requests.

This implementation is designed to be a straightforward and effective
solution for the long-tail problem, though some aspects of the
asynchronous design and the impact on training variance require further
investigation.

---------

Co-authored-by: zhaochenyang <[email protected]>
Co-authored-by: PopSoda2002 <[email protected]>
Co-authored-by: ChangyiYang <[email protected]>
Co-authored-by: PrinsYin <[email protected]>
Co-authored-by: WindowsXp-Beta <[email protected]>
Co-authored-by: zhaochenyang20 <[email protected]>
techkang pushed a commit to techkang/verl that referenced this pull request Aug 15, 2025
…engine#2929)

### What does this PR do?

This PR introduces an **over-sample strategy** for verl's SGLang
multi-turn rollout to address the long-tail problem, where a few slow
requests disproportionately increase the overall rollout time. The core
idea is to over-sample the number of requests at the start of the
rollout and then aggressively cancel any requests that haven't finished
once a target number of completions is met.

- **Improves rollout efficiency** for multi-turn conversations by
reducing total time spent waiting for slow requests.
- **Implements a new request monitoring and cancellation mechanism** to
cut off unnecessary computation.

wandb results is as follow:


https://wandb.ai/zhaochenyang20/benchmark_over_sample_2/workspace?nw=nwuserzhaochenyang20

-----

Of course, this strategy has its share of issues. For example, many
might question why the over-long requests that are dropped aren't simply
saved and continued in the next round. This is certainly possible—it's a
partial rollout strategy—but it would require verl to have a data
buffer, which is beyond the scope of this PR. Furthermore, saving and
continuing these requests would introduce an off-policy problem.

There is also a valid concern that this rather "brutal" dropping
strategy could unfairly ignore very long requests. I agree this is a
very reasonable point, but currently, we don't have a lossless solution.
However, our dropping strategy is very flexible and could even change
with our curriculum learning. For instance, in the example I gave, I
just directly dropped the last 20% of requests. **In practice, we can
dynamically adjust this drop rate and even set different dropping
methods. For example, we could record the return time (t) for the 80% of
requests and then drop any requests that haven't returned after 1.5t.**

We've provided an initial, validated idea and have completed its
implementation. We welcome everyone to join the discussion on how to
accelerate multi-turn rollouts with acceptable losses.

### Test

The new over-sample strategy was tested with an 8-GPU setup on the
**gsm8k** dataset, yielding the following results:

- **Rollout Time:** Significant reduction in overall rollout time per
step.
  - **Training Rewards:**
- The reward metric for training steps shows a positive bias. This is
because we exclude the aborted requests (which are typically more
difficult and have lower rewards) from the reward calculation.
- The reward metric for validation steps remains accurate and aligns
with the baseline. This is because the cancellation logic is not
triggered during validation, ensuring a fair and complete evaluation.

### API and Usage Example

This feature modifies `sglang_rollout.py` and `metric_utils.py`. To use
it, follow the standard setup and then run the training script with the
over-sample parameters.


https://github.com/zhaochenyang20/Awesome-ML-SYS-Tutorial/blob/main/rlhf/verl/multi-turn/release_log/over_sample.md

### Design & Code Changes

The design is centered on three main functions that orchestrate the
over-sampling logic: `run_with_cancellation`,
`process_request_with_monitoring`, and `monitor_and_cancel`. These
functions rely on global variables, such as `all_tasks` and
`completion_lock`, to manage state.

- **`run_with_cancellation`:** This is the entry point. It launches all
requests as `process_request_with_monitoring` tasks concurrently with a
single `monitor_and_cancel` task. It uses `asyncio.gather` to wait for
all tasks to complete (or be canceled) and converts any exceptions from
canceled tasks into padding requests before returning the final output.

- **`process_request_with_monitoring`:** This async function handles a
single request. It waits for the request to complete using
`_async_rollout_a_request` and then checks a shared counter,
`completed_count`, using a `completion_lock` for thread safety. If the
target completion count has not been reached, it returns the real
result. If the target has been met, it returns padding data instead,
effectively "discarding" the late result.

- **`monitor_and_cancel`:** This is a separate async task that polls the
`completed_count`. Once the count reaches the `target_completion`
threshold, it immediately cancels all remaining tasks and sends an
`abort_requests` signal to the SGLang engine, halting any ongoing GPU
computation for those requests.

Key code changes:

  - **`sglang_rollout.py`**:
- Adds the three core asynchronous functions for the over-sample
strategy.
- The `AsyncEngine` class now includes a new `abort_request` method that
calls the synchronous `abort_request` in the `tokenizer_manager`.
  - **`metric_utils.py`**:
- The `compute_data_metrics` function is updated to exclude the aborted
requests (identified by padding) from the denominator when calculating
average rewards during training. This prevents the training reward from
being artificially lowered by the zero-reward aborted requests.

This implementation is designed to be a straightforward and effective
solution for the long-tail problem, though some aspects of the
asynchronous design and the impact on training variance require further
investigation.

---------

Co-authored-by: zhaochenyang <[email protected]>
Co-authored-by: PopSoda2002 <[email protected]>
Co-authored-by: ChangyiYang <[email protected]>
Co-authored-by: PrinsYin <[email protected]>
Co-authored-by: WindowsXp-Beta <[email protected]>
Co-authored-by: zhaochenyang20 <[email protected]>
ChangyiYang added a commit to SwordFaith/verl that referenced this pull request Aug 16, 2025
…engine#2929)

This PR introduces an **over-sample strategy** for verl's SGLang
multi-turn rollout to address the long-tail problem, where a few slow
requests disproportionately increase the overall rollout time. The core
idea is to over-sample the number of requests at the start of the
rollout and then aggressively cancel any requests that haven't finished
once a target number of completions is met.

- **Improves rollout efficiency** for multi-turn conversations by
reducing total time spent waiting for slow requests.
- **Implements a new request monitoring and cancellation mechanism** to
cut off unnecessary computation.

wandb results is as follow:

https://wandb.ai/zhaochenyang20/benchmark_over_sample_2/workspace?nw=nwuserzhaochenyang20

-----

Of course, this strategy has its share of issues. For example, many
might question why the over-long requests that are dropped aren't simply
saved and continued in the next round. This is certainly possible—it's a
partial rollout strategy—but it would require verl to have a data
buffer, which is beyond the scope of this PR. Furthermore, saving and
continuing these requests would introduce an off-policy problem.

There is also a valid concern that this rather "brutal" dropping
strategy could unfairly ignore very long requests. I agree this is a
very reasonable point, but currently, we don't have a lossless solution.
However, our dropping strategy is very flexible and could even change
with our curriculum learning. For instance, in the example I gave, I
just directly dropped the last 20% of requests. **In practice, we can
dynamically adjust this drop rate and even set different dropping
methods. For example, we could record the return time (t) for the 80% of
requests and then drop any requests that haven't returned after 1.5t.**

We've provided an initial, validated idea and have completed its
implementation. We welcome everyone to join the discussion on how to
accelerate multi-turn rollouts with acceptable losses.

The new over-sample strategy was tested with an 8-GPU setup on the
**gsm8k** dataset, yielding the following results:

- **Rollout Time:** Significant reduction in overall rollout time per
step.
  - **Training Rewards:**
- The reward metric for training steps shows a positive bias. This is
because we exclude the aborted requests (which are typically more
difficult and have lower rewards) from the reward calculation.
- The reward metric for validation steps remains accurate and aligns
with the baseline. This is because the cancellation logic is not
triggered during validation, ensuring a fair and complete evaluation.

This feature modifies `sglang_rollout.py` and `metric_utils.py`. To use
it, follow the standard setup and then run the training script with the
over-sample parameters.

https://github.com/zhaochenyang20/Awesome-ML-SYS-Tutorial/blob/main/rlhf/verl/multi-turn/release_log/over_sample.md

The design is centered on three main functions that orchestrate the
over-sampling logic: `run_with_cancellation`,
`process_request_with_monitoring`, and `monitor_and_cancel`. These
functions rely on global variables, such as `all_tasks` and
`completion_lock`, to manage state.

- **`run_with_cancellation`:** This is the entry point. It launches all
requests as `process_request_with_monitoring` tasks concurrently with a
single `monitor_and_cancel` task. It uses `asyncio.gather` to wait for
all tasks to complete (or be canceled) and converts any exceptions from
canceled tasks into padding requests before returning the final output.

- **`process_request_with_monitoring`:** This async function handles a
single request. It waits for the request to complete using
`_async_rollout_a_request` and then checks a shared counter,
`completed_count`, using a `completion_lock` for thread safety. If the
target completion count has not been reached, it returns the real
result. If the target has been met, it returns padding data instead,
effectively "discarding" the late result.

- **`monitor_and_cancel`:** This is a separate async task that polls the
`completed_count`. Once the count reaches the `target_completion`
threshold, it immediately cancels all remaining tasks and sends an
`abort_requests` signal to the SGLang engine, halting any ongoing GPU
computation for those requests.

Key code changes:

  - **`sglang_rollout.py`**:
- Adds the three core asynchronous functions for the over-sample
strategy.
- The `AsyncEngine` class now includes a new `abort_request` method that
calls the synchronous `abort_request` in the `tokenizer_manager`.
  - **`metric_utils.py`**:
- The `compute_data_metrics` function is updated to exclude the aborted
requests (identified by padding) from the denominator when calculating
average rewards during training. This prevents the training reward from
being artificially lowered by the zero-reward aborted requests.

This implementation is designed to be a straightforward and effective
solution for the long-tail problem, though some aspects of the
asynchronous design and the impact on training variance require further
investigation.

---------

Co-authored-by: zhaochenyang <[email protected]>
Co-authored-by: PopSoda2002 <[email protected]>
Co-authored-by: ChangyiYang <[email protected]>
Co-authored-by: PrinsYin <[email protected]>
Co-authored-by: WindowsXp-Beta <[email protected]>
Co-authored-by: zhaochenyang20 <[email protected]>
whatadayG pushed a commit to whatadayG/verl that referenced this pull request Sep 5, 2025
…engine#2929)

### What does this PR do?

This PR introduces an **over-sample strategy** for verl's SGLang
multi-turn rollout to address the long-tail problem, where a few slow
requests disproportionately increase the overall rollout time. The core
idea is to over-sample the number of requests at the start of the
rollout and then aggressively cancel any requests that haven't finished
once a target number of completions is met.

- **Improves rollout efficiency** for multi-turn conversations by
reducing total time spent waiting for slow requests.
- **Implements a new request monitoring and cancellation mechanism** to
cut off unnecessary computation.

wandb results is as follow:


https://wandb.ai/zhaochenyang20/benchmark_over_sample_2/workspace?nw=nwuserzhaochenyang20

-----

Of course, this strategy has its share of issues. For example, many
might question why the over-long requests that are dropped aren't simply
saved and continued in the next round. This is certainly possible—it's a
partial rollout strategy—but it would require verl to have a data
buffer, which is beyond the scope of this PR. Furthermore, saving and
continuing these requests would introduce an off-policy problem.

There is also a valid concern that this rather "brutal" dropping
strategy could unfairly ignore very long requests. I agree this is a
very reasonable point, but currently, we don't have a lossless solution.
However, our dropping strategy is very flexible and could even change
with our curriculum learning. For instance, in the example I gave, I
just directly dropped the last 20% of requests. **In practice, we can
dynamically adjust this drop rate and even set different dropping
methods. For example, we could record the return time (t) for the 80% of
requests and then drop any requests that haven't returned after 1.5t.**

We've provided an initial, validated idea and have completed its
implementation. We welcome everyone to join the discussion on how to
accelerate multi-turn rollouts with acceptable losses.

### Test

The new over-sample strategy was tested with an 8-GPU setup on the
**gsm8k** dataset, yielding the following results:

- **Rollout Time:** Significant reduction in overall rollout time per
step.
  - **Training Rewards:**
- The reward metric for training steps shows a positive bias. This is
because we exclude the aborted requests (which are typically more
difficult and have lower rewards) from the reward calculation.
- The reward metric for validation steps remains accurate and aligns
with the baseline. This is because the cancellation logic is not
triggered during validation, ensuring a fair and complete evaluation.

### API and Usage Example

This feature modifies `sglang_rollout.py` and `metric_utils.py`. To use
it, follow the standard setup and then run the training script with the
over-sample parameters.


https://github.com/zhaochenyang20/Awesome-ML-SYS-Tutorial/blob/main/rlhf/verl/multi-turn/release_log/over_sample.md

### Design & Code Changes

The design is centered on three main functions that orchestrate the
over-sampling logic: `run_with_cancellation`,
`process_request_with_monitoring`, and `monitor_and_cancel`. These
functions rely on global variables, such as `all_tasks` and
`completion_lock`, to manage state.

- **`run_with_cancellation`:** This is the entry point. It launches all
requests as `process_request_with_monitoring` tasks concurrently with a
single `monitor_and_cancel` task. It uses `asyncio.gather` to wait for
all tasks to complete (or be canceled) and converts any exceptions from
canceled tasks into padding requests before returning the final output.

- **`process_request_with_monitoring`:** This async function handles a
single request. It waits for the request to complete using
`_async_rollout_a_request` and then checks a shared counter,
`completed_count`, using a `completion_lock` for thread safety. If the
target completion count has not been reached, it returns the real
result. If the target has been met, it returns padding data instead,
effectively "discarding" the late result.

- **`monitor_and_cancel`:** This is a separate async task that polls the
`completed_count`. Once the count reaches the `target_completion`
threshold, it immediately cancels all remaining tasks and sends an
`abort_requests` signal to the SGLang engine, halting any ongoing GPU
computation for those requests.

Key code changes:

  - **`sglang_rollout.py`**:
- Adds the three core asynchronous functions for the over-sample
strategy.
- The `AsyncEngine` class now includes a new `abort_request` method that
calls the synchronous `abort_request` in the `tokenizer_manager`.
  - **`metric_utils.py`**:
- The `compute_data_metrics` function is updated to exclude the aborted
requests (identified by padding) from the denominator when calculating
average rewards during training. This prevents the training reward from
being artificially lowered by the zero-reward aborted requests.

This implementation is designed to be a straightforward and effective
solution for the long-tail problem, though some aspects of the
asynchronous design and the impact on training variance require further
investigation.

---------

Co-authored-by: zhaochenyang <[email protected]>
Co-authored-by: PopSoda2002 <[email protected]>
Co-authored-by: ChangyiYang <[email protected]>
Co-authored-by: PrinsYin <[email protected]>
Co-authored-by: WindowsXp-Beta <[email protected]>
Co-authored-by: zhaochenyang20 <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

8 participants