generated from fastai/nbdev_template
-
Notifications
You must be signed in to change notification settings - Fork 2.1k
[Models] Activation checkpointing from TorchTune #2954
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 3 commits
Commits
Show all changes
45 commits
Select commit
Hold shift + click to select a range
a1a2582
initial activaton offloading from torchtune
kashif 1e5b989
formatting
kashif 12afc5a
add support to SFT Trainer
kashif c25facb
License and minor style changes
qgallouedec 7632e5e
Doc
qgallouedec 408284c
fix path in docs
qgallouedec 0d156db
use activation_offloading
kashif 0aa23ad
only doc get_act_offloading_ctx_manager
qgallouedec bfb1333
Merge branch 'activation-checkpoint' of https://github.com/huggingfac…
qgallouedec acc921c
✋ Prevent applying the chat template to tokenized datasets (#2939)
DanFosing f19bfa0
📇 GRPO: print completions to console and update docs (#2951)
nopepper 31963ec
↩️ Fix typo in TextEnvironment init param, should be max_tool_respons…
shenxiangzhuang f66feb2
🗿 Updated DPO default values for alpha and tau (#2918)
Ishan-Kumar2 5b5caf1
📌 Pin liger-kernel and vLLM (#2952)
qgallouedec e7dcbfa
⏪ Parameterize `enable_prefix_caching` (#2900)
ji-huazhong 7eb5e68
🔢 Fix GRPO doc about `num_iterations` (#2966)
qgallouedec a483c89
Update grpo_trainer.py (#2973)
tpoisonooo bf830c5
Merge branch 'main' into activation-checkpoint
kashif 95a52d9
Merge branch 'main' into activation-checkpoint
kashif eb4d315
only for training
kashif 04ea97f
better docs
kashif 6443645
Merge branch 'main' into activation-checkpoint
kashif 141abc4
ignore any modules with Liger
kashif 8a0da7c
Only offload if activation is on CUDA
kashif 5365106
Merge branch 'main' into activation-checkpoint
kashif 02d936c
Merge branch 'main' into activation-checkpoint
kashif 7abe1e5
fix for pytorch 2.4
kashif e3756c2
add docs
kashif d3e0726
Validate that activation_offloading and use_liger_kernel aren't both …
kashif 74c3b55
fix CI
kashif eb06992
add tip
kashif d74cf84
Revert "add tip"
kashif 9db9043
Merge branch 'main' into activation-checkpoint
kashif 1935673
fix tests
kashif 02ddbe5
Update docs/source/reducing_memory_usage.md
kashif 9553718
Update docs/source/reducing_memory_usage.md
kashif d0d26e3
move sft test
qgallouedec 96348d6
fix license
qgallouedec b317b0f
maybe_activation_offload_context and some doc
qgallouedec ff7dd54
better arg ordering and move param validation in SFTTrainer
qgallouedec 0df3453
disable warning
qgallouedec 82e3f91
better check for peft
kashif 6cbd67d
Merge branch 'main' into activation-checkpoint
qgallouedec 9ce14f8
Merge branch 'main' into activation-checkpoint
kashif c43f58f
move tests to SLOW as it needs accelerator
kashif File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,184 @@ | ||
import gc | ||
import tempfile | ||
import unittest | ||
|
||
import torch | ||
from datasets import load_dataset | ||
from torch import nn | ||
from transformers import AutoModelForCausalLM | ||
from transformers.testing_utils import require_peft, require_torch_accelerator | ||
from transformers.utils import is_peft_available | ||
|
||
from trl.models.activation_offloading import NoOpManager, OffloadActivations | ||
from trl.trainer.sft_trainer import SFTConfig, SFTTrainer | ||
|
||
|
||
if is_peft_available(): | ||
from peft import LoraConfig, get_peft_model | ||
|
||
|
||
class TestActivationOffloading(unittest.TestCase): | ||
def setUp(self): | ||
self.train_dataset = load_dataset("stanfordnlp/imdb", split="train[:10%]") | ||
self.eval_dataset = load_dataset("stanfordnlp/imdb", split="test[:10%]") | ||
self.max_length = 128 | ||
self.peft_config = LoraConfig( | ||
lora_alpha=16, | ||
lora_dropout=0.1, | ||
r=8, | ||
bias="none", | ||
task_type="CAUSAL_LM", | ||
) | ||
|
||
def tearDown(self): | ||
gc.collect() | ||
torch.cuda.empty_cache() | ||
gc.collect() | ||
|
||
@require_torch_accelerator | ||
def test_offloading_with_sft_trainer(self) -> None: | ||
"""Test that activation offloading works with SFTTrainer.""" | ||
model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5" | ||
model = AutoModelForCausalLM.from_pretrained(model_id).cuda() | ||
|
||
with tempfile.TemporaryDirectory() as tmp_dir: | ||
training_args = SFTConfig( | ||
output_dir=tmp_dir, | ||
per_device_train_batch_size=2, | ||
max_steps=1, | ||
enable_activation_offloading=True, | ||
report_to="none", | ||
) | ||
|
||
trainer = SFTTrainer( | ||
model=model, | ||
args=training_args, | ||
train_dataset=self.train_dataset, | ||
qgallouedec marked this conversation as resolved.
Show resolved
Hide resolved
|
||
eval_dataset=self.eval_dataset, | ||
qgallouedec marked this conversation as resolved.
Show resolved
Hide resolved
|
||
) | ||
|
||
# Train for one step | ||
trainer.train() | ||
|
||
# Verify training completed successfully | ||
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) | ||
|
||
@require_torch_accelerator | ||
@require_peft | ||
def test_offloading_with_peft_models(self) -> None: | ||
"""Test that activation offloading works with PEFT models.""" | ||
model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5" | ||
model = AutoModelForCausalLM.from_pretrained(model_id).cuda() | ||
|
||
model = get_peft_model(model, self.peft_config) | ||
inp = torch.randint(0, 100, (2, 10), device="cuda") | ||
|
||
# First forward-backward pass without offloading | ||
torch.manual_seed(42) | ||
loss = model(inp, labels=inp).loss | ||
loss.backward() | ||
# Store gradients | ||
# Store gradients - only from trainable parameters | ||
grads_original = [] | ||
for name, param in model.named_parameters(): | ||
if param.requires_grad and param.grad is not None: | ||
grads_original.append((name, param.grad.clone())) | ||
|
||
# Reset gradients | ||
for p in model.parameters(): | ||
if p.grad is not None: | ||
p.grad = None | ||
|
||
# Second forward-backward pass with offloading | ||
torch.manual_seed(42) | ||
with OffloadActivations(use_streams=True): | ||
loss_c = model(inp, labels=inp).loss | ||
loss_c.backward() | ||
|
||
# Compare gradients - only trainable parameters | ||
for name_orig, grad_orig in grads_original: | ||
for name_param, param in model.named_parameters(): | ||
if name_param == name_orig and param.requires_grad and param.grad is not None: | ||
self.assertTrue( | ||
torch.allclose(grad_orig, param.grad, rtol=1e-4, atol=1e-5), | ||
f"Gradient mismatch for {name_orig}", | ||
) | ||
|
||
@require_torch_accelerator | ||
def test_noop_manager_with_offloading(self): | ||
model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5" | ||
model = AutoModelForCausalLM.from_pretrained(model_id).cuda() | ||
inp = torch.randint(0, 100, (2, 10), device="cuda") | ||
|
||
# Run with offloading but disable for specific section | ||
with OffloadActivations(use_streams=True): | ||
# First forward-backward with normal offloading | ||
torch.manual_seed(42) | ||
out1 = model(inp, labels=inp) | ||
out1.loss.backward() | ||
grads1 = [p.grad.clone() for p in model.parameters()] | ||
|
||
# Reset grads | ||
for p in model.parameters(): | ||
p.grad = None | ||
|
||
# Second forward-backward with NoOpManager | ||
with NoOpManager(): | ||
torch.manual_seed(42) | ||
out2 = model(inp, labels=inp) | ||
out2.loss.backward() | ||
|
||
grads2 = [p.grad.clone() for p in model.parameters()] | ||
|
||
# Gradients should match as NoOpManager should have prevented offloading | ||
for g1, g2 in zip(grads1, grads2): | ||
self.assertTrue(torch.allclose(g1, g2, rtol=1e-4, atol=1e-5)) | ||
|
||
@require_torch_accelerator | ||
def test_min_offload_size(self): | ||
"""Test that tensors smaller than min_offload_size aren't offloaded""" | ||
model = nn.Sequential( | ||
nn.Linear(5, 5), # Small layer that shouldn't be offloaded | ||
nn.Linear(5, 1000), # Large layer that should be offloaded | ||
).cuda() | ||
|
||
inp = torch.randn(2, 5, device="cuda") | ||
|
||
with OffloadActivations(min_offload_size=1000): | ||
out = model(inp) | ||
out.sum().backward() | ||
|
||
# The test passes if no errors occur, as we're mainly testing | ||
# that the logic handles both offloaded and non-offloaded tensors | ||
|
||
@require_torch_accelerator | ||
def test_real_hf_model(self): | ||
"""Test with an actual HuggingFace model""" | ||
model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5" | ||
model = AutoModelForCausalLM.from_pretrained(model_id).cuda() | ||
|
||
# Create small input | ||
inp = torch.randint(0, 100, (2, 10), device="cuda") | ||
|
||
# Baseline without offloading | ||
torch.manual_seed(42) | ||
out1 = model(inp, labels=inp).loss | ||
out1.backward() | ||
grads1 = [p.grad.clone() for p in model.parameters()] | ||
|
||
# Reset grads | ||
for p in model.parameters(): | ||
p.grad = None | ||
|
||
# With offloading | ||
with OffloadActivations(use_streams=True): | ||
torch.manual_seed(42) | ||
out2 = model(inp, labels=inp).loss | ||
out2.backward() | ||
|
||
grads2 = [p.grad.clone() for p in model.parameters()] | ||
|
||
# Check outputs and gradients match | ||
self.assertTrue(torch.allclose(out1, out2, rtol=1e-5)) | ||
for g1, g2 in zip(grads1, grads2): | ||
self.assertTrue(torch.allclose(g1, g2, rtol=1e-5)) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.