Skip to content

⏯️ Fix: handle None inputs when resuming GRPO Trainer from checkpoint #3148

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Mar 31, 2025

Conversation

PenutChen
Copy link
Contributor

What does this PR do?

This PR refactors the _prepare_inputs method to ensure it never returns None when resuming training from checkpoints in GRPOTrainer.

Previously, if self._buffered_inputs[...] was None during the resume process, and self.state.global_step % self.num_iterations != 0, the method would return None. This caused issues downstream where non-None inputs were expected.

To fix this, the logic has been updated so that if the buffered input is None, it always falls back to generating new inputs via _generate_and_score_completions(inputs), regardless of the step count. This ensures stability and correctness when resuming from checkpoints.

In addition, the modulo expression self._step % self.args.gradient_accumulation_steps has been refactored into a dedicated variable accumulation_index to improve readability and reduce redundancy.


Fixes

No corresponding issue was filed, but this change addresses a potential silent failure when using resume_from_checkpoint with GRPOTrainer.


Motivation and context

Users resuming training from checkpoints may encounter None inputs in _prepare_inputs, leading to errors in the training loop. This fix ensures robustness by avoiding returning None in any case and improves the maintainability of the code through refactoring.


Before submitting


Who can review?

Anyone familiar with GRPOTrainer, buffered input logic, or checkpoint resume behavior in TRL.

@qgallouedec
Copy link
Member

Thanks!! Can share a simple piece of code that would fail without your fix?

@PenutChen
Copy link
Contributor Author

@qgallouedec here's the sample code:

from datasets import load_dataset
from trl import GRPOTrainer, GRPOConfig

def reward_fn(completions, **_):
    return [1.0 for _ in completions]

# Normal training
trainer = GRPOTrainer(
    model="facebook/opt-125m",
    args=GRPOConfig(
        output_dir="save/test",
        num_generations=2,
        per_device_train_batch_size=2,
        num_iterations=4,
        save_steps=1,
        max_steps=10,
        max_prompt_length=1,
        max_completion_length=1,
    ),
    reward_funcs=reward_fn,
    train_dataset=load_dataset("trl-lib/tldr", split="train"),
)
trainer.train()

# Simulating a fresh new trainer instance after interruption
trainer = GRPOTrainer(
    model="facebook/opt-125m",
    args=GRPOConfig(
        output_dir="save/test",
        num_generations=2,
        per_device_train_batch_size=2,
        num_iterations=4,
        save_steps=1,
        max_steps=10,
        max_prompt_length=1,
        max_completion_length=1,
    ),
    reward_funcs=reward_fn,
    train_dataset=load_dataset("trl-lib/tldr", split="train"),
)

# Resume from checkpoint at step which is not divisible by num_iterations
trainer.train(resume_from_checkpoint="save/test/checkpoint-6")

error traceback:

Traceback (most recent call last):
  File "your_script.py", line 41, in <module>
    trainer.train(resume_from_checkpoint="save/test/checkpoint-6")
  File ".../transformers/trainer.py", line 2245, in train
    return inner_training_loop(
  File ".../transformers/trainer.py", line 2556, in _inner_training_loop
    tr_loss_step = self.training_step(model, inputs, num_items_in_batch)
  File ".../transformers/trainer.py", line 3718, in training_step
    loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch)
  File ".../trl/extras/profiling.py", line 87, in wrapper
    return func(self, *args, **kwargs)
  File ".../trl/trainer/grpo_trainer.py", line 905, in compute_loss
    prompt_ids, prompt_mask = inputs["prompt_ids"], inputs["prompt_mask"]
                              ~~~~~~^^^^^^^^^^^^^^
TypeError: 'NoneType' object is not subscriptable

@qgallouedec qgallouedec self-assigned this Mar 25, 2025
Copy link
Member

@qgallouedec qgallouedec left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice, thank you @PenutChen! In this case, the results won't be exactly the same as if we hadn't interrupted the training (we would have to save and load this buffer), but that's not a big deal.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@qgallouedec qgallouedec changed the title fix: handle None inputs when resuming GRPO Trainer from checkpoint ⏯️ Fix: handle None inputs when resuming GRPO Trainer from checkpoint Mar 31, 2025
@qgallouedec qgallouedec merged commit 488025c into huggingface:main Mar 31, 2025
6 of 9 checks passed
kashif pushed a commit to kashif/trl that referenced this pull request Mar 31, 2025
BjarniHaukur pushed a commit to ASSERT-KTH/trl that referenced this pull request Apr 15, 2025
…ggingface#3131)

Co-authored-by: Quentin Gallouédec <[email protected]>
Co-authored-by: Quentin Gallouédec <[email protected]>

log answer key to wandb

all Table

HTML logging

table

bump patch

hmm

formatting

html esacape

reward isnt string

[Liger] Liger KTO support (huggingface#2812)

Co-authored-by: Kashif Rasul <[email protected]>
Co-authored-by: Quentin Gallouédec <[email protected]>

🏃 Migrate CI to self-hosted runners (huggingface#3174)

❤️‍🩹 [CI] fix transformers dev CI failure (huggingface#3176)

Co-authored-by: Quentin Gallouédec <[email protected]>

⏯️ Fix: handle None inputs when resuming GRPO Trainer from checkpoint (huggingface#3148)

Co-authored-by: Quentin Gallouédec <[email protected]>

📎 Fix is_clipped to compute the effective clip_ratio (huggingface#3175)

Co-authored-by: Quentin Gallouédec <[email protected]>
Co-authored-by: Quentin Gallouédec <[email protected]>

Fix breaking typo for flash_attention reducing_memory_usage.md (huggingface#3190)

Show unique prompts in GRPO WandB tables (huggingface#3191)

🐗 [CI] Fix trufflehog false positives (huggingface#3192)

[GRPO] Improve completion length logging (huggingface#3188)

preliminary openai compatible endpoint

early concept, needs refining

dedupe

debug print

some slop to work on

unslop, missing hist

almost valid pseudocode

middle-ware monkey patch in mp.Pool()...

remove unused

More accurate .md

need gpu

renting lambda again

much nicer

small

aider-chat and datasets conflict

risky reqs change

should work, but hacky

some insights, but monkeypatching probably wont suffice

refactor: Rewrite test script to use SWE-bench dataset with MultiProcessAider

refactor: Remove logging statements from test.py

one step closer

finally, the correct abstraction

doc

todo

unslop

unslop

undo accidental black

cleaner abstraction

new abstraction
yxliu-TAMU pushed a commit to mincheolseong/ECEN743-GRPO-Project-Proposal that referenced this pull request Apr 20, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants