-
Notifications
You must be signed in to change notification settings - Fork 86
[trainer/algorithm] Implement DAPO and Polaris style dynamic sampling + add DAPO docs + example #130
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[trainer/algorithm] Implement DAPO and Polaris style dynamic sampling + add DAPO docs + example #130
Changes from 44 commits
69fd7e0
f59abaa
51119e5
5659f9b
fc9e355
6810779
c4bde2a
ac018fd
0e8facc
9a865f5
7202d21
4445e42
bec693e
9b7c7d2
119d9cd
f32ffa9
35db88e
3cce025
f5267b1
e11db0a
ad7b045
5a909f5
615133d
43edec4
19f5816
8419051
7bf7e54
2b2e326
40f26e8
4822d52
5232569
81a3819
3706f07
0c40fed
13574c9
b9f03d7
bcd53eb
c063aee
f0890d2
46ddda9
8118f55
3b2d007
a2ac205
8fa7e42
0ba0d52
1c24673
e8f7be4
87d6add
5c91bcc
eac1c77
c777948
401fd72
e7c5616
59bb246
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
DAPO | ||
==== | ||
|
||
The `DAPO <https://arxiv.org/abs/2503.14476>`_ (Decoupled Clip and Dynamic Sampling Policy Optimization) algorithm consists of the following components on top of a GRPO baseline: | ||
|
||
- **Clip-Higher**: Promotes the diversity of the system and avoids entropy collapse; | ||
- **Dynamic Sampling**: Improves training efficiency and stability; | ||
- **Token-Level Policy Gradient Loss**: Critical in long-CoT RL scenarios; | ||
- **Overlong Reward Shaping**: Reduces reward noise and stabilizes training. | ||
|
||
In this guide, we walk through how to enable each of these components in SkyRL. We provide a simple example script for training DAPO on GSM8K in :code_link:`examples/algorithms/dapo/`. | ||
|
||
Clip-Higher | ||
~~~~~~~~~~~ | ||
To use clip-higher, you can simply configure ``trainer.algorithm.eps_clip_high`` separately from ``trainer.algorithm.eps_clip_low``. | ||
|
||
.. code-block:: yaml | ||
|
||
trainer: | ||
algorithm: | ||
eps_clip_low: 0.2 | ||
eps_clip_high: 0.28 | ||
|
||
Dynamic Sampling | ||
~~~~~~~~~~~~~~~~ | ||
In DAPO style dynamic sampling, we sample rollouts until we have a full batch with non-zero advantages (meaning that we have a non-zero std deviation of rewards for the n rollouts for a given prompt). | ||
|
||
To configure DAPO style dynamic sampling, you can set ``trainer.algorithm.dynamic_sampling.type`` to ``filter`` and configure ``trainer.algorithm.dynamic_sampling.max_sample_batches`` to the maximum number of batches to sample. | ||
If ``max_sample_batches > 0`` and is exceeded, SkyRL-Train will raise an error. If ``max_sample_batches <= 0``, SkyRL-Train will sample until a full batch with non-zero advantages is accumulated. | ||
.. code-block:: yaml | ||
|
||
trainer: | ||
algorithm: | ||
dynamic_sampling: | ||
type: filter | ||
max_sample_batches: 30 | ||
|
||
Token-Level Policy Gradient Loss | ||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | ||
DAPO uses token-level policy gradient loss, which can be enabled by setting ``trainer.algorithm.loss_reduction`` to ``token_mean``. This is the default setting in SkyRL-Train. | ||
|
||
.. code-block:: yaml | ||
|
||
trainer: | ||
algorithm: | ||
loss_reduction: "token_mean" | ||
|
||
Overlong Reward Shaping | ||
~~~~~~~~~~~~~~~~~~~~~~~~ | ||
The DAPO paper proposes two methods for overlong reward shaping: | ||
|
||
- **Overlong Filtering**: Sets loss mask to be all zeros for responses that exceed the max response length. | ||
- **Soft Overlong Punishment**: Penalizes responses that exceed the max response length within a punishment interval. Within this interval, the longer the response, the greater the punishment it receives. This penalty is added to the original reward. | ||
|
||
To enable overlong filtering, which sets loss mask to be all zeros for responses that do not finish with a stop token (i.e. responses that are too long), you can set ``generator.apply_overlong_filtering`` to ``true``. | ||
|
||
To enable soft overlong punishment, you can register a custom advantage estimator, which we show an example of in :code_link:`examples/algorithms/dapo/main_dapo.py`. | ||
|
||
.. code-block:: yaml | ||
|
||
generator: | ||
apply_overlong_filtering: true | ||
|
||
An example script with all of the above components enabled can be found at :code_link:`examples/algorithms/dapo/run_dapo_gsm8k.sh`. |
Original file line number | Diff line number | Diff line change | ||
---|---|---|---|---|
@@ -0,0 +1,84 @@ | ||||
""" | ||||
uv run --isolated --extra vllm -m examples.algorithms.dapo.main_dapo | ||||
""" | ||||
|
||||
import ray | ||||
import hydra | ||||
import torch | ||||
import numpy as np | ||||
from omegaconf import DictConfig | ||||
from skyrl_train.utils import initialize_ray | ||||
from skyrl_train.entrypoints.main_base import BasePPOExp, config_dir, validate_cfg | ||||
from skyrl_train.utils.ppo_utils import AdvantageEstimatorRegistry, compute_grpo_outcome_advantage | ||||
|
||||
|
||||
# Custom advantage estimator to implement soft overlong punishment for DAPO | ||||
def compute_grpo_with_soft_overlong_punishment( | ||||
token_level_rewards: torch.Tensor, response_mask: torch.Tensor, index: np.ndarray, **kwargs | ||||
): | ||||
""" | ||||
Applies soft overlong punishment to the token-level rewards and then computes GRPO advantages. | ||||
|
||||
Args: | ||||
token_level_rewards: (batch_size, seqlen) tensor of token-level rewards | ||||
response_mask: (batch_size, seqlen) tensor of response mask | ||||
index: (batch_size) tensor of prompt indices | ||||
|
||||
Returns: | ||||
advantages: (batch_size, seqlen) tensor of advantages | ||||
returns: (batch_size, seqlen) tensor of returns | ||||
""" | ||||
# this assumes response-level rewards | ||||
scores = token_level_rewards.sum(dim=-1) | ||||
|
||||
# Overlong punishment params - hardcoded for this script for now | ||||
# TODO (erictang000): make these configurable (in general for all custom registries) | ||||
max_resp_length = 1024 # this is generator.sampling_params.max_generate_length in the `run_dapo_gsm8k.sh` script | ||||
overlong_buffer_len = 512 # overlong buffer is last 512 tokens of the response as an example | ||||
overlong_penalty_factor = ( | ||||
1.0 # reward penalty increases linearly from 0 to 1.0 as the response length enters the overlong buffer | ||||
) | ||||
|
||||
# add soft overlong punishment | ||||
lengths = response_mask.sum(dim=-1) | ||||
buffer_start_idx = max_resp_length - overlong_buffer_len | ||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Wait why would be do this here? This is pretty hacky and not what we'd want to show for an example of reward customization? The ideal place is probably in the Generator, but we don't have any postprocessing hooks there. For now, we can quickly make this be custom postprocessing in the trainer by subclassing SkyRL/skyrl-train/skyrl_train/trainer.py Line 672 in bd9d6a9
class DAPOTrainer(RayPPOTrainer):
@torch.no_grad()
def postprocess_generator_output(self, generator_output: GeneratorOutput, uids: List[str]) -> GeneratorOutput:
# modify rewards here
....
# use base class impl for metrics and per-token reward conversion
return super().postprocess_generator_output(....) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yes agree this is cleaner, updated to do this! |
||||
# apply penalty | ||||
penalty_mask = lengths > buffer_start_idx | ||||
penalty = (lengths[penalty_mask] - buffer_start_idx) / overlong_buffer_len * overlong_penalty_factor | ||||
scores[penalty_mask] -= penalty | ||||
# for responses that have length >= max_resp_length, overlong filtering is already applied in the config | ||||
# by setting apply_overlong_filtering=true | ||||
|
||||
# reconstruct response-level rewards in format expected in compute_grpo_outcome_advantage | ||||
new_token_level_rewards = torch.zeros_like(token_level_rewards) | ||||
new_token_level_rewards[:, -1] = scores | ||||
|
||||
# compute GRPO advantages | ||||
advantages, returns = compute_grpo_outcome_advantage( | ||||
new_token_level_rewards, response_mask, index, epsilon=1e-6, norm_adv_by_std_in_grpo=True | ||||
) | ||||
|
||||
return advantages, returns | ||||
|
||||
|
||||
# Register our custom advantage estimator | ||||
AdvantageEstimatorRegistry.register("grpo_with_soft_overlong_punishment", compute_grpo_with_soft_overlong_punishment) | ||||
|
||||
|
||||
@ray.remote(num_cpus=1) | ||||
def skyrl_entrypoint(cfg: DictConfig): | ||||
exp = BasePPOExp(cfg) | ||||
exp.run() | ||||
|
||||
|
||||
@hydra.main(config_path=config_dir, config_name="ppo_base_config", version_base=None) | ||||
def main(cfg: DictConfig) -> None: | ||||
# validate the arguments | ||||
validate_cfg(cfg) | ||||
|
||||
initialize_ray(cfg) | ||||
ray.get(skyrl_entrypoint.remote(cfg)) | ||||
|
||||
|
||||
if __name__ == "__main__": | ||||
main() |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,83 @@ | ||
set -x | ||
|
||
# Colocated DAPO training+generation for Qwen2.5-1.5B-Instruct on GSM8K. | ||
|
||
# uv run examples/gsm8k/gsm8k_dataset.py --output_dir $HOME/data/gsm8k | ||
# export WANDB_API_KEY=<your_key_here> | ||
# bash examples/algorithms/dapo/run_dapo_gsm8k.sh | ||
|
||
DATA_DIR="$HOME/data/gsm8k" | ||
NUM_GPUS=4 | ||
LOGGER="wandb" # change to "console" to print to stdout | ||
|
||
# main DAPO parameters | ||
EPS_CLIP_LOW=0.2 | ||
EPS_CLIP_HIGH=0.28 | ||
DYNAMIC_SAMPLING_TYPE=filter | ||
DYNAMIC_SAMPLING_MAX_SAMPLE_BATCHES=30 | ||
LOSS_REDUCTION="token_mean" | ||
# applies overlong filtering (but not soft overlong punishment) | ||
APPLY_OVERLONG_FILTERING=true | ||
# apply soft overlong punishment using custom advantage estimator registered in main_dapo.py | ||
ADV_ESTIMATOR="grpo_with_soft_overlong_punishment" | ||
|
||
# other DAPO parameters | ||
USE_KL_LOSS=false | ||
TEMPERATURE=1.0 | ||
TOP_P=1.0 | ||
EVAL_TOP_P=0.7 | ||
CLIP_RATIO_C=10.0 | ||
MAX_RESPONSE_LENGTH=1024 | ||
|
||
uv run --isolated --extra vllm -m examples.algorithms.dapo.main_dapo \ | ||
data.train_data="['$DATA_DIR/train.parquet']" \ | ||
data.val_data="['$DATA_DIR/validation.parquet']" \ | ||
trainer.algorithm.advantage_estimator=$ADV_ESTIMATOR \ | ||
trainer.algorithm.policy_loss_type="dual_clip" \ | ||
trainer.algorithm.eps_clip_low=$EPS_CLIP_LOW \ | ||
trainer.algorithm.eps_clip_high=$EPS_CLIP_HIGH \ | ||
trainer.algorithm.dynamic_sampling.type=$DYNAMIC_SAMPLING_TYPE \ | ||
trainer.algorithm.dynamic_sampling.max_sample_batches=$DYNAMIC_SAMPLING_MAX_SAMPLE_BATCHES \ | ||
trainer.algorithm.loss_reduction=$LOSS_REDUCTION \ | ||
generator.apply_overlong_filtering=$APPLY_OVERLONG_FILTERING \ | ||
generator.sampling_params.temperature=$TEMPERATURE \ | ||
generator.sampling_params.top_p=$TOP_P \ | ||
generator.eval_sampling_params.top_p=$EVAL_TOP_P \ | ||
trainer.algorithm.use_kl_loss=$USE_KL_LOSS \ | ||
trainer.algorithm.clip_ratio_c=$CLIP_RATIO_C \ | ||
trainer.policy.model.path="Qwen/Qwen2.5-1.5B-Instruct" \ | ||
trainer.placement.colocate_all=true \ | ||
trainer.strategy=fsdp2 \ | ||
trainer.placement.policy_num_gpus_per_node=$NUM_GPUS \ | ||
trainer.placement.ref_num_gpus_per_node=$NUM_GPUS \ | ||
generator.num_inference_engines=$NUM_GPUS \ | ||
generator.inference_engine_tensor_parallel_size=1 \ | ||
trainer.epochs=20 \ | ||
trainer.eval_batch_size=1024 \ | ||
trainer.eval_before_train=false \ | ||
trainer.eval_interval=5 \ | ||
trainer.update_epochs_per_batch=1 \ | ||
trainer.train_batch_size=1024 \ | ||
trainer.policy_mini_batch_size=256 \ | ||
trainer.micro_forward_batch_size_per_gpu=64 \ | ||
trainer.micro_train_batch_size_per_gpu=64 \ | ||
trainer.ckpt_interval=10 \ | ||
trainer.max_prompt_length=512 \ | ||
generator.sampling_params.max_generate_length=$MAX_RESPONSE_LENGTH \ | ||
trainer.policy.optimizer_config.lr=1.0e-6 \ | ||
trainer.policy.optimizer_config.weight_decay=0.1 \ | ||
trainer.policy.optimizer_config.max_grad_norm=1.0 \ | ||
generator.backend=vllm \ | ||
generator.run_engines_locally=true \ | ||
generator.weight_sync_backend=nccl \ | ||
generator.async_engine=true \ | ||
generator.batched=true \ | ||
environment.env_class=gsm8k \ | ||
generator.n_samples_per_prompt=5 \ | ||
generator.gpu_memory_utilization=0.8 \ | ||
trainer.logger="$LOGGER" \ | ||
trainer.project_name="gsm8k" \ | ||
trainer.run_name="gsm8k_dapo" \ | ||
trainer.resume_mode=null \ | ||
trainer.ckpt_path="$HOME/ckpts/gsm8k_1.5B_ckpt" \ | ||
$@ |
Uh oh!
There was an error while loading. Please reload this page.