-
Notifications
You must be signed in to change notification settings - Fork 88
[trainer/algorithm] Implement DAPO and Polaris style dynamic sampling + add DAPO docs + example #130
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
+1,055
−19
Merged
[trainer/algorithm] Implement DAPO and Polaris style dynamic sampling + add DAPO docs + example #130
Changes from all commits
Commits
Show all changes
54 commits
Select commit
Hold shift + click to select a range
69fd7e0
fix bug
erictang000 f59abaa
remove fsdp from fsdp2 hf save model architecture
erictang000 51119e5
merge
erictang000 5659f9b
x
erictang000 fc9e355
thanks gemini
erictang000 6810779
remove extra ray.shutdown
erictang000 c4bde2a
deepspeed + fsdp add configs to checkpoint folder
erictang000 ac018fd
Merge branch 'main' of https://github.com/erictang000/SkyRL into conf…
erictang000 0e8facc
pull to parent function for shared logic
erictang000 9a865f5
x
erictang000 7202d21
docs
erictang000 4445e42
x
erictang000 bec693e
x
erictang000 9b7c7d2
address gemini comments
erictang000 119d9cd
x
erictang000 f32ffa9
Merge branch 'main' of https://github.com/erictang000/SkyRL
erictang000 35db88e
Merge branch 'config_checkpointing' of https://github.com/erictang000…
erictang000 3cce025
Merge branch 'main' of https://github.com/erictang000/SkyRL
erictang000 f5267b1
Merge branch 'main' of https://github.com/erictang000/SkyRL
erictang000 e11db0a
x
erictang000 ad7b045
unit tests passing - need to test both e2e
erictang000 5a909f5
x
erictang000 615133d
Merge branch 'main' of https://github.com/erictang000/SkyRL into dyna…
erictang000 43edec4
x
erictang000 19f5816
fixes
erictang000 8419051
fixes
erictang000 7bf7e54
x
erictang000 2b2e326
Merge branch 'main' of https://github.com/erictang000/SkyRL into dyna…
erictang000 40f26e8
fix weight manager logic
erictang000 4822d52
x
erictang000 5232569
thanks gemini
erictang000 81a3819
x
erictang000 3706f07
x
erictang000 0c40fed
x
erictang000 13574c9
x
erictang000 b9f03d7
x
erictang000 bcd53eb
Apply suggestions from code review
erictang000 c063aee
address comments
erictang000 f0890d2
Merge branch 'dynamic_sampling' of https://github.com/erictang000/Sky…
erictang000 46ddda9
fix tests
erictang000 8118f55
add soft overlong punishment
erictang000 3b2d007
x
erictang000 a2ac205
thanks gemini
erictang000 8fa7e42
x
erictang000 0ba0d52
change to overriding trainer
erictang000 1c24673
x
erictang000 e8f7be4
x
erictang000 87d6add
x
erictang000 5c91bcc
x
erictang000 eac1c77
x
erictang000 c777948
add more docs for custom trainer
erictang000 401fd72
add ref to dapo example
erictang000 e7c5616
x
erictang000 59bb246
thanks gemini
erictang000 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,113 @@ | ||
DAPO | ||
==== | ||
|
||
The `DAPO <https://arxiv.org/abs/2503.14476>`_ (Decoupled Clip and Dynamic Sampling Policy Optimization) algorithm consists of the following components on top of a GRPO baseline: | ||
|
||
- **Clip-Higher**: Promotes the diversity of the system and avoids entropy collapse; | ||
- **Dynamic Sampling**: Improves training efficiency and stability; | ||
- **Token-Level Policy Gradient Loss**: Critical in long-CoT RL scenarios; | ||
- **Overlong Reward Shaping**: Reduces reward noise and stabilizes training. | ||
|
||
In this guide, we walk through how to enable each of these components in SkyRL. We provide a simple example script for training DAPO on GSM8K in :code_link:`examples/algorithms/dapo/`. | ||
|
||
Clip-Higher | ||
~~~~~~~~~~~ | ||
To use clip-higher, you can simply configure ``trainer.algorithm.eps_clip_high`` separately from ``trainer.algorithm.eps_clip_low``. | ||
|
||
.. code-block:: yaml | ||
trainer: | ||
algorithm: | ||
eps_clip_low: 0.2 | ||
eps_clip_high: 0.28 | ||
Dynamic Sampling | ||
~~~~~~~~~~~~~~~~ | ||
In DAPO style dynamic sampling, we sample rollouts until we have a full batch with non-zero advantages (meaning that we have a non-zero std deviation of rewards for the n rollouts for a given prompt). | ||
|
||
To configure DAPO style dynamic sampling, you can set ``trainer.algorithm.dynamic_sampling.type`` to ``filter`` and configure ``trainer.algorithm.dynamic_sampling.max_sample_batches`` to the maximum number of batches to sample. | ||
If ``max_sample_batches > 0`` and is exceeded, SkyRL-Train will raise an error. If ``max_sample_batches <= 0``, SkyRL-Train will sample until a full batch with non-zero advantages is accumulated. | ||
|
||
.. code-block:: yaml | ||
trainer: | ||
algorithm: | ||
dynamic_sampling: | ||
type: filter | ||
max_sample_batches: 30 | ||
Token-Level Policy Gradient Loss | ||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | ||
DAPO uses token-level policy gradient loss, which can be enabled by setting ``trainer.algorithm.loss_reduction`` to ``token_mean``. This is the default setting in SkyRL-Train. | ||
|
||
.. code-block:: yaml | ||
trainer: | ||
algorithm: | ||
loss_reduction: "token_mean" | ||
Overlong Reward Shaping | ||
~~~~~~~~~~~~~~~~~~~~~~~~ | ||
The DAPO paper proposes two methods for overlong reward shaping: | ||
|
||
- **Overlong Filtering**: Sets loss mask to be all zeros for responses that exceed the max response length. | ||
- **Soft Overlong Punishment**: Penalizes responses that exceed the max response length within a punishment interval. Within this interval, the longer the response, the greater the punishment it receives. This penalty is added to the original reward. | ||
|
||
Overlong Filtering | ||
------------------ | ||
|
||
To enable overlong filtering, which sets the loss mask to be all zeros for responses that do not finish with a stop token (i.e. responses that are too long), you can set ``generator.apply_overlong_filtering`` to ``true``. | ||
|
||
.. code-block:: yaml | ||
generator: | ||
apply_overlong_filtering: true | ||
.. _dapo-custom-trainer: | ||
|
||
Soft Overlong Punishment | ||
------------------------ | ||
|
||
To enable soft overlong punishment, you can create a custom trainer class and override the ``RayPPOTrainer`` ``postprocess_generator_output`` method to additionally apply soft overlong punishment to rewards. | ||
We provide an example of this in :code_link:`examples/algorithms/dapo/main_dapo.py`, and show an overview of the implementation below: | ||
|
||
.. code-block:: python | ||
:caption: ``skyrl_train/examples/algorithms/dapo/main_dapo.py`` | ||
class DAPOTrainer(RayPPOTrainer): | ||
@torch.no_grad() | ||
def postprocess_generator_output(self, generator_output: GeneratorOutput, uids: List[str]) -> GeneratorOutput: | ||
# apply soft overlong punishment | ||
overlong_buffer_len = self.cfg.trainer.algorithm.overlong_buffer.len | ||
overlong_buffer_penalty_factor = self.cfg.trainer.algorithm.overlong_buffer.penalty_factor | ||
... | ||
# use base class impl for metrics and per-token reward conversion | ||
return super().postprocess_generator_output(generator_output, uids) | ||
class DAPOExp(BasePPOExp): | ||
def get_trainer(self, *args, **kwargs): | ||
return DAPOTrainer(*args, **kwargs) | ||
@ray.remote(num_cpus=1) | ||
def skyrl_entrypoint(cfg: DictConfig): | ||
exp = DAPOExp(cfg) | ||
exp.run() | ||
To add the overlong buffer length and penalty factor parameters to the config, you can add the following lines to the ``run_dapo_gsm8k.sh`` script: | ||
|
||
.. code-block:: bash | ||
:caption: ``skyrl_train/examples/algorithms/dapo/run_dapo_gsm8k.sh`` | ||
+trainer.algorithm.overlong_buffer.len=512 \ | ||
+trainer.algorithm.overlong_buffer.penalty_factor=1.0 \ | ||
Launching a DAPO Training Run | ||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | ||
|
||
An example script with all of the above components enabled for basic GSM8K training can be found at :code_link:`examples/algorithms/dapo/run_dapo_gsm8k.sh`. | ||
|
||
.. code-block:: bash | ||
export WANDB_API_KEY=your_wandb_api_key | ||
bash examples/algorithms/dapo/run_dapo_gsm8k.sh | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,98 @@ | ||
""" | ||
uv run --isolated --extra vllm -m examples.algorithms.dapo.main_dapo | ||
""" | ||
|
||
import ray | ||
import hydra | ||
import torch | ||
from typing import List | ||
from omegaconf import DictConfig | ||
from skyrl_train.trainer import RayPPOTrainer | ||
from skyrl_train.utils import initialize_ray | ||
from skyrl_train.entrypoints.main_base import BasePPOExp, config_dir, validate_cfg | ||
|
||
from skyrl_train.generators.base import GeneratorOutput | ||
|
||
|
||
class DAPOTrainer(RayPPOTrainer): | ||
""" | ||
Custom trainer for DAPO. | ||
|
||
Overrides the postprocess_generator_output method to additionally apply soft overlong punishment to rewards. | ||
""" | ||
|
||
@torch.no_grad() | ||
def postprocess_generator_output(self, generator_output: GeneratorOutput, uids: List[str]) -> GeneratorOutput: | ||
""" | ||
Overrides the postprocess_generator_output method to additionally apply DAPO specific soft overlong punishment to rewards. | ||
|
||
Args: | ||
generator_output: GeneratorOutput | ||
uids: List[str] | ||
|
||
Returns: | ||
GeneratorOutput | ||
""" | ||
overlong_buffer_len = self.cfg.trainer.algorithm.overlong_buffer.len | ||
overlong_buffer_penalty_factor = self.cfg.trainer.algorithm.overlong_buffer.penalty_factor | ||
# modify rewards here | ||
prompt_token_ids = generator_output["prompt_token_ids"] | ||
response_ids = generator_output["response_ids"] | ||
rewards = generator_output["rewards"] | ||
|
||
assert not isinstance(rewards[0], list), "we assume verifiable sequence level rewards here" | ||
|
||
# get the prompt length | ||
prompt_lengths = [len(prompt) for prompt in prompt_token_ids] | ||
|
||
# get the response length | ||
response_lengths = [len(response) for response in response_ids] | ||
|
||
# get the max context length | ||
max_context_length = ( | ||
self.cfg.generator.max_input_length + self.cfg.generator.sampling_params.max_generate_length | ||
) | ||
|
||
# apply soft overlong punishment | ||
for i, (prompt_length, response_length) in enumerate(zip(prompt_lengths, response_lengths)): | ||
# max_exceed_length is the beginning of the overlong buffer | ||
max_exceed_length = max_context_length - overlong_buffer_len - prompt_length | ||
# if the response is within the overlong buffer, apply the penalty | ||
if response_length > max_exceed_length and response_length <= max_context_length - prompt_length: | ||
exceed_length = response_length - max_exceed_length | ||
penalty = exceed_length / overlong_buffer_len * overlong_buffer_penalty_factor | ||
|
||
rewards[i] -= penalty | ||
# if the response is outside the overlong buffer, set the reward to 0 | ||
elif response_length > max_context_length - prompt_length: | ||
# if self.cfg.generator.apply_overlong_filtering is true, loss masks are already set to 0 for these responses | ||
rewards[i] = 0.0 | ||
|
||
generator_output["rewards"] = rewards | ||
|
||
# use base class impl for metrics and per-token reward conversion | ||
return super().postprocess_generator_output(generator_output, uids) | ||
|
||
|
||
class DAPOExp(BasePPOExp): | ||
def get_trainer(self, *args, **kwargs): | ||
return DAPOTrainer(*args, **kwargs) | ||
|
||
|
||
@ray.remote(num_cpus=1) | ||
def skyrl_entrypoint(cfg: DictConfig): | ||
exp = DAPOExp(cfg) | ||
exp.run() | ||
|
||
|
||
@hydra.main(config_path=config_dir, config_name="ppo_base_config", version_base=None) | ||
def main(cfg: DictConfig) -> None: | ||
# validate the arguments | ||
validate_cfg(cfg) | ||
|
||
initialize_ray(cfg) | ||
ray.get(skyrl_entrypoint.remote(cfg)) | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.