Skip to content

[Models] Activation checkpointing from TorchTune #2954

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 45 commits into from
May 7, 2025
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
Show all changes
45 commits
Select commit Hold shift + click to select a range
a1a2582
initial activaton offloading from torchtune
kashif Feb 25, 2025
1e5b989
formatting
kashif Feb 25, 2025
12afc5a
add support to SFT Trainer
kashif Feb 25, 2025
c25facb
License and minor style changes
qgallouedec Feb 25, 2025
7632e5e
Doc
qgallouedec Feb 25, 2025
408284c
fix path in docs
qgallouedec Feb 25, 2025
0d156db
use activation_offloading
kashif Feb 25, 2025
0aa23ad
only doc get_act_offloading_ctx_manager
qgallouedec Feb 25, 2025
bfb1333
Merge branch 'activation-checkpoint' of https://github.com/huggingfac…
qgallouedec Feb 25, 2025
acc921c
✋ Prevent applying the chat template to tokenized datasets (#2939)
DanFosing Feb 24, 2025
f19bfa0
📇 GRPO: print completions to console and update docs (#2951)
nopepper Feb 24, 2025
31963ec
↩️ Fix typo in TextEnvironment init param, should be max_tool_respons…
shenxiangzhuang Feb 24, 2025
f66feb2
🗿 Updated DPO default values for alpha and tau (#2918)
Ishan-Kumar2 Feb 24, 2025
5b5caf1
📌 Pin liger-kernel and vLLM (#2952)
qgallouedec Feb 24, 2025
e7dcbfa
⏪ Parameterize `enable_prefix_caching` (#2900)
ji-huazhong Feb 24, 2025
7eb5e68
🔢 Fix GRPO doc about `num_iterations` (#2966)
qgallouedec Feb 26, 2025
a483c89
Update grpo_trainer.py (#2973)
tpoisonooo Feb 27, 2025
bf830c5
Merge branch 'main' into activation-checkpoint
kashif Mar 3, 2025
95a52d9
Merge branch 'main' into activation-checkpoint
kashif Mar 18, 2025
eb4d315
only for training
kashif Mar 18, 2025
04ea97f
better docs
kashif Mar 18, 2025
6443645
Merge branch 'main' into activation-checkpoint
kashif Mar 19, 2025
141abc4
ignore any modules with Liger
kashif Mar 19, 2025
8a0da7c
Only offload if activation is on CUDA
kashif Apr 1, 2025
5365106
Merge branch 'main' into activation-checkpoint
kashif Apr 1, 2025
02d936c
Merge branch 'main' into activation-checkpoint
kashif Apr 26, 2025
7abe1e5
fix for pytorch 2.4
kashif Apr 26, 2025
e3756c2
add docs
kashif Apr 26, 2025
d3e0726
Validate that activation_offloading and use_liger_kernel aren't both …
kashif Apr 26, 2025
74c3b55
fix CI
kashif Apr 26, 2025
eb06992
add tip
kashif Apr 26, 2025
d74cf84
Revert "add tip"
kashif Apr 26, 2025
9db9043
Merge branch 'main' into activation-checkpoint
kashif Apr 26, 2025
1935673
fix tests
kashif Apr 27, 2025
02ddbe5
Update docs/source/reducing_memory_usage.md
kashif Apr 28, 2025
9553718
Update docs/source/reducing_memory_usage.md
kashif Apr 28, 2025
d0d26e3
move sft test
qgallouedec Apr 28, 2025
96348d6
fix license
qgallouedec Apr 28, 2025
b317b0f
maybe_activation_offload_context and some doc
qgallouedec Apr 28, 2025
ff7dd54
better arg ordering and move param validation in SFTTrainer
qgallouedec Apr 28, 2025
0df3453
disable warning
qgallouedec Apr 28, 2025
82e3f91
better check for peft
kashif Apr 29, 2025
6cbd67d
Merge branch 'main' into activation-checkpoint
qgallouedec May 2, 2025
9ce14f8
Merge branch 'main' into activation-checkpoint
kashif May 7, 2025
c43f58f
move tests to SLOW as it needs accelerator
kashif May 7, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
184 changes: 184 additions & 0 deletions tests/test_activation_offloading.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,184 @@
import gc
import tempfile
import unittest

import torch
from datasets import load_dataset
from torch import nn
from transformers import AutoModelForCausalLM
from transformers.testing_utils import require_peft, require_torch_accelerator
from transformers.utils import is_peft_available

from trl.models.activation_offloading import NoOpManager, OffloadActivations
from trl.trainer.sft_trainer import SFTConfig, SFTTrainer


if is_peft_available():
from peft import LoraConfig, get_peft_model


class TestActivationOffloading(unittest.TestCase):
def setUp(self):
self.train_dataset = load_dataset("stanfordnlp/imdb", split="train[:10%]")
self.eval_dataset = load_dataset("stanfordnlp/imdb", split="test[:10%]")
self.max_length = 128
self.peft_config = LoraConfig(
lora_alpha=16,
lora_dropout=0.1,
r=8,
bias="none",
task_type="CAUSAL_LM",
)

def tearDown(self):
gc.collect()
torch.cuda.empty_cache()
gc.collect()

@require_torch_accelerator
def test_offloading_with_sft_trainer(self) -> None:
"""Test that activation offloading works with SFTTrainer."""
model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5"
model = AutoModelForCausalLM.from_pretrained(model_id).cuda()

with tempfile.TemporaryDirectory() as tmp_dir:
training_args = SFTConfig(
output_dir=tmp_dir,
per_device_train_batch_size=2,
max_steps=1,
enable_activation_offloading=True,
report_to="none",
)

trainer = SFTTrainer(
model=model,
args=training_args,
train_dataset=self.train_dataset,
eval_dataset=self.eval_dataset,
)

# Train for one step
trainer.train()

# Verify training completed successfully
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])

@require_torch_accelerator
@require_peft
def test_offloading_with_peft_models(self) -> None:
"""Test that activation offloading works with PEFT models."""
model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5"
model = AutoModelForCausalLM.from_pretrained(model_id).cuda()

model = get_peft_model(model, self.peft_config)
inp = torch.randint(0, 100, (2, 10), device="cuda")

# First forward-backward pass without offloading
torch.manual_seed(42)
loss = model(inp, labels=inp).loss
loss.backward()
# Store gradients
# Store gradients - only from trainable parameters
grads_original = []
for name, param in model.named_parameters():
if param.requires_grad and param.grad is not None:
grads_original.append((name, param.grad.clone()))

# Reset gradients
for p in model.parameters():
if p.grad is not None:
p.grad = None

# Second forward-backward pass with offloading
torch.manual_seed(42)
with OffloadActivations(use_streams=True):
loss_c = model(inp, labels=inp).loss
loss_c.backward()

# Compare gradients - only trainable parameters
for name_orig, grad_orig in grads_original:
for name_param, param in model.named_parameters():
if name_param == name_orig and param.requires_grad and param.grad is not None:
self.assertTrue(
torch.allclose(grad_orig, param.grad, rtol=1e-4, atol=1e-5),
f"Gradient mismatch for {name_orig}",
)

@require_torch_accelerator
def test_noop_manager_with_offloading(self):
model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5"
model = AutoModelForCausalLM.from_pretrained(model_id).cuda()
inp = torch.randint(0, 100, (2, 10), device="cuda")

# Run with offloading but disable for specific section
with OffloadActivations(use_streams=True):
# First forward-backward with normal offloading
torch.manual_seed(42)
out1 = model(inp, labels=inp)
out1.loss.backward()
grads1 = [p.grad.clone() for p in model.parameters()]

# Reset grads
for p in model.parameters():
p.grad = None

# Second forward-backward with NoOpManager
with NoOpManager():
torch.manual_seed(42)
out2 = model(inp, labels=inp)
out2.loss.backward()

grads2 = [p.grad.clone() for p in model.parameters()]

# Gradients should match as NoOpManager should have prevented offloading
for g1, g2 in zip(grads1, grads2):
self.assertTrue(torch.allclose(g1, g2, rtol=1e-4, atol=1e-5))

@require_torch_accelerator
def test_min_offload_size(self):
"""Test that tensors smaller than min_offload_size aren't offloaded"""
model = nn.Sequential(
nn.Linear(5, 5), # Small layer that shouldn't be offloaded
nn.Linear(5, 1000), # Large layer that should be offloaded
).cuda()

inp = torch.randn(2, 5, device="cuda")

with OffloadActivations(min_offload_size=1000):
out = model(inp)
out.sum().backward()

# The test passes if no errors occur, as we're mainly testing
# that the logic handles both offloaded and non-offloaded tensors

@require_torch_accelerator
def test_real_hf_model(self):
"""Test with an actual HuggingFace model"""
model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5"
model = AutoModelForCausalLM.from_pretrained(model_id).cuda()

# Create small input
inp = torch.randint(0, 100, (2, 10), device="cuda")

# Baseline without offloading
torch.manual_seed(42)
out1 = model(inp, labels=inp).loss
out1.backward()
grads1 = [p.grad.clone() for p in model.parameters()]

# Reset grads
for p in model.parameters():
p.grad = None

# With offloading
with OffloadActivations(use_streams=True):
torch.manual_seed(42)
out2 = model(inp, labels=inp).loss
out2.backward()

grads2 = [p.grad.clone() for p in model.parameters()]

# Check outputs and gradients match
self.assertTrue(torch.allclose(out1, out2, rtol=1e-5))
for g1, g2 in zip(grads1, grads2):
self.assertTrue(torch.allclose(g1, g2, rtol=1e-5))
2 changes: 2 additions & 0 deletions trl/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@


_import_structure = {
"activation_offloading": ["get_act_offloading_ctx_manager"],
"modeling_base": ["GeometricMixtureWrapper", "PreTrainedModelWrapper", "create_reference_model"],
"modeling_value_head": ["AutoModelForCausalLMWithValueHead", "AutoModelForSeq2SeqLMWithValueHead"],
"utils": ["SUPPORTED_ARCHITECTURES", "prepare_deepspeed", "setup_chat_format", "unwrap_model_for_generation"],
Expand All @@ -37,6 +38,7 @@
]

if TYPE_CHECKING:
from .activation_offloading import get_act_offloading_ctx_manager
from .modeling_base import GeometricMixtureWrapper, PreTrainedModelWrapper, create_reference_model
from .modeling_value_head import AutoModelForCausalLMWithValueHead, AutoModelForSeq2SeqLMWithValueHead
from .utils import SUPPORTED_ARCHITECTURES, prepare_deepspeed, setup_chat_format, unwrap_model_for_generation
Expand Down
Loading