Skip to content
Merged
36 changes: 31 additions & 5 deletions skyrl-train/docs/checkpointing-logging/checkpointing.rst
Original file line number Diff line number Diff line change
Expand Up @@ -28,24 +28,50 @@ FSDP checkpoints are organized according to the following directory hierarchy:
.. code-block::

{ckpt_path}/
├── latest_ckpt_global_step.txt # Holds the global step of the latest checkpoint
├── global_step_10/ # Checkpoint at training step 10
├── latest_ckpt_global_step.txt # Holds the global step of the latest checkpoint
├── global_step_10/ # Checkpoint at training step 10
│ ├── policy/ # Policy model checkpoint directory
│ │ ├── fsdp_config.json # stores fsdp version and world size
│ │ ├── huggingface/ # HuggingFace config and tokenizer
│ │ ├── config.json # model config
│ │ ├── tokenizer_config.json # tokenizer config
│ │ ├── generation_config.json # generation config
│ │ ├── ... # other tokenizer config files
│ │ ├── model_state.pt # Model parameters
│ │ ├── optimizer_state.pt # Optimizer state
│ │ └── lr_scheduler_state.pt # Learning rate scheduler state
│ ├── critic/ # Critic model checkpoint (if enabled)
│ │ ├── fsdp_config.json
│ │ ├── huggingface/
│ │ ├── model_state.pt
│ │ ├── optimizer_state.pt
│ │ └── lr_scheduler_state.pt
│ ├── data.pt # Dataloader state
│ └── trainer_state.pt # High-level trainer state
├── global_step_20/ # Checkpoint at training step 20
├── global_step_20/ # Checkpoint at training step 20
│ └── ...
└── global_step_30/ # Checkpoint at training step 30
└── global_step_30/ # Checkpoint at training step 30
└── ...

DeepSpeed checkpoints follow a similar directory structure but the files under ``policy`` and ``critic`` are created by the DeepSpeed checkpoint API, and are not explicitly managed by SkyRL.
DeepSpeed checkpoints follow a similar directory structure but the model checkpoint files under ``policy`` and ``critic`` are created by the DeepSpeed checkpoint API, and are not explicitly managed by SkyRL.

.. code-block::

{ckpt_path}/
├── latest_ckpt_global_step.txt # Holds the global step of the latest checkpoint
├── global_step_10/ # Checkpoint at training step 10
│ ├── policy/ # Policy model checkpoint directory
│ │ ├── huggingface/ # HuggingFace config and tokenizer
│ │ ├── global_step10/ # Deepspeed checkpoint directory
│ │ ├── ... # other deepspeed checkpointing files
│ ├── critic/ # Critic model checkpoint (if enabled)
│ │ ├── huggingface/
│ │ ├── global_step10/
│ │ ├── ...
├── global_step_20/ # Checkpoint at training step 20
│ └── ...
└── global_step_30/ # Checkpoint at training step 30
└── ...


Key Configuration Parameters
Expand Down
6 changes: 6 additions & 0 deletions skyrl-train/skyrl_train/distributed/deepspeed_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,7 @@ def save_ckpt(
scheduler=None,
client_state={},
tag=None,
tokenizer=None,
):
if isinstance(model, Actor):
model = model.model
Expand All @@ -277,6 +278,11 @@ def save_ckpt(

model.save_checkpoint(ckpt_dir, tag=tag, client_state=extra_state_dict)

# Save HuggingFace config and tokenizer
if self.is_rank_0():
config_save_model = self._unwrap_model(model)
self.save_hf_configs(config_save_model, ckpt_dir, tokenizer)

def load_ckpt(
self,
model,
Expand Down
11 changes: 11 additions & 0 deletions skyrl-train/skyrl_train/distributed/fsdp_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from typing import List, Union, Optional
from jaxtyping import Float
import gc
import json

import numpy as np
import torch
Expand Down Expand Up @@ -374,6 +375,7 @@ def save_ckpt(
scheduler=None,
client_state={},
tag=None,
tokenizer=None,
):
"""Save model checkpoint for FSDP"""
import warnings
Expand Down Expand Up @@ -445,6 +447,15 @@ def save_ckpt(
# Garbage collect temporary buffers from materializing the state dicts
gc.collect()

if self.is_rank_0():
config_save_model = self._unwrap_model(model)
self.save_hf_configs(config_save_model, ckpt_dir, tokenizer)

# Also save runtime FSDP config
fsdp_config_path = os.path.join(ckpt_dir, "fsdp_config.json")
with open(fsdp_config_path, "w") as f:
json.dump({"fsdp_strategy": self.fsdp_strategy, "world_size": self.world_size}, f, indent=4)

# Final barrier to ensure all operations complete
dist.barrier()
torch.cuda.synchronize()
Expand Down
32 changes: 31 additions & 1 deletion skyrl-train/skyrl_train/distributed/strategy.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import random
import os
from abc import ABC, abstractmethod

import numpy as np
Expand All @@ -7,6 +8,7 @@
from typing import Optional, Dict, Any, Union, TypeVar
import torch.optim as optim
from jaxtyping import Float
from transformers import GenerationConfig


DataT = TypeVar("DataT", bound=Union[Dict[str, Any], torch.Tensor])
Expand Down Expand Up @@ -45,7 +47,7 @@ def optimizer_step(
pass

@abstractmethod
def save_ckpt(self, model, optimizer, scheduler, ckpt_dir, global_step, node_local_rank):
def save_ckpt(self, model, optimizer, scheduler, ckpt_dir, global_step, node_local_rank, tokenizer=None):
"""Save checkpoint"""
pass

Expand All @@ -72,6 +74,34 @@ def get_rank(self) -> int:
"""Get current process rank"""
return dist.get_rank()

def save_hf_configs(self, model, ckpt_dir: str, tokenizer=None):
"""
Save model and tokenizer configs to ckpt_dir/huggingface

Args:
model: AutoModel - the model to save the configs for
ckpt_dir: str - the directory to save the configs to
tokenizer: AutoTokenizer - tokenizer to save
"""
hf_config_tokenizer_path = os.path.join(ckpt_dir, "huggingface")
os.makedirs(hf_config_tokenizer_path, exist_ok=True)
model_config = model.config
generation_config = None
if model.can_generate() and hasattr(model_config, "name_or_path") and model_config.name_or_path:
try:
# Some model's name_or_path is empty if not initialized from pretrained,
# in this cases, we don't save generation config.
generation_config = GenerationConfig.from_pretrained(model_config.name_or_path)
generation_config.save_pretrained(hf_config_tokenizer_path)
except Exception as e:
# if the generation config isn't available, we don't save it
print(f"Warning: Could not save generation config for '{model_config.name_or_path}'. Error: {e}")
pass

model_config.save_pretrained(hf_config_tokenizer_path)
if tokenizer is not None:
tokenizer.save_pretrained(hf_config_tokenizer_path)

@staticmethod
def get_rng_state():
"""Get current RNG state for reproducibility"""
Expand Down
2 changes: 2 additions & 0 deletions skyrl-train/skyrl_train/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1037,6 +1037,7 @@ def save_checkpoints(self):
"save_ckpt",
global_step=self.global_step,
ckpt_dir=policy_save_dir,
tokenizer=self.tokenizer,
)
)

Expand All @@ -1052,6 +1053,7 @@ def save_checkpoints(self):
"save_ckpt",
global_step=self.global_step,
ckpt_dir=critic_save_dir,
tokenizer=self.tokenizer,
)
)

Expand Down
6 changes: 4 additions & 2 deletions skyrl-train/skyrl_train/workers/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -847,14 +847,15 @@ def training_step(self, experience: Experience, global_step, local_step, accumul
status["response_length"] = num_actions
return status

def save_ckpt(self, global_step: int, ckpt_dir: Path):
def save_ckpt(self, global_step: int, ckpt_dir: Path, tokenizer=None):
self.strategy.save_ckpt(
model=self.model,
optimizer=self.optimizer,
scheduler=self.scheduler,
ckpt_dir=ckpt_dir,
global_step=global_step,
node_local_rank=self.get_node_local_rank(),
tokenizer=tokenizer,
)

def load_ckpt(self, ckpt_dir: Path, load_optimizer_states: bool = True, load_lr_scheduler_states: bool = True):
Expand Down Expand Up @@ -1052,14 +1053,15 @@ def training_step(self, experience: Experience, global_step, local_step, accumul
status["raw_grad_norm"] = grad_norm
return status

def save_ckpt(self, global_step: int, ckpt_dir: str):
def save_ckpt(self, global_step: int, ckpt_dir: str, tokenizer=None):
self.strategy.save_ckpt(
model=self.model,
optimizer=self.optimizer,
scheduler=self.scheduler,
ckpt_dir=ckpt_dir,
global_step=global_step,
node_local_rank=self.get_node_local_rank(),
tokenizer=tokenizer,
)

def load_ckpt(self, ckpt_dir=None, load_optimizer_states=True, load_lr_scheduler_states=True):
Expand Down
30 changes: 24 additions & 6 deletions skyrl-train/tests/gpu/test_save_load_ckpt.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""
Run with:
uv run --isolated --extra dev -- pytest tests/gpu/test_save_load_ckpt.py
uv run --isolated --extra dev --with deepspeed -- pytest tests/gpu/test_save_load_ckpt.py
"""

import ray
Expand All @@ -9,7 +9,9 @@
import torch
import os
import shutil
import json
from omegaconf import DictConfig
from transformers import AutoTokenizer

from tests.gpu.utils import init_worker_with_type, make_dummy_experience, get_model_logits_from_actor
from skyrl_train.entrypoints.main_base import config_dir
Expand Down Expand Up @@ -40,7 +42,7 @@ def get_test_actor_config(strategy: str) -> DictConfig:
"fsdp2",
],
)
def test_save_load_checkpoint(strategy):
def test_save_load_checkpoint(ray_init_fixture, strategy):
"""
Test checkpointing logic by:
1. Creating model and doing one training step
Expand All @@ -59,6 +61,7 @@ def test_save_load_checkpoint(strategy):
num_gpus_per_node=cfg.trainer.placement.policy_num_gpus_per_node,
cfg=cfg,
)
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

checkpoint_dir = None
# Create dummy experiences for training steps
Expand All @@ -82,7 +85,25 @@ def test_save_load_checkpoint(strategy):
checkpoint_dir = os.path.expandvars(os.path.join(cfg.trainer.ckpt_path, "global_step_1")) # Store for cleanup

# Step 2: Save checkpoint
ray.get(actor_group.async_run_ray_method("pass_through", "save_ckpt", global_step=1, ckpt_dir=checkpoint_path))
ray.get(
actor_group.async_run_ray_method(
"pass_through", "save_ckpt", global_step=1, ckpt_dir=checkpoint_path, tokenizer=tokenizer
)
)

# check that relevant files are saved
huggingface_dir = os.path.join(checkpoint_path, "huggingface")
expected_files = ["config.json", "generation_config.json", "tokenizer.json"]
for file in expected_files:
assert os.path.exists(
os.path.join(huggingface_dir, file)
), f"File {file} not found in huggingface directory"
if "fsdp" in strategy:
fsdp_config_path = os.path.join(checkpoint_path, "fsdp_config.json")
with open(fsdp_config_path, "r") as f:
fsdp_config = json.load(f)
assert fsdp_config["fsdp_strategy"] == strategy
assert fsdp_config["world_size"] == 2

# Step 3: Do second training step and record results
ray.get(
Expand Down Expand Up @@ -117,9 +138,6 @@ def test_save_load_checkpoint(strategy):
torch.testing.assert_close(logits_after_second_training, logits_after_reload_and_training, atol=0.0, rtol=0.0)

finally:
# Clean up ray
ray.shutdown()

# Clean up checkpoint directory
if checkpoint_dir and os.path.exists(checkpoint_dir):
print(f"Removing checkpoint directory: {checkpoint_dir}")
Expand Down
40 changes: 8 additions & 32 deletions skyrl-train/tests/gpu/test_trainer_full_checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
ensuring that training can resume exactly where it left off.

Run with:
uv run --isolated --extra dev -- pytest tests/gpu/test_trainer_full_checkpointing.py
uv run --isolated --extra dev --with deepspeed -- pytest tests/gpu/test_trainer_full_checkpointing.py
"""

import ray
Expand All @@ -18,10 +18,11 @@
from omegaconf import DictConfig
from torch.utils.data import Dataset
from unittest.mock import MagicMock
from transformers import AutoTokenizer

from skyrl_train.utils.tracking import Tracking
from skyrl_train.trainer import RayPPOTrainer
from tests.gpu.utils import import_worker
from tests.gpu.utils import import_worker, ray_init_for_tests
from skyrl_train.entrypoints.main_base import config_dir

MODEL_NAME = "Qwen/Qwen2.5-0.5B-Instruct"
Expand All @@ -43,25 +44,6 @@ def collate_fn(self, batch):
return batch


class MinimalTokenizer:
"""Minimal tokenizer for testing"""

def __init__(self):
self.pad_token_id = 0
self.eos_token_id = 1
self.vocab_size = 1000

def encode(self, text, **kwargs):
# Return dummy token IDs
return list(range(10))

def decode(self, token_ids, **kwargs):
return f"Decoded: {token_ids}"

def apply_chat_template(self, messages, **kwargs):
return list(range(5)) # Return dummy tokens


def get_test_trainer_config(strategy: str, fsdp2_cpu_offload: bool = False) -> DictConfig:
"""Create minimal trainer config for testing"""
with hydra.initialize_config_dir(config_dir=config_dir):
Expand All @@ -75,10 +57,10 @@ def get_test_trainer_config(strategy: str, fsdp2_cpu_offload: bool = False) -> D

# Use minimal settings for faster testing
cfg.trainer.placement.policy_num_gpus_per_node = 2
cfg.trainer.placement.ref_num_gpus_per_node = 2
cfg.trainer.placement.critic_num_gpus_per_node = 2
cfg.trainer.placement.policy_num_nodes = 1
cfg.trainer.placement.critic_num_nodes = 1
cfg.trainer.placement.ref_num_nodes = 1
cfg.trainer.algorithm.use_kl_loss = False # disable ref model so we just have policy and critic (4 GPUs)
cfg.trainer.placement.colocate_all = False # Disable colocation for simpler testing
cfg.trainer.train_batch_size = 2
cfg.trainer.micro_train_batch_size_per_gpu = 1
Expand All @@ -103,7 +85,7 @@ def get_test_trainer_config(strategy: str, fsdp2_cpu_offload: bool = False) -> D
def create_minimal_trainer(cfg: DictConfig):
"""Create a minimal trainer setup for testing"""
# Create minimal tokenizer
tokenizer = MinimalTokenizer()
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

# Create dummy dataset
train_dataset = DummyDataset(size=4) # Small dataset for quick testing
Expand Down Expand Up @@ -152,7 +134,7 @@ def capture_training_state(trainer):
("fsdp2", True),
],
)
def test_trainer_full_checkpointing(strategy, fsdp2_cpu_offload):
def test_trainer_full_checkpointing(ray_init_fixture, strategy, fsdp2_cpu_offload):
"""
Test full trainer checkpointing by:
1. Creating trainer and setting it up
Expand Down Expand Up @@ -228,7 +210,7 @@ def test_trainer_full_checkpointing(strategy, fsdp2_cpu_offload):

# ============= PHASE 2: Resume from Checkpoint =============
print("Phase 2: Resume from checkpoint")

ray_init_for_tests()
# Create new config with resume enabled
cfg_resume = get_test_trainer_config(strategy, fsdp2_cpu_offload)
cfg_resume.trainer.resume_mode = "from_path" # Enable resume
Expand Down Expand Up @@ -275,12 +257,6 @@ def test_trainer_full_checkpointing(strategy, fsdp2_cpu_offload):
assert latest_step == trainer2.global_step, "Atomic tracking file was not updated after second save"

finally:
# Cleanup
try:
ray.shutdown()
except Exception as e:
print(f"Error shutting down Ray -- it may already be shut down. Error: {e}")

if checkpoint_dir and os.path.exists(os.path.dirname(checkpoint_dir)):
print(f"Cleaning up checkpoint directory: {os.path.dirname(checkpoint_dir)}")
shutil.rmtree(os.path.dirname(checkpoint_dir))