-
Notifications
You must be signed in to change notification settings - Fork 2.1k
[Models] Activation checkpointing from TorchTune #2954
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
Nice! When it works, can you also add a few lines in https://huggingface.co/docs/trl/en/reducing_memory_usage? 🙏 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for adding this @kashif ! Overall it looks great :) In addition to Quentin's comment about the docs, would you mind running a benchmark with e.g. SFT and GRPO so we can have a rough idea of both the memory saved and the impact on throughput?
sure! |
trl/trainer/sft_trainer.py
Outdated
self.activation_offload_context = get_act_offloading_ctx_manager( | ||
model=self.model, | ||
enable_activation_offloading=self.args.enable_activation_offloading, | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd find it more readable like this:
self.activation_offload_context = get_act_offloading_ctx_manager( | |
model=self.model, | |
enable_activation_offloading=self.args.enable_activation_offloading, | |
) | |
if self.args.enable_activation_offloading: | |
self.activation_offload_context = get_act_offloading_ctx_manager( | |
model=self.model, | |
enable_activation_offloading=self.args.enable_activation_offloading, | |
) |
and later in the code:
context = self.activation_offload_context if self.args.enable_activation_offloading else nullcontext()
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
so the get_act_offloading_ctx_manager
returns a nullcontext
when the flag enable_activation_offloading=False
so that is why i do not need to put this in an if statement, or we can remove that logic from the get_act_offloading_ctx_manager method?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I find it more explicit to disable outside the function than inside. But it's really not very important.
…e/trl into activation-checkpoint
* Update sft_config.py * Update sft_trainer.py * Update sft_config.py * Update sft_trainer.py * Apply style fixes --------- Co-authored-by: Kashif Rasul <[email protected]> Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
* ✨ Enhance GRPO logging with configurable completions sampling - Update `GRPOConfig` to replace `log_completions` with `log_completions_steps` - Add `print_prompt_completions_sample()` utility function for rich console logging - Modify `GRPOTrainer` to additionally print 5 random prompt-completion pairs every log_completions_steps steps * GRPO trainer completions logging, move wandb checks together * Add rich availability check and use fallback in print_prompt_completions_sample when rich is not available * Update docstrings on print_prompt_completions_sample Co-authored-by: Quentin Gallouédec <[email protected]> * Revert back to simple log_completions bool * GRPO log completions fully * Remove print fallback from print_prompt_completions_sample * Move accelerator main process check up for grpo log completions * Explicit variable names in print_prompt_completions_sample * Make GRPOConfig docstring match field description * Update log_completions docs again Co-authored-by: Quentin Gallouédec <[email protected]> * Update GRPOConfig docs to match field * improve readibility when prompt or completions are multilines * log reward * prevent hanging, don't print without rich, print reward * style --------- Co-authored-by: Robert Veres <[email protected]> Co-authored-by: Quentin Gallouédec <[email protected]> Co-authored-by: Quentin Gallouédec <[email protected]>
#2921) Co-authored-by: Quentin Gallouédec <[email protected]>
* updated DPO default values for alpha and tau * same for grpo --------- Co-authored-by: Quentin Gallouédec <[email protected]> Co-authored-by: Quentin Gallouédec <[email protected]>
* pin liger-kernel * style
* parameterize enable_prefix_caching * apply review suggestion --------- Co-authored-by: Quentin Gallouédec <[email protected]>
Looks like this PR was parked for now. @kashif did the implementation not work? This is super relevant to me if I am going to use TRL for training long-context reasoners |
@casper-hansen So during training, i see a reduction in memory but the memory jumps up to default as soon as the eval steps starts and i am investigating why... |
@kashif The implementation relies on |
@casper-hansen i have added the documentation as well as a check that disables the activation offloading when the |
In my personal opinion, it's suboptimal to disable in all cases. The only case where it's incompatible is when using the fused linear cross entropy. Maybe this can be fixed in a followup PR. |
here is the plot @casper-hansen using: #!/usr/bin/env python
# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Script to benchmark GPU memory usage with and without activation offloading.
Example:
python trl/scripts/activation_offloading_benchmark.py \
--model_name_or_path Qwen/Qwen2-0.5B \
--dataset_name trl-lib/Capybara \
--start_length 128 \
--max_length 4096 \
--step_size 128 \
--num_train_steps 5 \
--per_device_train_batch_size 1
"""
import argparse
import os
import gc
import sys
import traceback
from pathlib import Path
import torch
import matplotlib.pyplot as plt
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer
from trl import SFTConfig, SFTTrainer, TrlParser
def measure_memory():
"""Measure current and peak GPU memory in GB"""
current = torch.cuda.memory_allocated() / (1024 ** 3)
peak = torch.cuda.max_memory_allocated() / (1024 ** 3)
return current, peak
def reset_memory_stats():
"""Reset the peak memory stats"""
torch.cuda.reset_peak_memory_stats()
torch.cuda.empty_cache()
gc.collect()
def train_with_config(config, model_name, dataset, sequence_length, num_steps):
"""Run training with a specific configuration and measure memory usage"""
try:
# Create model and tokenizer
model = AutoModelForCausalLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
# Set up trainer
trainer = SFTTrainer(
model=model,
args=config,
train_dataset=dataset,
processing_class=tokenizer,
)
# Reset memory stats before training
reset_memory_stats()
# Capture starting memory
_, start_peak = measure_memory()
# Train for a few steps - no parameters needed as max_steps is in config
trainer.train()
# Measure memory after training
_, peak = measure_memory()
# Clean up
del trainer
del model
del tokenizer
reset_memory_stats()
# Return peak memory during training
return peak, True # Success
except Exception as e:
is_oom = "CUDA out of memory" in str(e)
error_msg = "OOM" if is_oom else str(e)
print(f"Error with sequence length {sequence_length}: {error_msg}")
if not is_oom:
traceback.print_exc()
# Clean up - use variable names directly to avoid AttributeError
try:
if 'trainer' in locals():
del trainer
if 'model' in locals():
del model
if 'tokenizer' in locals():
del tokenizer
except:
pass
reset_memory_stats()
return None, False # Failure
def run_benchmark(args):
"""Run the complete benchmark with increasing sequence lengths"""
# Load dataset
dataset = load_dataset(args.dataset_name, split=f"train[:{args.dataset_size}]")
# Results storage
results = {
"with_offloading": {"seq_lengths": [], "memory": []},
"without_offloading": {"seq_lengths": [], "memory": []}
}
# Directory for output
output_dir = Path(args.output_dir)
output_dir.mkdir(exist_ok=True, parents=True)
# Test both with and without activation offloading
for use_offloading in [True, False]:
mode = "with_offloading" if use_offloading else "without_offloading"
print(f"\n{'='*50}\nTesting {mode}\n{'='*50}")
seq_length = args.start_length
while seq_length <= args.max_length:
print(f"\nTesting sequence length: {seq_length}")
# Create config with current sequence length
config = SFTConfig(
output_dir=str(output_dir / f"temp-{mode}-{seq_length}"),
max_length=seq_length,
packing=True,
per_device_train_batch_size=args.per_device_train_batch_size,
gradient_accumulation_steps=1,
learning_rate=args.learning_rate,
logging_steps=1,
max_steps=args.num_train_steps,
activation_offloading=use_offloading,
remove_unused_columns=False,
report_to="none",
)
# Train with this config
peak_memory, success = train_with_config(
config, args.model_name_or_path, dataset, seq_length, args.num_train_steps
)
if peak_memory is not None:
results[mode]["seq_lengths"].append(seq_length)
results[mode]["memory"].append(peak_memory)
print(f"Sequence length {seq_length}: Peak memory {peak_memory:.2f} GB")
if not success:
print(f"Failed at sequence length {seq_length}, stopping {mode} tests")
break
# Increase sequence length for next iteration
seq_length += args.step_size
return results
def plot_results(results, output_path):
"""Plot memory usage for both configurations"""
plt.figure(figsize=(10, 6))
# Plot both configurations
for mode, data in results.items():
if data["seq_lengths"]: # Only plot if we have data
label = "With Activation Offloading" if mode == "with_offloading" else "Without Activation Offloading"
plt.plot(data["seq_lengths"], data["memory"], 'o-', label=label)
plt.xlabel('Sequence Length')
plt.ylabel('Peak GPU Memory Usage (GB)')
plt.title('GPU Memory Usage vs Sequence Length')
plt.grid(True)
plt.legend()
# Save the plot
plt.savefig(output_path)
print(f"Plot saved to {output_path}")
def main():
parser = argparse.ArgumentParser(description="Benchmark activation offloading memory usage")
# Model and dataset arguments
parser.add_argument("--model_name_or_path", type=str, required=True, help="Path to pretrained model")
parser.add_argument("--dataset_name", type=str, default="trl-lib/Capybara", help="Dataset to use for training")
parser.add_argument("--dataset_size", type=int, default=100, help="Number of examples to use from dataset")
# Sequence length parameters
parser.add_argument("--start_length", type=int, default=128, help="Starting sequence length")
parser.add_argument("--max_length", type=int, default=4096, help="Maximum sequence length to try")
parser.add_argument("--step_size", type=int, default=128, help="Increment for sequence length per step")
# Training parameters
parser.add_argument("--num_train_steps", type=int, default=5, help="Number of training steps per sequence length")
parser.add_argument("--per_device_train_batch_size", type=int, default=1, help="Batch size per device")
parser.add_argument("--learning_rate", type=float, default=2e-5, help="Learning rate")
# Output parameters
parser.add_argument("--output_dir", type=str, default="./activation_offloading_benchmark",
help="Directory to save temporary files and output")
args = parser.parse_args()
if not torch.cuda.is_available():
print("CUDA is not available. This benchmark requires a GPU.")
return 1
# Run the benchmark
results = run_benchmark(args)
# Plot the results
plot_path = os.path.join(args.output_dir, "memory_vs_sequence_length.png")
plot_results(results, plot_path)
# Save the raw results
import json
with open(os.path.join(args.output_dir, "results.json"), "w") as f:
json.dump(results, f, indent=2)
return 0
if __name__ == "__main__":
sys.exit(main()) run with: python trl/scripts/activation_offloading_benchmark.py \
--model_name_or_path Qwen/Qwen2-0.5B \
--dataset_name trl-lib/Capybara \
--start_length 128 \
--max_length 2048 \
--step_size 128 \
--num_train_steps 5 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very clean PR and integration with TRL @kashif !
Regarding the benchmark, could you run this with a 7B model so we can see when SFT goes OOM for a given sequence length without activation checkpointing. I guess it may also be simpler to use packing in your benchmark so you are guaranteed to have fixed chunks of size max_length
.
Note that this isn't a blocking requirement to get this merged once @qgallouedec has approved - I think it would mainly be nice to include in the docs as a plot :)
Co-authored-by: lewtun <[email protected]>
Co-authored-by: lewtun <[email protected]>
trl/trainer/sft_config.py
Outdated
activation_offloading (`bool`, *optional*, defaults to `False`): | ||
Whether to offload the activations to the CPU. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you move this to the section > Parameters that control the training
🙏
# Disable offloading for any Liger modules | ||
for name, module in unwrapped_model.named_modules(): | ||
if "liger" in name.lower(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does this function support liger or not? If the user is prevented from using it with Liger, perhaps this part is useless?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wonderful work
What does this PR do?
Adapt Torchtune's activation checkpointing for HF models