Skip to content

[Models] Activation checkpointing from TorchTune #2954

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 45 commits into from
May 7, 2025
Merged

Conversation

kashif
Copy link
Collaborator

@kashif kashif commented Feb 25, 2025

What does this PR do?

Adapt Torchtune's activation checkpointing for HF models

@kashif kashif requested a review from lewtun February 25, 2025 10:09
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@qgallouedec
Copy link
Member

Nice! When it works, can you also add a few lines in https://huggingface.co/docs/trl/en/reducing_memory_usage? 🙏

Copy link
Member

@lewtun lewtun left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding this @kashif ! Overall it looks great :) In addition to Quentin's comment about the docs, would you mind running a benchmark with e.g. SFT and GRPO so we can have a rough idea of both the memory saved and the impact on throughput?

@kashif
Copy link
Collaborator Author

kashif commented Feb 25, 2025

sure!

Comment on lines 245 to 248
self.activation_offload_context = get_act_offloading_ctx_manager(
model=self.model,
enable_activation_offloading=self.args.enable_activation_offloading,
)
Copy link
Member

@qgallouedec qgallouedec Feb 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd find it more readable like this:

Suggested change
self.activation_offload_context = get_act_offloading_ctx_manager(
model=self.model,
enable_activation_offloading=self.args.enable_activation_offloading,
)
if self.args.enable_activation_offloading:
self.activation_offload_context = get_act_offloading_ctx_manager(
model=self.model,
enable_activation_offloading=self.args.enable_activation_offloading,
)

and later in the code:

context = self.activation_offload_context if self.args.enable_activation_offloading else nullcontext()

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so the get_act_offloading_ctx_manager returns a nullcontext when the flag enable_activation_offloading=False so that is why i do not need to put this in an if statement, or we can remove that logic from the get_act_offloading_ctx_manager method?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I find it more explicit to disable outside the function than inside. But it's really not very important.

qgallouedec and others added 14 commits February 25, 2025 15:00
* Update sft_config.py

* Update sft_trainer.py

* Update sft_config.py

* Update sft_trainer.py

* Apply style fixes

---------

Co-authored-by: Kashif Rasul <[email protected]>
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
* ✨ Enhance GRPO logging with configurable completions sampling

- Update `GRPOConfig` to replace `log_completions` with `log_completions_steps`
- Add `print_prompt_completions_sample()` utility function for rich console logging
- Modify `GRPOTrainer` to additionally print 5 random prompt-completion pairs every log_completions_steps steps

* GRPO trainer completions logging, move wandb checks together

* Add rich availability check and use fallback in print_prompt_completions_sample when rich is not available

* Update docstrings on print_prompt_completions_sample

Co-authored-by: Quentin Gallouédec <[email protected]>

* Revert back to simple log_completions bool

* GRPO log completions fully

* Remove print fallback from print_prompt_completions_sample

* Move accelerator main process check up for grpo log completions

* Explicit variable names in print_prompt_completions_sample

* Make GRPOConfig docstring match field description

* Update log_completions docs again

Co-authored-by: Quentin Gallouédec <[email protected]>

* Update GRPOConfig docs to match field

* improve readibility when prompt or completions are multilines

* log reward

* prevent hanging, don't print without rich, print reward

* style

---------

Co-authored-by: Robert Veres <[email protected]>
Co-authored-by: Quentin Gallouédec <[email protected]>
Co-authored-by: Quentin Gallouédec <[email protected]>
* updated DPO default values for alpha and tau

* same for grpo

---------

Co-authored-by: Quentin Gallouédec <[email protected]>
Co-authored-by: Quentin Gallouédec <[email protected]>
* pin liger-kernel

* style
* parameterize enable_prefix_caching

* apply review suggestion

---------

Co-authored-by: Quentin Gallouédec <[email protected]>
@qgallouedec qgallouedec marked this pull request as draft February 28, 2025 13:59
@casper-hansen
Copy link

Looks like this PR was parked for now. @kashif did the implementation not work? This is super relevant to me if I am going to use TRL for training long-context reasoners

@kashif
Copy link
Collaborator Author

kashif commented Mar 3, 2025

@casper-hansen So during training, i see a reduction in memory but the memory jumps up to default as soon as the eval steps starts and i am investigating why...

@casper-hansen
Copy link

@kashif The implementation relies on torch.autograd.graph.saved_tensors_hooks which would seem to only work if you are computing gradients.

@kashif
Copy link
Collaborator Author

kashif commented Apr 26, 2025

@casper-hansen i have added the documentation as well as a check that disables the activation offloading when the use_liger_kernel is True

@casper-hansen
Copy link

@casper-hansen i have added the documentation as well as a check that disables the activation offloading when the use_liger_kernel is True

In my personal opinion, it's suboptimal to disable in all cases. The only case where it's incompatible is when using the fused linear cross entropy. Maybe this can be fixed in a followup PR.

@kashif
Copy link
Collaborator Author

kashif commented Apr 26, 2025

here is the plot @casper-hansen

memory_vs_sequence_length

using:

#!/usr/bin/env python
# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
Script to benchmark GPU memory usage with and without activation offloading.

Example:
python trl/scripts/activation_offloading_benchmark.py \
    --model_name_or_path Qwen/Qwen2-0.5B \
    --dataset_name trl-lib/Capybara \
    --start_length 128 \
    --max_length 4096 \
    --step_size 128 \
    --num_train_steps 5 \
    --per_device_train_batch_size 1
"""

import argparse
import os
import gc
import sys
import traceback
from pathlib import Path

import torch
import matplotlib.pyplot as plt
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer

from trl import SFTConfig, SFTTrainer, TrlParser

def measure_memory():
    """Measure current and peak GPU memory in GB"""
    current = torch.cuda.memory_allocated() / (1024 ** 3)
    peak = torch.cuda.max_memory_allocated() / (1024 ** 3)
    return current, peak

def reset_memory_stats():
    """Reset the peak memory stats"""
    torch.cuda.reset_peak_memory_stats()
    torch.cuda.empty_cache()
    gc.collect()

def train_with_config(config, model_name, dataset, sequence_length, num_steps):
    """Run training with a specific configuration and measure memory usage"""
    try:
        # Create model and tokenizer
        model = AutoModelForCausalLM.from_pretrained(model_name)
        tokenizer = AutoTokenizer.from_pretrained(model_name)

        # Set up trainer
        trainer = SFTTrainer(
            model=model,
            args=config,
            train_dataset=dataset,
            processing_class=tokenizer,
        )

        # Reset memory stats before training
        reset_memory_stats()
        
        # Capture starting memory
        _, start_peak = measure_memory()
        
        # Train for a few steps - no parameters needed as max_steps is in config
        trainer.train()
        
        # Measure memory after training
        _, peak = measure_memory()
        
        # Clean up
        del trainer
        del model
        del tokenizer
        reset_memory_stats()
        
        # Return peak memory during training
        return peak, True  # Success
    
    except Exception as e:
        is_oom = "CUDA out of memory" in str(e)
        error_msg = "OOM" if is_oom else str(e)
        print(f"Error with sequence length {sequence_length}: {error_msg}")
        if not is_oom:
            traceback.print_exc()
        
        # Clean up - use variable names directly to avoid AttributeError
        try:
            if 'trainer' in locals():
                del trainer
            if 'model' in locals():
                del model
            if 'tokenizer' in locals():
                del tokenizer
        except:
            pass
            
        reset_memory_stats()
        
        return None, False  # Failure

def run_benchmark(args):
    """Run the complete benchmark with increasing sequence lengths"""
    # Load dataset
    dataset = load_dataset(args.dataset_name, split=f"train[:{args.dataset_size}]")
    
    # Results storage
    results = {
        "with_offloading": {"seq_lengths": [], "memory": []},
        "without_offloading": {"seq_lengths": [], "memory": []}
    }
    
    # Directory for output
    output_dir = Path(args.output_dir)
    output_dir.mkdir(exist_ok=True, parents=True)
    
    # Test both with and without activation offloading
    for use_offloading in [True, False]:
        mode = "with_offloading" if use_offloading else "without_offloading"
        print(f"\n{'='*50}\nTesting {mode}\n{'='*50}")
        
        seq_length = args.start_length
        
        while seq_length <= args.max_length:
            print(f"\nTesting sequence length: {seq_length}")
            
            # Create config with current sequence length
            config = SFTConfig(
                output_dir=str(output_dir / f"temp-{mode}-{seq_length}"),
                max_length=seq_length,
                packing=True,
                per_device_train_batch_size=args.per_device_train_batch_size,
                gradient_accumulation_steps=1,
                learning_rate=args.learning_rate,
                logging_steps=1,
                max_steps=args.num_train_steps,
                activation_offloading=use_offloading,
                remove_unused_columns=False,
                report_to="none",
            )
            
            # Train with this config
            peak_memory, success = train_with_config(
                config, args.model_name_or_path, dataset, seq_length, args.num_train_steps
            )
            
            if peak_memory is not None:
                results[mode]["seq_lengths"].append(seq_length)
                results[mode]["memory"].append(peak_memory)
                print(f"Sequence length {seq_length}: Peak memory {peak_memory:.2f} GB")
            
            if not success:
                print(f"Failed at sequence length {seq_length}, stopping {mode} tests")
                break
            
            # Increase sequence length for next iteration
            seq_length += args.step_size
    
    return results

def plot_results(results, output_path):
    """Plot memory usage for both configurations"""
    plt.figure(figsize=(10, 6))
    
    # Plot both configurations
    for mode, data in results.items():
        if data["seq_lengths"]:  # Only plot if we have data
            label = "With Activation Offloading" if mode == "with_offloading" else "Without Activation Offloading"
            plt.plot(data["seq_lengths"], data["memory"], 'o-', label=label)
    
    plt.xlabel('Sequence Length')
    plt.ylabel('Peak GPU Memory Usage (GB)')
    plt.title('GPU Memory Usage vs Sequence Length')
    plt.grid(True)
    plt.legend()
    
    # Save the plot
    plt.savefig(output_path)
    print(f"Plot saved to {output_path}")

def main():
    parser = argparse.ArgumentParser(description="Benchmark activation offloading memory usage")
    
    # Model and dataset arguments
    parser.add_argument("--model_name_or_path", type=str, required=True, help="Path to pretrained model")
    parser.add_argument("--dataset_name", type=str, default="trl-lib/Capybara", help="Dataset to use for training")
    parser.add_argument("--dataset_size", type=int, default=100, help="Number of examples to use from dataset")
    
    # Sequence length parameters
    parser.add_argument("--start_length", type=int, default=128, help="Starting sequence length")
    parser.add_argument("--max_length", type=int, default=4096, help="Maximum sequence length to try")
    parser.add_argument("--step_size", type=int, default=128, help="Increment for sequence length per step")
    
    # Training parameters
    parser.add_argument("--num_train_steps", type=int, default=5, help="Number of training steps per sequence length")
    parser.add_argument("--per_device_train_batch_size", type=int, default=1, help="Batch size per device")
    parser.add_argument("--learning_rate", type=float, default=2e-5, help="Learning rate")
    
    # Output parameters
    parser.add_argument("--output_dir", type=str, default="./activation_offloading_benchmark", 
                       help="Directory to save temporary files and output")
    
    args = parser.parse_args()
    
    if not torch.cuda.is_available():
        print("CUDA is not available. This benchmark requires a GPU.")
        return 1
    
    # Run the benchmark
    results = run_benchmark(args)
    
    # Plot the results
    plot_path = os.path.join(args.output_dir, "memory_vs_sequence_length.png")
    plot_results(results, plot_path)
    
    # Save the raw results
    import json
    with open(os.path.join(args.output_dir, "results.json"), "w") as f:
        json.dump(results, f, indent=2)
    
    return 0

if __name__ == "__main__":
    sys.exit(main()) 

run with:

python trl/scripts/activation_offloading_benchmark.py \
    --model_name_or_path Qwen/Qwen2-0.5B \
    --dataset_name trl-lib/Capybara \
    --start_length 128 \
    --max_length 2048 \
    --step_size 128 \
    --num_train_steps 5

Copy link
Member

@lewtun lewtun left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very clean PR and integration with TRL @kashif !

Regarding the benchmark, could you run this with a 7B model so we can see when SFT goes OOM for a given sequence length without activation checkpointing. I guess it may also be simpler to use packing in your benchmark so you are guaranteed to have fixed chunks of size max_length.

Note that this isn't a blocking requirement to get this merged once @qgallouedec has approved - I think it would mainly be nice to include in the docs as a plot :)

Comment on lines 40 to 41
activation_offloading (`bool`, *optional*, defaults to `False`):
Whether to offload the activations to the CPU.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you move this to the section > Parameters that control the training 🙏

Comment on lines +441 to +443
# Disable offloading for any Liger modules
for name, module in unwrapped_model.named_modules():
if "liger" in name.lower():
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this function support liger or not? If the user is prevented from using it with Liger, perhaps this part is useless?

Copy link
Member

@qgallouedec qgallouedec left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wonderful work

@lewtun lewtun changed the title [Models] Activation checkpointing from TrorchTune [Models] Activation checkpointing from TorchTune May 7, 2025
@kashif kashif merged commit cafa663 into main May 7, 2025
10 of 11 checks passed
@kashif kashif deleted the activation-checkpoint branch May 7, 2025 10:36
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.