Skip to content
Merged
Show file tree
Hide file tree
Changes from 22 commits
Commits
Show all changes
54 commits
Select commit Hold shift + click to select a range
69fd7e0
fix bug
erictang000 Jul 25, 2025
f59abaa
remove fsdp from fsdp2 hf save model architecture
erictang000 Jul 28, 2025
51119e5
merge
erictang000 Jul 28, 2025
5659f9b
x
erictang000 Jul 28, 2025
fc9e355
thanks gemini
erictang000 Jul 28, 2025
6810779
remove extra ray.shutdown
erictang000 Jul 28, 2025
c4bde2a
deepspeed + fsdp add configs to checkpoint folder
erictang000 Jul 29, 2025
ac018fd
Merge branch 'main' of https://github.com/erictang000/SkyRL into conf…
erictang000 Jul 29, 2025
0e8facc
pull to parent function for shared logic
erictang000 Jul 29, 2025
9a865f5
x
erictang000 Jul 29, 2025
7202d21
docs
erictang000 Jul 29, 2025
4445e42
x
erictang000 Jul 29, 2025
bec693e
x
erictang000 Jul 29, 2025
9b7c7d2
address gemini comments
erictang000 Jul 29, 2025
119d9cd
x
erictang000 Jul 29, 2025
f32ffa9
Merge branch 'main' of https://github.com/erictang000/SkyRL
erictang000 Jul 29, 2025
35db88e
Merge branch 'config_checkpointing' of https://github.com/erictang000…
erictang000 Jul 29, 2025
3cce025
Merge branch 'main' of https://github.com/erictang000/SkyRL
erictang000 Jul 29, 2025
f5267b1
Merge branch 'main' of https://github.com/erictang000/SkyRL
erictang000 Jul 31, 2025
e11db0a
x
erictang000 Aug 1, 2025
ad7b045
unit tests passing - need to test both e2e
erictang000 Aug 2, 2025
5a909f5
x
erictang000 Aug 2, 2025
615133d
Merge branch 'main' of https://github.com/erictang000/SkyRL into dyna…
erictang000 Aug 4, 2025
43edec4
x
erictang000 Aug 4, 2025
19f5816
fixes
erictang000 Aug 4, 2025
8419051
fixes
erictang000 Aug 4, 2025
7bf7e54
x
erictang000 Aug 4, 2025
2b2e326
Merge branch 'main' of https://github.com/erictang000/SkyRL into dyna…
erictang000 Aug 4, 2025
40f26e8
fix weight manager logic
erictang000 Aug 4, 2025
4822d52
x
erictang000 Aug 4, 2025
5232569
thanks gemini
erictang000 Aug 4, 2025
81a3819
x
erictang000 Aug 5, 2025
3706f07
x
erictang000 Aug 5, 2025
0c40fed
x
erictang000 Aug 5, 2025
13574c9
x
erictang000 Aug 5, 2025
b9f03d7
x
erictang000 Aug 5, 2025
bcd53eb
Apply suggestions from code review
erictang000 Aug 6, 2025
c063aee
address comments
erictang000 Aug 6, 2025
f0890d2
Merge branch 'dynamic_sampling' of https://github.com/erictang000/Sky…
erictang000 Aug 6, 2025
46ddda9
fix tests
erictang000 Aug 6, 2025
8118f55
add soft overlong punishment
erictang000 Aug 7, 2025
3b2d007
x
erictang000 Aug 7, 2025
a2ac205
thanks gemini
erictang000 Aug 7, 2025
8fa7e42
x
erictang000 Aug 7, 2025
0ba0d52
change to overriding trainer
erictang000 Aug 7, 2025
1c24673
x
erictang000 Aug 7, 2025
e8f7be4
x
erictang000 Aug 7, 2025
87d6add
x
erictang000 Aug 7, 2025
5c91bcc
x
erictang000 Aug 7, 2025
eac1c77
x
erictang000 Aug 7, 2025
c777948
add more docs for custom trainer
erictang000 Aug 7, 2025
401fd72
add ref to dapo example
erictang000 Aug 7, 2025
e7c5616
x
erictang000 Aug 7, 2025
59bb246
thanks gemini
erictang000 Aug 7, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions skyrl-train/skyrl_train/config/ppo_base_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,9 @@ trainer:
# value loss parameters
value_clip: 0.2
normalize_reward: true
dynamic_sampling:
type: null # filter, replace, or null/"none"
max_sample_batches: 20 # sample at most this many batches before stopping, -1 to sample forever

gradient_checkpointing: true
gradient_checkpointing_use_reentrant: false
Expand Down
70 changes: 70 additions & 0 deletions skyrl-train/skyrl_train/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
validate_consistency_for_latest_checkpoint,
calculate_per_dataset_metrics,
dump_per_dataset_eval_results,
handle_dynamic_sampling,
GLOBAL_STEP_PREFIX,
)

Expand Down Expand Up @@ -92,6 +93,8 @@ def __init__(
self.weights_manager: InferenceWeightsManager = None
self.eval_weights_manager: InferenceWeightsManager = None

self.dynamic_sampling_state = None

def build_dataloader(self, dataset: PromptDataset, is_train=True):
"""
Build the dataloader for the training or evaluation dataset
Expand Down Expand Up @@ -232,6 +235,7 @@ def train(self):
# main training loop
pbar = tqdm(total=self.total_training_steps, initial=self.global_step, desc="Training Step Progress")
self.global_step += 1 # start training at global_step 1
exit_loop = False
for epoch in range(self.cfg.trainer.epochs):
for iter, rand_prompts in enumerate(self.train_dataloader):
with Timer("step", self.all_timings):
Expand All @@ -250,6 +254,14 @@ def train(self):
with Timer("generate", self.all_timings):
generator_output: GeneratorOutput = asyncio.run(self.generate(generator_input))

# dynamic sampling
if self.cfg.trainer.algorithm.dynamic_sampling.type is not None:
generator_output, uids, keep_sampling, exit_loop = self.dynamic_sampling(generator_output, uids)
if keep_sampling: # continue sampling
continue
elif exit_loop: # we want to exit gracefully if we hit the max sample batches
break

# 1.2 postprocess rewards
with Timer("postprocess_generator_output", self.all_timings):
generator_output = self.postprocess_generator_output(generator_output, uids)
Expand Down Expand Up @@ -328,6 +340,12 @@ def train(self):

del training_input, generator_output

if exit_loop:
logger.info(
"Exiting training loop due to hitting dynamic sampling limit. Please check your data difficulty distribution."
)
break

if self.cfg.trainer.update_ref_every_epoch and self.ref_model is not None:
with Timer("update_ref_with_policy", self.all_timings):
self.update_ref_with_policy()
Expand Down Expand Up @@ -1007,6 +1025,58 @@ def train_critic_and_policy(self, data: TrainingInputBatch):

return policy_status

def dynamic_sampling(
self, generator_output: GeneratorOutput, uids: List[str]
) -> Tuple[GeneratorOutput, List[str], bool, bool]:
"""
Implement dynamic sampling

Args:
generator_output: Current batch generator output
uids: Current batch UIDs

Returns:
processed_output: Filtered generator output
processed_uids: Filtered UIDs
keep_sampling: Whether to keep sampling
exit_loop: Whether to exit the training loop (if we hit the max sample batches)
"""
# Prepare sampling configuration
max_sample_batches = self.cfg.trainer.algorithm.dynamic_sampling.max_sample_batches
dynamic_sampling_config = {
"type": self.cfg.trainer.algorithm.dynamic_sampling.type,
"max_sample_batches": max_sample_batches,
"train_batch_size": self.cfg.trainer.train_batch_size,
"n_samples_per_prompt": getattr(self.cfg.generator, "n_samples_per_prompt", 1),
}

if self.dynamic_sampling_state is None:
self.dynamic_sampling_state = {
"sample_batch_count": 1,
}
else:
self.dynamic_sampling_state["sample_batch_count"] += 1

# Handle dynamic sampling using utilities
processed_output, processed_uids, keep_sampling, updated_state = handle_dynamic_sampling(
generator_output, uids, dynamic_sampling_config, self.dynamic_sampling_state
)

# Check max resample limit, and if we hit it, return true for exit_loop
if max_sample_batches > 0 and self.dynamic_sampling_state["sample_batch_count"] >= max_sample_batches:
logger.warning(
f"Hit max resample batches ({max_sample_batches}), but there are still not enough good prompts, stopping sampling"
)
return None, None, False, True
# Update state
self.dynamic_sampling_state = updated_state

if not keep_sampling:
# Reset state when sampling is complete
self.dynamic_sampling_state = None

return processed_output, processed_uids, keep_sampling, False

def _get_dp_group_models(self, rank: int, model_type: str = ""):
model = getattr(self, model_type)
if model_type == "reward_model":
Expand Down
249 changes: 247 additions & 2 deletions skyrl-train/skyrl_train/utils/trainer_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List, Dict, Any, Union, Callable, Optional
from typing import List, Dict, Any, Union, Callable, Optional, Tuple
from enum import Enum
import ray
from skyrl_train.workers.worker import PPORayActorGroup
Expand All @@ -8,7 +8,9 @@
from loguru import logger
import glob
import json
from skyrl_train.generators.utils import get_metrics_from_generator_output
import numpy as np
from collections import defaultdict
from skyrl_train.generators.utils import get_metrics_from_generator_output, concatenate_generator_outputs
from skyrl_train.generators.base import GeneratorOutput
from transformers import AutoTokenizer
from pathlib import Path
Expand Down Expand Up @@ -232,3 +234,246 @@ def dump_per_dataset_eval_results(
f.write(json.dumps(eval_metrics, ensure_ascii=False) + "\n")

print(f"Dumped aggregated eval metrics to {aggregated_filename}")


def handle_dynamic_sampling(
generator_output: GeneratorOutput,
uids: List[str],
sampling_config: Dict[str, Any],
collected_state: Optional[Dict[str, Any]] = None,
) -> Tuple[GeneratorOutput, List[str], bool, Optional[Dict[str, Any]]]:
"""
Handle dynamic sampling with different strategies (filter, replace).

filter (used in DAPO) - filter out groups with std == 0 and group size > 1 and resample until we have enough prompts
replace (used in POLARIS, WebSailor) - replace bad (std == 0) samples with good (std > 0) samples

Args:
generator_output: Current batch generator output
uids: Current batch UIDs
sampling_config: Configuration dict with sampling parameters
collected_state: State for accumulating data across batches (for filter strategy)

Returns:
Tuple of (processed_generator_output, processed_uids, keep_sampling, updated_state)
"""
sampling_type = sampling_config.get("type", None)

if sampling_type is None:
return generator_output, uids, False, None

if sampling_type == "replace":
# For "replace" strategy, handle immediately without accumulation
processed_output, processed_uids, keep_sampling = handle_replace_sampling(
generator_output, uids, sampling_config
)
return processed_output, processed_uids, keep_sampling, collected_state
elif sampling_type == "filter":
# For filter strategies, handle with accumulation
return handle_filter_sampling(generator_output, uids, sampling_config, collected_state)
else:
raise ValueError(f"Invalid dynamic sampling type: {sampling_type}")


def get_bad_sample_replacements(good_uids: List[str], bad_uids: List[str]) -> List[str]:
num_replacements = len(bad_uids)
num_candidates = len(good_uids)

if num_candidates >= num_replacements:
perm = np.random.permutation(num_candidates)
chosen_replacement_idxs = np.array(list(good_uids))[perm[:num_replacements]]
else:
indices = np.random.randint(low=0, high=num_candidates, size=(num_replacements,))
chosen_replacement_idxs = np.array(list(good_uids))[indices]

return chosen_replacement_idxs


def handle_replace_sampling(
generator_output: GeneratorOutput, uids: List[str], sampling_config: Dict[str, Any]
) -> Tuple[GeneratorOutput, List[str], bool]:
"""
Handle replace sampling strategy based on POLARIS implementation (https://github.com/ChenxinAn-fdu/POLARIS/).

Args:
generator_output: Current batch generator output
uids: Current batch UIDs
sampling_config: Configuration dict with sampling parameters
Returns:
Tuple of (processed_generator_output, processed_uids, keep_sampling)
"""
n_samples_per_prompt = sampling_config.get("n_samples_per_prompt", 1)

# Extract rewards and convert to sequence-level if needed
rewards = np.array(generator_output["rewards"])
if isinstance(rewards[0], list):
# Token-level rewards: sum to get sequence rewards
rewards = rewards.sum(dim=-1)

# get mapping of uids to list of indices and metrics
uid2indices = defaultdict(list)
uid2metric_vals = defaultdict(list)
for idx, uid in enumerate(uids):
uid2indices[uid].append(idx)
uid2metric_vals[uid].append(rewards[idx])

# Group by UID and calculate metrics
uid2metric_std = {}
for uid, metric_vals in uid2metric_vals.items():
uid2metric_std[uid] = np.std(metric_vals)

# Determine good UIDs: those with std > 0 (or group size == 1)
good_uids = set([uid for uid, std in uid2metric_std.items() if std > 0 or n_samples_per_prompt == 1])
bad_uids = set([uid for uid, std in uid2metric_std.items() if std == 0 and n_samples_per_prompt > 1])

logger.info(f"Replace sampling: {len(good_uids)} good UIDs out of {len(uid2metric_vals)} total prompts")

# Check if we have enough good UIDs (more than 1/3 of the batch)
if len(good_uids) > len(uid2metric_vals) // 3:
logger.info("============= POLARIS dynamic sampling replace ===========")
logger.info(f"Number of good prompts: {len(good_uids)}")
logger.info(f"Number of bad prompts: {len(bad_uids)}")

# Get good uids to replace the bad uids (length of bad uids)
replacement_uids = get_bad_sample_replacements(good_uids, bad_uids) # uids to replace the bad uids
# get replacement indices
replacement_indices = []
for uid in replacement_uids:
replacement_indices.extend(uid2indices[uid])
# get bad indices
bad_indices = []
for uid in bad_uids:
bad_indices.extend(uid2indices[uid])

# Replace bad samples with good ones (modify in place because replacement_idx and bad_idx should not overlap)
for bad_idx, replacement_idx in zip(bad_indices, replacement_indices):
generator_output["prompt_token_ids"][bad_idx] = generator_output["prompt_token_ids"][replacement_idx].copy()
generator_output["response_ids"][bad_idx] = generator_output["response_ids"][replacement_idx].copy()
if isinstance(rewards[0], list):
generator_output["rewards"][bad_idx] = generator_output["rewards"][replacement_idx].copy()
else:
generator_output["rewards"][bad_idx] = generator_output["rewards"][replacement_idx]
generator_output["loss_masks"][bad_idx] = generator_output["loss_masks"][replacement_idx].copy()
if generator_output["stop_reasons"]:
generator_output["stop_reasons"][bad_idx] = generator_output["stop_reasons"][replacement_idx]

# Update UIDs accordingly
replaced_uids = uids.copy()
for bad_idx, replacement_idx in zip(bad_indices, replacement_indices):
replaced_uids[bad_idx] = uids[replacement_idx]

logger.info(f"After replacement - Replaced {len(bad_indices) // n_samples_per_prompt} bad prompts")

return generator_output, replaced_uids, False
else:
logger.warning("===================== Warning ====================")
logger.warning("In this mini-batch, most training samples receive low variance rewards.")
logger.warning("If you continue to see this warning, please check your data difficulty distribution.")
logger.warning("==================================================")

return generator_output, uids, True


def handle_filter_sampling(
generator_output: GeneratorOutput,
uids: List[str],
sampling_config: Dict[str, Any],
collected_state: Dict[str, Any],
) -> Tuple[GeneratorOutput, List[str], bool, Dict[str, Any]]:
"""
Handle filter-based sampling strategy.

Args:
generator_output: Current batch generator output
uids: Current batch UIDs
sampling_config: Configuration dict with sampling parameters
collected_state: State for accumulating data across batches

Returns:
Tuple of (processed_generator_output, processed_uids, keep_sampling, updated_state)
"""
target_batch_size = sampling_config.get("train_batch_size")
n_samples_per_prompt = sampling_config.get("n_samples_per_prompt", 1)

# Extract rewards from collected output
rewards = np.array(generator_output["rewards"])

# Convert to sequence-level rewards if needed
if isinstance(rewards[0], list):
rewards = rewards.sum(dim=-1)

# Group by UID and calculate standard deviation
uid2metric_vals = defaultdict(list)
for uid, reward in zip(uids, rewards):
uid2metric_vals[uid].append(reward)

uid2metric_std = {}
for uid, metric_vals in uid2metric_vals.items():
uid2metric_std[uid] = np.std(metric_vals)

# Filter out groups with std == 0 and group size > 1
kept_uids = [uid for uid, std in uid2metric_std.items() if std > 0 or n_samples_per_prompt == 1]
kept_uids_set = set(kept_uids)

# Filter trajectories based on kept UIDs
kept_traj_idxs = []
for idx, traj_uid in enumerate(uids):
if traj_uid in kept_uids_set:
kept_traj_idxs.append(idx)

# Apply filtering to generator output
filtered_output = filter_generator_output(generator_output, kept_traj_idxs)
filtered_uids = [uids[idx] for idx in kept_traj_idxs]

if "collected_generator_output" not in collected_state:
collected_state.update(
{
"collected_generator_output": filtered_output,
"collected_uids": filtered_uids.copy(),
"num_prompts_in_batch": len(kept_uids),
}
)
else:
collected_state["collected_generator_output"] = concatenate_generator_outputs(
[collected_state["collected_generator_output"], filtered_output]
)
collected_state["collected_uids"].extend(filtered_uids)
collected_state["num_prompts_in_batch"] += len(kept_uids)

# Check if we have enough prompts
if collected_state["num_prompts_in_batch"] < target_batch_size:
logger.info(f"Dynamic sampling: {collected_state['num_prompts_in_batch']} < {target_batch_size} prompts")
logger.info(f"Resample batch {collected_state['sample_batch_count']}, continue sampling...")
return generator_output, uids, True, collected_state
else:
logger.info(
f"Dynamic sampling: collected {collected_state['num_prompts_in_batch']} >= {target_batch_size} prompts"
)
# Truncate to exact batch size if needed
n_samples_per_prompt = sampling_config.get("n_samples_per_prompt", 1)
max_trajectories = target_batch_size * n_samples_per_prompt
final_output = collected_state["collected_generator_output"]
final_uids = collected_state["collected_uids"]

if len(final_uids) > max_trajectories:
final_output = filter_generator_output(final_output, list(range(max_trajectories)))
final_uids = final_uids[:max_trajectories]

return final_output, final_uids, False, None


def filter_generator_output(output: GeneratorOutput, kept_indices: List[int]) -> GeneratorOutput:
"""Filter GeneratorOutput based on kept indices."""
filtered = {
"prompt_token_ids": [output["prompt_token_ids"][i] for i in kept_indices],
"response_ids": [output["response_ids"][i] for i in kept_indices],
"rewards": [output["rewards"][i] for i in kept_indices],
"loss_masks": [output["loss_masks"][i] for i in kept_indices],
"stop_reasons": None,
"rollout_metrics": output.get("rollout_metrics"),
}

if output.get("stop_reasons"):
filtered["stop_reasons"] = [output["stop_reasons"][i] for i in kept_indices]

return filtered
File renamed without changes.
Loading