Skip to content

Commit 735e5d1

Browse files
⏯️ Fix: handle None inputs when resuming GRPO Trainer from checkpoint (huggingface#3148)
Co-authored-by: Quentin Gallouédec <[email protected]>
1 parent 85e24bc commit 735e5d1

File tree

1 file changed

+6
-3
lines changed

1 file changed

+6
-3
lines changed

trl/trainer/grpo_trainer.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -663,11 +663,14 @@ def _move_model_to_vllm(self):
663663
def _prepare_inputs(self, inputs: dict[str, Union[torch.Tensor, Any]]) -> dict[str, Union[torch.Tensor, Any]]:
664664
mode = "eval" if self.control.should_evaluate else "train"
665665
if mode == "train":
666-
if self.state.global_step % self.num_iterations == 0:
666+
buffer_index = self._step % self.args.gradient_accumulation_steps
667+
buffered_inputs = self._buffered_inputs[buffer_index]
668+
if self.state.global_step % self.num_iterations == 0 or buffered_inputs is None:
669+
# buffered_inputs=None can occur when resuming from a checkpoint
667670
inputs = self._generate_and_score_completions(inputs)
668-
self._buffered_inputs[self._step % self.args.gradient_accumulation_steps] = inputs
671+
self._buffered_inputs[buffer_index] = inputs
669672
else:
670-
inputs = self._buffered_inputs[self._step % self.args.gradient_accumulation_steps]
673+
inputs = buffered_inputs
671674
self._step += 1
672675
else:
673676
# In evaluation, we don't reuse completions across multiple updates, so we don't need to buffer inputs.

0 commit comments

Comments
 (0)