Skip to content
Merged
Show file tree
Hide file tree
Changes from 40 commits
Commits
Show all changes
45 commits
Select commit Hold shift + click to select a range
a1a2582
initial activaton offloading from torchtune
kashif Feb 25, 2025
1e5b989
formatting
kashif Feb 25, 2025
12afc5a
add support to SFT Trainer
kashif Feb 25, 2025
c25facb
License and minor style changes
qgallouedec Feb 25, 2025
7632e5e
Doc
qgallouedec Feb 25, 2025
408284c
fix path in docs
qgallouedec Feb 25, 2025
0d156db
use activation_offloading
kashif Feb 25, 2025
0aa23ad
only doc get_act_offloading_ctx_manager
qgallouedec Feb 25, 2025
bfb1333
Merge branch 'activation-checkpoint' of https://github.com/huggingfac…
qgallouedec Feb 25, 2025
acc921c
✋ Prevent applying the chat template to tokenized datasets (#2939)
DanFosing Feb 24, 2025
f19bfa0
📇 GRPO: print completions to console and update docs (#2951)
nopepper Feb 24, 2025
31963ec
↩️ Fix typo in TextEnvironment init param, should be max_tool_respons…
shenxiangzhuang Feb 24, 2025
f66feb2
🗿 Updated DPO default values for alpha and tau (#2918)
Ishan-Kumar2 Feb 24, 2025
5b5caf1
📌 Pin liger-kernel and vLLM (#2952)
qgallouedec Feb 24, 2025
e7dcbfa
⏪ Parameterize `enable_prefix_caching` (#2900)
ji-huazhong Feb 24, 2025
7eb5e68
🔢 Fix GRPO doc about `num_iterations` (#2966)
qgallouedec Feb 26, 2025
a483c89
Update grpo_trainer.py (#2973)
tpoisonooo Feb 27, 2025
bf830c5
Merge branch 'main' into activation-checkpoint
kashif Mar 3, 2025
95a52d9
Merge branch 'main' into activation-checkpoint
kashif Mar 18, 2025
eb4d315
only for training
kashif Mar 18, 2025
04ea97f
better docs
kashif Mar 18, 2025
6443645
Merge branch 'main' into activation-checkpoint
kashif Mar 19, 2025
141abc4
ignore any modules with Liger
kashif Mar 19, 2025
8a0da7c
Only offload if activation is on CUDA
kashif Apr 1, 2025
5365106
Merge branch 'main' into activation-checkpoint
kashif Apr 1, 2025
02d936c
Merge branch 'main' into activation-checkpoint
kashif Apr 26, 2025
7abe1e5
fix for pytorch 2.4
kashif Apr 26, 2025
e3756c2
add docs
kashif Apr 26, 2025
d3e0726
Validate that activation_offloading and use_liger_kernel aren't both …
kashif Apr 26, 2025
74c3b55
fix CI
kashif Apr 26, 2025
eb06992
add tip
kashif Apr 26, 2025
d74cf84
Revert "add tip"
kashif Apr 26, 2025
9db9043
Merge branch 'main' into activation-checkpoint
kashif Apr 26, 2025
1935673
fix tests
kashif Apr 27, 2025
02ddbe5
Update docs/source/reducing_memory_usage.md
kashif Apr 28, 2025
9553718
Update docs/source/reducing_memory_usage.md
kashif Apr 28, 2025
d0d26e3
move sft test
qgallouedec Apr 28, 2025
96348d6
fix license
qgallouedec Apr 28, 2025
b317b0f
maybe_activation_offload_context and some doc
qgallouedec Apr 28, 2025
ff7dd54
better arg ordering and move param validation in SFTTrainer
qgallouedec Apr 28, 2025
0df3453
disable warning
qgallouedec Apr 28, 2025
82e3f91
better check for peft
kashif Apr 29, 2025
6cbd67d
Merge branch 'main' into activation-checkpoint
qgallouedec May 2, 2025
9ce14f8
Merge branch 'main' into activation-checkpoint
kashif May 7, 2025
c43f58f
move tests to SLOW as it needs accelerator
kashif May 7, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/source/_toctree.yml
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,8 @@
title: Trainers
- local: models
title: Model Classes
- local: model_utils
title: Model Utilities
- local: best_of_n
title: Best of N Sampling
- local: judges
Expand Down
5 changes: 5 additions & 0 deletions docs/source/model_utils.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# Model Utilities

## get_act_offloading_ctx_manager

[[autodoc]] models.get_act_offloading_ctx_manager
38 changes: 37 additions & 1 deletion docs/source/reducing_memory_usage.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ Sequence lengths in the dataset can vary widely. When data is batched, sequences
<img src="https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/why_you_should_truncate.png" alt="Truncation prompt completion" width="600"/>
</div>

To reduce memory usage, its important to truncate sequences to a reasonable length. While TRL trainers truncate sequences by default, you may want to adjust the default truncation length to better align with your specific use case.
To reduce memory usage, it's important to truncate sequences to a reasonable length. While TRL trainers truncate sequences by default, you may want to adjust the default truncation length to better align with your specific use case.

<hfoptions id="dpo">
<hfoption id="DPO">
Expand Down Expand Up @@ -129,6 +129,42 @@ training_args = SFTConfig(..., padding_free=True, model_init_kwargs={"attn_imple
</hfoption>
</hfoptions>

## Activation offloading

Activation offloading is a memory efficiency technique that reduces GPU VRAM usage by temporarily moving activation tensors to CPU RAM during the forward pass and bringing them back only when needed for the backward pass. This significantly reduces peak memory usage at the cost of slightly increased training time.

To enable activation offloading in your SFT training configuration:

</hfoption>
<hfoption id="SFT">

```python
from trl import SFTConfig

training_args = SFTConfig(..., activation_offloading=True)
```

</hfoption>
</hfoptions>

<Tip warning={true}>

When using activation offloading with models that use Liger kernels, you must disable Liger cross entropy due to compatibility issues. The issue occurs specifically with `use_liger_kernel=True` because Liger cross entropy performs in-place operations which conflict with activation offloading. The default setting (`use_liger_kernel=False`) works:

```python
# When using activation offloading with a model that uses Liger kernels:
from trl import SFTConfig

training_args = SFTConfig(
activation_offloading=True,
use_liger_kernel=False, # Disable Liger cross entropy
# Other parameters...
)
```
</Tip>

Under the hood, activation offloading implements PyTorch's [`saved_tensors_hooks`](https://pytorch.org/tutorials/intermediate/autograd_saved_tensors_hooks_tutorial.html#hooks-for-autograd-saved-tensors) to intercept activations during the forward pass. It intelligently manages which tensors to offload based on size and context, avoiding offloading output tensors which would be inefficient. For performance optimization, it can optionally use CUDA streams to overlap computation with CPU-GPU transfers.

## Disabling model gathering for generation in online methods

When using DeepSpeed ZeRO-3, model weights are sharded across multiple GPUs. Online methods involve generating completions from the model as part of the training process. During this step, the model weights are temporarily gathered on a single GPU for generation. For very large models, this gathering can lead to out-of-memory (OOM) errors, as described in this issue: [#2250](https://github.com/huggingface/trl/issues/2250#issue-2598304204).
Expand Down
156 changes: 156 additions & 0 deletions tests/test_activation_offloading.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest

import torch
from torch import nn
from transformers import AutoModelForCausalLM
from transformers.testing_utils import require_peft, require_torch_accelerator
from transformers.utils import is_peft_available

from trl.models.activation_offloading import NoOpManager, OffloadActivations


if is_peft_available():
from peft import LoraConfig, get_peft_model


class TestActivationOffloading(unittest.TestCase):
@require_torch_accelerator
@require_peft
def test_offloading_with_peft_models(self) -> None:
"""Test that activation offloading works with PEFT models."""
model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5"
model = AutoModelForCausalLM.from_pretrained(model_id).cuda()
peft_config = LoraConfig(
lora_alpha=16,
lora_dropout=0.1,
r=8,
bias="none",
task_type="CAUSAL_LM",
)

model = get_peft_model(model, peft_config)
inp = torch.randint(0, 100, (2, 10), device="cuda")

# First forward-backward pass without offloading
torch.manual_seed(42)
loss = model(inp, labels=inp).loss
loss.backward()

# Store gradients - only from trainable parameters
grads_original = []
for name, param in model.named_parameters():
if param.requires_grad and param.grad is not None:
grads_original.append((name, param.grad.clone()))

# Reset gradients
for p in model.parameters():
if p.grad is not None:
p.grad = None

# Second forward-backward pass with offloading
torch.manual_seed(42)
with OffloadActivations():
loss_c = model(inp, labels=inp).loss
loss_c.backward()

# Compare gradients - only trainable parameters
for name_orig, grad_orig in grads_original:
for name_param, param in model.named_parameters():
if name_param == name_orig and param.requires_grad and param.grad is not None:
self.assertTrue(
torch.allclose(grad_orig, param.grad, rtol=1e-4, atol=1e-5),
f"Gradient mismatch for {name_orig}",
)

@require_torch_accelerator
def test_noop_manager_with_offloading(self):
model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5"
model = AutoModelForCausalLM.from_pretrained(model_id).cuda()
inp = torch.randint(0, 100, (2, 10), device="cuda")

# Run with offloading but disable for specific section
with OffloadActivations():
# First forward-backward with normal offloading
torch.manual_seed(42)
out1 = model(inp, labels=inp)
out1.loss.backward()
grads1 = [p.grad.clone() for p in model.parameters()]

# Reset grads
for p in model.parameters():
p.grad = None

# Second forward-backward with NoOpManager
with NoOpManager():
torch.manual_seed(42)
out2 = model(inp, labels=inp)
out2.loss.backward()

grads2 = [p.grad.clone() for p in model.parameters()]

# Gradients should match as NoOpManager should have prevented offloading
for g1, g2 in zip(grads1, grads2):
self.assertTrue(torch.allclose(g1, g2, rtol=1e-4, atol=1e-5))

@require_torch_accelerator
def test_min_offload_size(self):
"""Test that tensors smaller than min_offload_size aren't offloaded"""
model = nn.Sequential(
nn.Linear(5, 5), # Small layer that shouldn't be offloaded
nn.Linear(5, 1000), # Large layer that should be offloaded
).cuda()

inp = torch.randn(2, 5, device="cuda")

with OffloadActivations(min_offload_size=1000):
out = model(inp)
out.sum().backward()

# The test passes if no errors occur, as we're mainly testing
# that the logic handles both offloaded and non-offloaded tensors

@require_torch_accelerator
def test_real_hf_model(self):
"""Test with an actual HuggingFace model"""
model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5"
model = AutoModelForCausalLM.from_pretrained(model_id).cuda()

# Create small input
inp = torch.randint(0, 100, (2, 10), device="cuda")

# Baseline without offloading
torch.manual_seed(42)
out1 = model(inp, labels=inp).loss
out1.backward()
grads1 = [p.grad.clone() for p in model.parameters()]

# Reset grads
for p in model.parameters():
p.grad = None

# With offloading
with OffloadActivations():
torch.manual_seed(42)
out2 = model(inp, labels=inp).loss
out2.backward()

grads2 = [p.grad.clone() for p in model.parameters()]

# Check outputs and gradients match
self.assertTrue(torch.allclose(out1, out2, rtol=1e-5))
for g1, g2 in zip(grads1, grads2):
self.assertTrue(torch.allclose(g1, g2, rtol=1e-5))
33 changes: 32 additions & 1 deletion tests/test_sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
TrainingArguments,
is_vision_available,
)
from transformers.testing_utils import require_flash_attn, require_peft, require_vision
from transformers.testing_utils import require_flash_attn, require_peft, require_torch_accelerator, require_vision
from transformers.utils import is_peft_available

from trl import SFTConfig, SFTTrainer
Expand Down Expand Up @@ -1229,3 +1229,34 @@ def test_train_padding_free(self):
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed")

@require_torch_accelerator
def test_train_offloading(self):
"""Test that activation offloading works with SFTTrainer."""
# Get the dataset
dataset = load_dataset("trl-internal-testing/zen", "standard_language_modeling", split="train")

with tempfile.TemporaryDirectory() as tmp_dir:
# Initialize the trainer
training_args = SFTConfig(
output_dir=tmp_dir,
activation_offloading=True,
report_to="none",
)
trainer = SFTTrainer(
model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", args=training_args, train_dataset=dataset
)

# Save the initial parameters to compare them later
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}

# Train the model
trainer.train()

# Check that the training loss is not None
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])

# Check the params have changed
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed")
2 changes: 2 additions & 0 deletions trl/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@


_import_structure = {
"activation_offloading": ["get_act_offloading_ctx_manager"],
"modeling_base": ["GeometricMixtureWrapper", "PreTrainedModelWrapper", "create_reference_model"],
"modeling_value_head": ["AutoModelForCausalLMWithValueHead", "AutoModelForSeq2SeqLMWithValueHead"],
"utils": [
Expand All @@ -43,6 +44,7 @@
]

if TYPE_CHECKING:
from .activation_offloading import get_act_offloading_ctx_manager
from .modeling_base import GeometricMixtureWrapper, PreTrainedModelWrapper, create_reference_model
from .modeling_value_head import AutoModelForCausalLMWithValueHead, AutoModelForSeq2SeqLMWithValueHead
from .utils import (
Expand Down
Loading
Loading