Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
54 commits
Select commit Hold shift + click to select a range
69fd7e0
fix bug
erictang000 Jul 25, 2025
f59abaa
remove fsdp from fsdp2 hf save model architecture
erictang000 Jul 28, 2025
51119e5
merge
erictang000 Jul 28, 2025
5659f9b
x
erictang000 Jul 28, 2025
fc9e355
thanks gemini
erictang000 Jul 28, 2025
6810779
remove extra ray.shutdown
erictang000 Jul 28, 2025
c4bde2a
deepspeed + fsdp add configs to checkpoint folder
erictang000 Jul 29, 2025
ac018fd
Merge branch 'main' of https://github.com/erictang000/SkyRL into conf…
erictang000 Jul 29, 2025
0e8facc
pull to parent function for shared logic
erictang000 Jul 29, 2025
9a865f5
x
erictang000 Jul 29, 2025
7202d21
docs
erictang000 Jul 29, 2025
4445e42
x
erictang000 Jul 29, 2025
bec693e
x
erictang000 Jul 29, 2025
9b7c7d2
address gemini comments
erictang000 Jul 29, 2025
119d9cd
x
erictang000 Jul 29, 2025
f32ffa9
Merge branch 'main' of https://github.com/erictang000/SkyRL
erictang000 Jul 29, 2025
35db88e
Merge branch 'config_checkpointing' of https://github.com/erictang000…
erictang000 Jul 29, 2025
3cce025
Merge branch 'main' of https://github.com/erictang000/SkyRL
erictang000 Jul 29, 2025
f5267b1
Merge branch 'main' of https://github.com/erictang000/SkyRL
erictang000 Jul 31, 2025
e11db0a
x
erictang000 Aug 1, 2025
ad7b045
unit tests passing - need to test both e2e
erictang000 Aug 2, 2025
5a909f5
x
erictang000 Aug 2, 2025
615133d
Merge branch 'main' of https://github.com/erictang000/SkyRL into dyna…
erictang000 Aug 4, 2025
43edec4
x
erictang000 Aug 4, 2025
19f5816
fixes
erictang000 Aug 4, 2025
8419051
fixes
erictang000 Aug 4, 2025
7bf7e54
x
erictang000 Aug 4, 2025
2b2e326
Merge branch 'main' of https://github.com/erictang000/SkyRL into dyna…
erictang000 Aug 4, 2025
40f26e8
fix weight manager logic
erictang000 Aug 4, 2025
4822d52
x
erictang000 Aug 4, 2025
5232569
thanks gemini
erictang000 Aug 4, 2025
81a3819
x
erictang000 Aug 5, 2025
3706f07
x
erictang000 Aug 5, 2025
0c40fed
x
erictang000 Aug 5, 2025
13574c9
x
erictang000 Aug 5, 2025
b9f03d7
x
erictang000 Aug 5, 2025
bcd53eb
Apply suggestions from code review
erictang000 Aug 6, 2025
c063aee
address comments
erictang000 Aug 6, 2025
f0890d2
Merge branch 'dynamic_sampling' of https://github.com/erictang000/Sky…
erictang000 Aug 6, 2025
46ddda9
fix tests
erictang000 Aug 6, 2025
8118f55
add soft overlong punishment
erictang000 Aug 7, 2025
3b2d007
x
erictang000 Aug 7, 2025
a2ac205
thanks gemini
erictang000 Aug 7, 2025
8fa7e42
x
erictang000 Aug 7, 2025
0ba0d52
change to overriding trainer
erictang000 Aug 7, 2025
1c24673
x
erictang000 Aug 7, 2025
e8f7be4
x
erictang000 Aug 7, 2025
87d6add
x
erictang000 Aug 7, 2025
5c91bcc
x
erictang000 Aug 7, 2025
eac1c77
x
erictang000 Aug 7, 2025
c777948
add more docs for custom trainer
erictang000 Aug 7, 2025
401fd72
add ref to dapo example
erictang000 Aug 7, 2025
e7c5616
x
erictang000 Aug 7, 2025
59bb246
thanks gemini
erictang000 Aug 7, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 39 additions & 2 deletions skyrl-train/docs/algorithms/custom_algorithms.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,10 @@ SkyRL-Train provides a registry system for easily implementing custom algorithms
The API for the registry system can be found in the :doc:`registry API <../api/registry>`.
Example scripts of using the registry can be found in at :code_link:`examples/algorithm/`.

Additionally for more control, you can subclass the ``BasePPOExp`` class from :code_link:`skyrl_train/entrypoints/main_base.py` and override the ``BasePPOExp.get_trainer`` method to return a custom trainer class.
This allows you to have full control over the training loop and implementing custom reward functions and output postprocessing.
We provide an example of this for applying custom reward penalties in our :ref:`DAPO example <dapo-custom-trainer>`.

Registering a Custom Advantage Estimator
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

Expand Down Expand Up @@ -53,8 +57,8 @@ Similarly, you can register custom policy loss functions:
# return loss and clip ratio
return loss, 0.0

Ray Distribution
~~~~~~~~~~~~~~~~
Registry Ray Distribution
~~~~~~~~~~~~~~~~~~~~~~~~~~

The registry system handles Ray actor synchronization when Ray is initialized. Functions registered on one process will be available to all Ray actors:

Expand All @@ -81,5 +85,38 @@ The registry system handles Ray actor synchronization when Ray is initialized. F
exp = BasePPOExp(cfg)
exp.run()

Creating a Custom Trainer
~~~~~~~~~~~~~~~~~~~~~~~~~~

To create a custom trainer for full control of your training loop, you can subclass the ``BasePPOExp`` class from :code_link:`skyrl_train/entrypoints/main_base.py` and override the ``BasePPOExp.get_trainer`` method to return a custom trainer class.
We show the outline of creating a custom trainer below, and you can find a full running example in our :ref:`DAPO example <dapo-custom-trainer>`.

.. code-block:: python

class CustomTrainer(RayPPOTrainer):
@torch.no_grad()
def postprocess_generator_output(self, generator_output: GeneratorOutput, uids: List[str]) -> GeneratorOutput:
# apply custom reward penalties
...
# use base class impl for metrics and per-token reward conversion
return super().postprocess_generator_output(generator_output, uids)

class CustomExp(BasePPOExp):
def get_trainer(self, *args, **kwargs):
return CustomTrainer(*args, **kwargs)

@ray.remote(num_cpus=1)
def skyrl_entrypoint(cfg: DictConfig):
exp = CustomExp(cfg)
exp.run()

@hydra.main(config_path=config_dir, config_name="ppo_base_config", version_base=None)
def main(cfg: DictConfig) -> None:
# validate the arguments
validate_cfg(cfg)

initialize_ray(cfg)
ray.get(skyrl_entrypoint.remote(cfg))

if __name__ == "__main__":
main()
113 changes: 113 additions & 0 deletions skyrl-train/docs/algorithms/dapo.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
DAPO
====

The `DAPO <https://arxiv.org/abs/2503.14476>`_ (Decoupled Clip and Dynamic Sampling Policy Optimization) algorithm consists of the following components on top of a GRPO baseline:

- **Clip-Higher**: Promotes the diversity of the system and avoids entropy collapse;
- **Dynamic Sampling**: Improves training efficiency and stability;
- **Token-Level Policy Gradient Loss**: Critical in long-CoT RL scenarios;
- **Overlong Reward Shaping**: Reduces reward noise and stabilizes training.

In this guide, we walk through how to enable each of these components in SkyRL. We provide a simple example script for training DAPO on GSM8K in :code_link:`examples/algorithms/dapo/`.

Clip-Higher
~~~~~~~~~~~
To use clip-higher, you can simply configure ``trainer.algorithm.eps_clip_high`` separately from ``trainer.algorithm.eps_clip_low``.

.. code-block:: yaml
trainer:
algorithm:
eps_clip_low: 0.2
eps_clip_high: 0.28
Dynamic Sampling
~~~~~~~~~~~~~~~~
In DAPO style dynamic sampling, we sample rollouts until we have a full batch with non-zero advantages (meaning that we have a non-zero std deviation of rewards for the n rollouts for a given prompt).

To configure DAPO style dynamic sampling, you can set ``trainer.algorithm.dynamic_sampling.type`` to ``filter`` and configure ``trainer.algorithm.dynamic_sampling.max_sample_batches`` to the maximum number of batches to sample.
If ``max_sample_batches > 0`` and is exceeded, SkyRL-Train will raise an error. If ``max_sample_batches <= 0``, SkyRL-Train will sample until a full batch with non-zero advantages is accumulated.

.. code-block:: yaml
trainer:
algorithm:
dynamic_sampling:
type: filter
max_sample_batches: 30
Token-Level Policy Gradient Loss
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
DAPO uses token-level policy gradient loss, which can be enabled by setting ``trainer.algorithm.loss_reduction`` to ``token_mean``. This is the default setting in SkyRL-Train.

.. code-block:: yaml
trainer:
algorithm:
loss_reduction: "token_mean"
Overlong Reward Shaping
~~~~~~~~~~~~~~~~~~~~~~~~
The DAPO paper proposes two methods for overlong reward shaping:

- **Overlong Filtering**: Sets loss mask to be all zeros for responses that exceed the max response length.
- **Soft Overlong Punishment**: Penalizes responses that exceed the max response length within a punishment interval. Within this interval, the longer the response, the greater the punishment it receives. This penalty is added to the original reward.

Overlong Filtering
------------------

To enable overlong filtering, which sets the loss mask to be all zeros for responses that do not finish with a stop token (i.e. responses that are too long), you can set ``generator.apply_overlong_filtering`` to ``true``.

.. code-block:: yaml
generator:
apply_overlong_filtering: true
.. _dapo-custom-trainer:

Soft Overlong Punishment
------------------------

To enable soft overlong punishment, you can create a custom trainer class and override the ``RayPPOTrainer`` ``postprocess_generator_output`` method to additionally apply soft overlong punishment to rewards.
We provide an example of this in :code_link:`examples/algorithms/dapo/main_dapo.py`, and show an overview of the implementation below:

.. code-block:: python
:caption: ``skyrl_train/examples/algorithms/dapo/main_dapo.py``
class DAPOTrainer(RayPPOTrainer):
@torch.no_grad()
def postprocess_generator_output(self, generator_output: GeneratorOutput, uids: List[str]) -> GeneratorOutput:
# apply soft overlong punishment
overlong_buffer_len = self.cfg.trainer.algorithm.overlong_buffer.len
overlong_buffer_penalty_factor = self.cfg.trainer.algorithm.overlong_buffer.penalty_factor
...
# use base class impl for metrics and per-token reward conversion
return super().postprocess_generator_output(generator_output, uids)
class DAPOExp(BasePPOExp):
def get_trainer(self, *args, **kwargs):
return DAPOTrainer(*args, **kwargs)
@ray.remote(num_cpus=1)
def skyrl_entrypoint(cfg: DictConfig):
exp = DAPOExp(cfg)
exp.run()
To add the overlong buffer length and penalty factor parameters to the config, you can add the following lines to the ``run_dapo_gsm8k.sh`` script:

.. code-block:: bash
:caption: ``skyrl_train/examples/algorithms/dapo/run_dapo_gsm8k.sh``
+trainer.algorithm.overlong_buffer.len=512 \
+trainer.algorithm.overlong_buffer.penalty_factor=1.0 \
Launching a DAPO Training Run
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

An example script with all of the above components enabled for basic GSM8K training can be found at :code_link:`examples/algorithms/dapo/run_dapo_gsm8k.sh`.

.. code-block:: bash
export WANDB_API_KEY=your_wandb_api_key
bash examples/algorithms/dapo/run_dapo_gsm8k.sh
12 changes: 12 additions & 0 deletions skyrl-train/docs/configuration/config.rst
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,13 @@ Algorithm Configuration
value_clip: 0.2
normalize_reward: true

# dynamic sampling parameters
dynamic_sampling:
type: null # filter (DAPO), replace (POLARIS/WebSailor), or null
max_sample_batches: 30 # sample at most this many batches before stopping, -1 to sample forever
min_replace_ratio: 0.3 # minimum proportion of good samples with which to replace bad samples (for replace strategy only)


- ``algorithm.advantage_estimator``: Advantage estimator to use. We currently implement ``grpo`` and ``gae``, and custom advantage estimators can be registered with the ``AdvantageEstimatorRegistry``.
- ``algorithm.use_kl_estimator_k3``: Whether to use the k3 estimator for KL divergence calculation. The k3 estimator is the non negative kl approximation in `this blog post <http://joschu.net/blog/kl-approx.html>`_. Besides non negative, it is also unbiased and has lower variance.
- ``algorithm.use_abs_kl``: Whether to use the absolute KL divergence for KL divergence calculation.
Expand All @@ -324,6 +331,11 @@ Algorithm Configuration
- ``algorithm.clip_ratio_c``: Clip ratio for dual clip PPO loss.
- ``algorithm.value_clip``: Clip value for value loss.
- ``algorithm.normalize_reward``: Whether to normalize critic model output (i.e., values). When ``true``, the critic model learns the mean and standard deviation of the values during training and normalizes the values during forward pass.
- ``algorithm.dynamic_sampling``: Dynamic sampling configuration.
- ``algorithm.dynamic_sampling.type``: Type of dynamic sampling to use. Currently, we support ``filter`` (`DAPO <https://dapo-sia.github.io/>`_), ``replace`` (`POLARIS <https://hkunlp.github.io/blog/2025/Polaris/>`_ / `WebSailor <https://arxiv.org/abs/2507.02592>`_), or ``null`` for no dynamic sampling.
- ``algorithm.dynamic_sampling.max_sample_batches``: Maximum number of batches to sample before stopping. Set to ``-1`` to sample forever.
- ``algorithm.dynamic_sampling.min_replace_ratio``: Minimum proportion of good samples with which to replace bad samples for ``replace`` strategy.


Policy Loss Formulation
~~~~~~~~~~~~~~~~~~~~~~~
Expand Down
15 changes: 8 additions & 7 deletions skyrl-train/docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,14 @@ SkyRL is a full-stack RL library designed for modularity and extensibility.

recipes/skyrl-sql
recipes/searchr1


.. toctree::
:maxdepth: 2
:caption: Algorithms

algorithms/dapo
algorithms/custom_algorithms

.. toctree::
:maxdepth: 2
:caption: Configuration
Expand All @@ -70,12 +77,6 @@ SkyRL is a full-stack RL library designed for modularity and extensibility.
api/registry
api/tools

.. toctree::
:maxdepth: 2
:caption: Algorithms

algorithms/custom_algorithms

.. toctree::
:maxdepth: 2
:caption: Troubleshooting
Expand Down
98 changes: 98 additions & 0 deletions skyrl-train/examples/algorithms/dapo/main_dapo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
"""
uv run --isolated --extra vllm -m examples.algorithms.dapo.main_dapo
"""

import ray
import hydra
import torch
from typing import List
from omegaconf import DictConfig
from skyrl_train.trainer import RayPPOTrainer
from skyrl_train.utils import initialize_ray
from skyrl_train.entrypoints.main_base import BasePPOExp, config_dir, validate_cfg

from skyrl_train.generators.base import GeneratorOutput


class DAPOTrainer(RayPPOTrainer):
"""
Custom trainer for DAPO.

Overrides the postprocess_generator_output method to additionally apply soft overlong punishment to rewards.
"""

@torch.no_grad()
def postprocess_generator_output(self, generator_output: GeneratorOutput, uids: List[str]) -> GeneratorOutput:
"""
Overrides the postprocess_generator_output method to additionally apply DAPO specific soft overlong punishment to rewards.

Args:
generator_output: GeneratorOutput
uids: List[str]

Returns:
GeneratorOutput
"""
overlong_buffer_len = self.cfg.trainer.algorithm.overlong_buffer.len
overlong_buffer_penalty_factor = self.cfg.trainer.algorithm.overlong_buffer.penalty_factor
# modify rewards here
prompt_token_ids = generator_output["prompt_token_ids"]
response_ids = generator_output["response_ids"]
rewards = generator_output["rewards"]

assert not isinstance(rewards[0], list), "we assume verifiable sequence level rewards here"

# get the prompt length
prompt_lengths = [len(prompt) for prompt in prompt_token_ids]

# get the response length
response_lengths = [len(response) for response in response_ids]

# get the max context length
max_context_length = (
self.cfg.generator.max_input_length + self.cfg.generator.sampling_params.max_generate_length
)

# apply soft overlong punishment
for i, (prompt_length, response_length) in enumerate(zip(prompt_lengths, response_lengths)):
# max_exceed_length is the beginning of the overlong buffer
max_exceed_length = max_context_length - overlong_buffer_len - prompt_length
# if the response is within the overlong buffer, apply the penalty
if response_length > max_exceed_length and response_length <= max_context_length - prompt_length:
exceed_length = response_length - max_exceed_length
penalty = exceed_length / overlong_buffer_len * overlong_buffer_penalty_factor

rewards[i] -= penalty
# if the response is outside the overlong buffer, set the reward to 0
elif response_length > max_context_length - prompt_length:
# if self.cfg.generator.apply_overlong_filtering is true, loss masks are already set to 0 for these responses
rewards[i] = 0.0

generator_output["rewards"] = rewards

# use base class impl for metrics and per-token reward conversion
return super().postprocess_generator_output(generator_output, uids)


class DAPOExp(BasePPOExp):
def get_trainer(self, *args, **kwargs):
return DAPOTrainer(*args, **kwargs)


@ray.remote(num_cpus=1)
def skyrl_entrypoint(cfg: DictConfig):
exp = DAPOExp(cfg)
exp.run()


@hydra.main(config_path=config_dir, config_name="ppo_base_config", version_base=None)
def main(cfg: DictConfig) -> None:
# validate the arguments
validate_cfg(cfg)

initialize_ray(cfg)
ray.get(skyrl_entrypoint.remote(cfg))


if __name__ == "__main__":
main()
Loading