Skip to content

[trainer/algorithm] Implement DAPO and Polaris style dynamic sampling + add DAPO docs + example #130

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 54 commits into from
Aug 7, 2025
Merged
Show file tree
Hide file tree
Changes from 44 commits
Commits
Show all changes
54 commits
Select commit Hold shift + click to select a range
69fd7e0
fix bug
erictang000 Jul 25, 2025
f59abaa
remove fsdp from fsdp2 hf save model architecture
erictang000 Jul 28, 2025
51119e5
merge
erictang000 Jul 28, 2025
5659f9b
x
erictang000 Jul 28, 2025
fc9e355
thanks gemini
erictang000 Jul 28, 2025
6810779
remove extra ray.shutdown
erictang000 Jul 28, 2025
c4bde2a
deepspeed + fsdp add configs to checkpoint folder
erictang000 Jul 29, 2025
ac018fd
Merge branch 'main' of https://github.com/erictang000/SkyRL into conf…
erictang000 Jul 29, 2025
0e8facc
pull to parent function for shared logic
erictang000 Jul 29, 2025
9a865f5
x
erictang000 Jul 29, 2025
7202d21
docs
erictang000 Jul 29, 2025
4445e42
x
erictang000 Jul 29, 2025
bec693e
x
erictang000 Jul 29, 2025
9b7c7d2
address gemini comments
erictang000 Jul 29, 2025
119d9cd
x
erictang000 Jul 29, 2025
f32ffa9
Merge branch 'main' of https://github.com/erictang000/SkyRL
erictang000 Jul 29, 2025
35db88e
Merge branch 'config_checkpointing' of https://github.com/erictang000…
erictang000 Jul 29, 2025
3cce025
Merge branch 'main' of https://github.com/erictang000/SkyRL
erictang000 Jul 29, 2025
f5267b1
Merge branch 'main' of https://github.com/erictang000/SkyRL
erictang000 Jul 31, 2025
e11db0a
x
erictang000 Aug 1, 2025
ad7b045
unit tests passing - need to test both e2e
erictang000 Aug 2, 2025
5a909f5
x
erictang000 Aug 2, 2025
615133d
Merge branch 'main' of https://github.com/erictang000/SkyRL into dyna…
erictang000 Aug 4, 2025
43edec4
x
erictang000 Aug 4, 2025
19f5816
fixes
erictang000 Aug 4, 2025
8419051
fixes
erictang000 Aug 4, 2025
7bf7e54
x
erictang000 Aug 4, 2025
2b2e326
Merge branch 'main' of https://github.com/erictang000/SkyRL into dyna…
erictang000 Aug 4, 2025
40f26e8
fix weight manager logic
erictang000 Aug 4, 2025
4822d52
x
erictang000 Aug 4, 2025
5232569
thanks gemini
erictang000 Aug 4, 2025
81a3819
x
erictang000 Aug 5, 2025
3706f07
x
erictang000 Aug 5, 2025
0c40fed
x
erictang000 Aug 5, 2025
13574c9
x
erictang000 Aug 5, 2025
b9f03d7
x
erictang000 Aug 5, 2025
bcd53eb
Apply suggestions from code review
erictang000 Aug 6, 2025
c063aee
address comments
erictang000 Aug 6, 2025
f0890d2
Merge branch 'dynamic_sampling' of https://github.com/erictang000/Sky…
erictang000 Aug 6, 2025
46ddda9
fix tests
erictang000 Aug 6, 2025
8118f55
add soft overlong punishment
erictang000 Aug 7, 2025
3b2d007
x
erictang000 Aug 7, 2025
a2ac205
thanks gemini
erictang000 Aug 7, 2025
8fa7e42
x
erictang000 Aug 7, 2025
0ba0d52
change to overriding trainer
erictang000 Aug 7, 2025
1c24673
x
erictang000 Aug 7, 2025
e8f7be4
x
erictang000 Aug 7, 2025
87d6add
x
erictang000 Aug 7, 2025
5c91bcc
x
erictang000 Aug 7, 2025
eac1c77
x
erictang000 Aug 7, 2025
c777948
add more docs for custom trainer
erictang000 Aug 7, 2025
401fd72
add ref to dapo example
erictang000 Aug 7, 2025
e7c5616
x
erictang000 Aug 7, 2025
59bb246
thanks gemini
erictang000 Aug 7, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 64 additions & 0 deletions skyrl-train/docs/algorithms/dapo.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
DAPO
====

The `DAPO <https://arxiv.org/abs/2503.14476>`_ (Decoupled Clip and Dynamic Sampling Policy Optimization) algorithm consists of the following components on top of a GRPO baseline:

- **Clip-Higher**: Promotes the diversity of the system and avoids entropy collapse;
- **Dynamic Sampling**: Improves training efficiency and stability;
- **Token-Level Policy Gradient Loss**: Critical in long-CoT RL scenarios;
- **Overlong Reward Shaping**: Reduces reward noise and stabilizes training.

In this guide, we walk through how to enable each of these components in SkyRL. We provide a simple example script for training DAPO on GSM8K in :code_link:`examples/algorithms/dapo/`.

Clip-Higher
~~~~~~~~~~~
To use clip-higher, you can simply configure ``trainer.algorithm.eps_clip_high`` separately from ``trainer.algorithm.eps_clip_low``.

.. code-block:: yaml

trainer:
algorithm:
eps_clip_low: 0.2
eps_clip_high: 0.28

Dynamic Sampling
~~~~~~~~~~~~~~~~
In DAPO style dynamic sampling, we sample rollouts until we have a full batch with non-zero advantages (meaning that we have a non-zero std deviation of rewards for the n rollouts for a given prompt).

To configure DAPO style dynamic sampling, you can set ``trainer.algorithm.dynamic_sampling.type`` to ``filter`` and configure ``trainer.algorithm.dynamic_sampling.max_sample_batches`` to the maximum number of batches to sample.
If ``max_sample_batches > 0`` and is exceeded, SkyRL-Train will raise an error. If ``max_sample_batches <= 0``, SkyRL-Train will sample until a full batch with non-zero advantages is accumulated.
.. code-block:: yaml

trainer:
algorithm:
dynamic_sampling:
type: filter
max_sample_batches: 30

Token-Level Policy Gradient Loss
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
DAPO uses token-level policy gradient loss, which can be enabled by setting ``trainer.algorithm.loss_reduction`` to ``token_mean``. This is the default setting in SkyRL-Train.

.. code-block:: yaml

trainer:
algorithm:
loss_reduction: "token_mean"

Overlong Reward Shaping
~~~~~~~~~~~~~~~~~~~~~~~~
The DAPO paper proposes two methods for overlong reward shaping:

- **Overlong Filtering**: Sets loss mask to be all zeros for responses that exceed the max response length.
- **Soft Overlong Punishment**: Penalizes responses that exceed the max response length within a punishment interval. Within this interval, the longer the response, the greater the punishment it receives. This penalty is added to the original reward.

To enable overlong filtering, which sets loss mask to be all zeros for responses that do not finish with a stop token (i.e. responses that are too long), you can set ``generator.apply_overlong_filtering`` to ``true``.

To enable soft overlong punishment, you can register a custom advantage estimator, which we show an example of in :code_link:`examples/algorithms/dapo/main_dapo.py`.

.. code-block:: yaml

generator:
apply_overlong_filtering: true

An example script with all of the above components enabled can be found at :code_link:`examples/algorithms/dapo/run_dapo_gsm8k.sh`.
12 changes: 12 additions & 0 deletions skyrl-train/docs/configuration/config.rst
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,13 @@ Algorithm Configuration
value_clip: 0.2
normalize_reward: true

# dynamic sampling parameters
dynamic_sampling:
type: null # filter (DAPO), replace (POLARIS/WebSailor), or null
max_sample_batches: 30 # sample at most this many batches before stopping, -1 to sample forever
min_replace_ratio: 0.3 # minimum proportion of good samples with which to replace bad samples (for replace strategy only)


- ``algorithm.advantage_estimator``: Advantage estimator to use. We currently implement ``grpo`` and ``gae``, and custom advantage estimators can be registered with the ``AdvantageEstimatorRegistry``.
- ``algorithm.use_kl_estimator_k3``: Whether to use the k3 estimator for KL divergence calculation. The k3 estimator is the non negative kl approximation in `this blog post <http://joschu.net/blog/kl-approx.html>`_. Besides non negative, it is also unbiased and has lower variance.
- ``algorithm.use_abs_kl``: Whether to use the absolute KL divergence for KL divergence calculation.
Expand All @@ -324,6 +331,11 @@ Algorithm Configuration
- ``algorithm.clip_ratio_c``: Clip ratio for dual clip PPO loss.
- ``algorithm.value_clip``: Clip value for value loss.
- ``algorithm.normalize_reward``: Whether to normalize critic model output (i.e., values). When ``true``, the critic model learns the mean and standard deviation of the values during training and normalizes the values during forward pass.
- ``algorithm.dynamic_sampling``: Dynamic sampling configuration.
- ``algorithm.dynamic_sampling.type``: Type of dynamic sampling to use. Currently, we support ``filter`` (`DAPO <https://dapo-sia.github.io/>`_), ``replace`` (`POLARIS <https://hkunlp.github.io/blog/2025/Polaris/>`_ / `WebSailor <https://arxiv.org/abs/2507.02592>`_), or ``null`` for no dynamic sampling.
- ``algorithm.dynamic_sampling.max_sample_batches``: Maximum number of batches to sample before stopping. Set to ``-1`` to sample forever.
- ``algorithm.dynamic_sampling.min_replace_ratio``: Minimum proportion of good samples with which to replace bad samples for ``replace`` strategy.


Policy Loss Formulation
~~~~~~~~~~~~~~~~~~~~~~~
Expand Down
15 changes: 8 additions & 7 deletions skyrl-train/docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,14 @@ SkyRL is a full-stack RL library designed for modularity and extensibility.

recipes/skyrl-sql
recipes/searchr1


.. toctree::
:maxdepth: 2
:caption: Algorithms

algorithms/dapo
algorithms/custom_algorithms

.. toctree::
:maxdepth: 2
:caption: Configuration
Expand All @@ -70,12 +77,6 @@ SkyRL is a full-stack RL library designed for modularity and extensibility.
api/registry
api/tools

.. toctree::
:maxdepth: 2
:caption: Algorithms

algorithms/custom_algorithms

.. toctree::
:maxdepth: 2
:caption: Troubleshooting
Expand Down
84 changes: 84 additions & 0 deletions skyrl-train/examples/algorithms/dapo/main_dapo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
"""
uv run --isolated --extra vllm -m examples.algorithms.dapo.main_dapo
"""

import ray
import hydra
import torch
import numpy as np
from omegaconf import DictConfig
from skyrl_train.utils import initialize_ray
from skyrl_train.entrypoints.main_base import BasePPOExp, config_dir, validate_cfg
from skyrl_train.utils.ppo_utils import AdvantageEstimatorRegistry, compute_grpo_outcome_advantage


# Custom advantage estimator to implement soft overlong punishment for DAPO
def compute_grpo_with_soft_overlong_punishment(
token_level_rewards: torch.Tensor, response_mask: torch.Tensor, index: np.ndarray, **kwargs
):
"""
Applies soft overlong punishment to the token-level rewards and then computes GRPO advantages.

Args:
token_level_rewards: (batch_size, seqlen) tensor of token-level rewards
response_mask: (batch_size, seqlen) tensor of response mask
index: (batch_size) tensor of prompt indices

Returns:
advantages: (batch_size, seqlen) tensor of advantages
returns: (batch_size, seqlen) tensor of returns
"""
# this assumes response-level rewards
scores = token_level_rewards.sum(dim=-1)

# Overlong punishment params - hardcoded for this script for now
# TODO (erictang000): make these configurable (in general for all custom registries)
max_resp_length = 1024 # this is generator.sampling_params.max_generate_length in the `run_dapo_gsm8k.sh` script
overlong_buffer_len = 512 # overlong buffer is last 512 tokens of the response as an example
overlong_penalty_factor = (
1.0 # reward penalty increases linearly from 0 to 1.0 as the response length enters the overlong buffer
)

# add soft overlong punishment
lengths = response_mask.sum(dim=-1)
buffer_start_idx = max_resp_length - overlong_buffer_len
Copy link
Member

@SumanthRH SumanthRH Aug 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wait why would be do this here? This is pretty hacky and not what we'd want to show for an example of reward customization?

The ideal place is probably in the Generator, but we don't have any postprocessing hooks there.

For now, we can quickly make this be custom postprocessing in the trainer by subclassing postprocess_generator_output and adding the reward penalty here?

def postprocess_generator_output(self, generator_output: GeneratorOutput, uids: List[str]) -> GeneratorOutput:

class DAPOTrainer(RayPPOTrainer):
    @torch.no_grad()
    def postprocess_generator_output(self, generator_output: GeneratorOutput, uids: List[str]) -> GeneratorOutput:
         # modify rewards here 
         ....
         # use base class impl for metrics and per-token reward conversion
         return super().postprocess_generator_output(....)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes agree this is cleaner, updated to do this!

# apply penalty
penalty_mask = lengths > buffer_start_idx
penalty = (lengths[penalty_mask] - buffer_start_idx) / overlong_buffer_len * overlong_penalty_factor
scores[penalty_mask] -= penalty
# for responses that have length >= max_resp_length, overlong filtering is already applied in the config
# by setting apply_overlong_filtering=true

# reconstruct response-level rewards in format expected in compute_grpo_outcome_advantage
new_token_level_rewards = torch.zeros_like(token_level_rewards)
new_token_level_rewards[:, -1] = scores

# compute GRPO advantages
advantages, returns = compute_grpo_outcome_advantage(
new_token_level_rewards, response_mask, index, epsilon=1e-6, norm_adv_by_std_in_grpo=True
)

return advantages, returns


# Register our custom advantage estimator
AdvantageEstimatorRegistry.register("grpo_with_soft_overlong_punishment", compute_grpo_with_soft_overlong_punishment)


@ray.remote(num_cpus=1)
def skyrl_entrypoint(cfg: DictConfig):
exp = BasePPOExp(cfg)
exp.run()


@hydra.main(config_path=config_dir, config_name="ppo_base_config", version_base=None)
def main(cfg: DictConfig) -> None:
# validate the arguments
validate_cfg(cfg)

initialize_ray(cfg)
ray.get(skyrl_entrypoint.remote(cfg))


if __name__ == "__main__":
main()
83 changes: 83 additions & 0 deletions skyrl-train/examples/algorithms/dapo/run_dapo_gsm8k.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
set -x

# Colocated DAPO training+generation for Qwen2.5-1.5B-Instruct on GSM8K.

# uv run examples/gsm8k/gsm8k_dataset.py --output_dir $HOME/data/gsm8k
# export WANDB_API_KEY=<your_key_here>
# bash examples/algorithms/dapo/run_dapo_gsm8k.sh

DATA_DIR="$HOME/data/gsm8k"
NUM_GPUS=4
LOGGER="wandb" # change to "console" to print to stdout

# main DAPO parameters
EPS_CLIP_LOW=0.2
EPS_CLIP_HIGH=0.28
DYNAMIC_SAMPLING_TYPE=filter
DYNAMIC_SAMPLING_MAX_SAMPLE_BATCHES=30
LOSS_REDUCTION="token_mean"
# applies overlong filtering (but not soft overlong punishment)
APPLY_OVERLONG_FILTERING=true
# apply soft overlong punishment using custom advantage estimator registered in main_dapo.py
ADV_ESTIMATOR="grpo_with_soft_overlong_punishment"

# other DAPO parameters
USE_KL_LOSS=false
TEMPERATURE=1.0
TOP_P=1.0
EVAL_TOP_P=0.7
CLIP_RATIO_C=10.0
MAX_RESPONSE_LENGTH=1024

uv run --isolated --extra vllm -m examples.algorithms.dapo.main_dapo \
data.train_data="['$DATA_DIR/train.parquet']" \
data.val_data="['$DATA_DIR/validation.parquet']" \
trainer.algorithm.advantage_estimator=$ADV_ESTIMATOR \
trainer.algorithm.policy_loss_type="dual_clip" \
trainer.algorithm.eps_clip_low=$EPS_CLIP_LOW \
trainer.algorithm.eps_clip_high=$EPS_CLIP_HIGH \
trainer.algorithm.dynamic_sampling.type=$DYNAMIC_SAMPLING_TYPE \
trainer.algorithm.dynamic_sampling.max_sample_batches=$DYNAMIC_SAMPLING_MAX_SAMPLE_BATCHES \
trainer.algorithm.loss_reduction=$LOSS_REDUCTION \
generator.apply_overlong_filtering=$APPLY_OVERLONG_FILTERING \
generator.sampling_params.temperature=$TEMPERATURE \
generator.sampling_params.top_p=$TOP_P \
generator.eval_sampling_params.top_p=$EVAL_TOP_P \
trainer.algorithm.use_kl_loss=$USE_KL_LOSS \
trainer.algorithm.clip_ratio_c=$CLIP_RATIO_C \
trainer.policy.model.path="Qwen/Qwen2.5-1.5B-Instruct" \
trainer.placement.colocate_all=true \
trainer.strategy=fsdp2 \
trainer.placement.policy_num_gpus_per_node=$NUM_GPUS \
trainer.placement.ref_num_gpus_per_node=$NUM_GPUS \
generator.num_inference_engines=$NUM_GPUS \
generator.inference_engine_tensor_parallel_size=1 \
trainer.epochs=20 \
trainer.eval_batch_size=1024 \
trainer.eval_before_train=false \
trainer.eval_interval=5 \
trainer.update_epochs_per_batch=1 \
trainer.train_batch_size=1024 \
trainer.policy_mini_batch_size=256 \
trainer.micro_forward_batch_size_per_gpu=64 \
trainer.micro_train_batch_size_per_gpu=64 \
trainer.ckpt_interval=10 \
trainer.max_prompt_length=512 \
generator.sampling_params.max_generate_length=$MAX_RESPONSE_LENGTH \
trainer.policy.optimizer_config.lr=1.0e-6 \
trainer.policy.optimizer_config.weight_decay=0.1 \
trainer.policy.optimizer_config.max_grad_norm=1.0 \
generator.backend=vllm \
generator.run_engines_locally=true \
generator.weight_sync_backend=nccl \
generator.async_engine=true \
generator.batched=true \
environment.env_class=gsm8k \
generator.n_samples_per_prompt=5 \
generator.gpu_memory_utilization=0.8 \
trainer.logger="$LOGGER" \
trainer.project_name="gsm8k" \
trainer.run_name="gsm8k_dapo" \
trainer.resume_mode=null \
trainer.ckpt_path="$HOME/ckpts/gsm8k_1.5B_ckpt" \
$@
1 change: 1 addition & 0 deletions skyrl-train/scripts/full_context/trainer_full_ctx.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ def train(self):
self.tracker.log(self.all_metrics, step=self.global_step)
self.all_metrics = {}
self.tracker.log({"timing/" + k: v for k, v in self.all_timings.items()}, step=self.global_step)
self.all_timings = {}
self.global_step += 1

logger.info(f"Step {step + 1} completed. Status: {status}")
Expand Down
4 changes: 4 additions & 0 deletions skyrl-train/skyrl_train/config/ppo_base_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,10 @@ trainer:
# value loss parameters
value_clip: 0.2
normalize_reward: true
dynamic_sampling:
type: null # filter, replace, or null
max_sample_batches: 30 # sample at most this many batches before stopping, -1 to sample forever
min_replace_ratio: 0.3 # minimum proportion of good samples with which to replace bad samples (for replace strategy only)

gradient_checkpointing: true
gradient_checkpointing_use_reentrant: false
Expand Down
Loading