Skip to content

🆙 Bump transformers to 4.51 and use _VALID_DICT_FIELDS #3553

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jun 9, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ jobs:
uv pip install ".[dev]"
uv pip install accelerate==1.4.0
uv pip install datasets==3.0.0
uv pip install transformers==4.50.0
uv pip install transformers==4.51.0

- name: Test with pytest
run: |
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
accelerate>=1.4.0
datasets>=3.0.0
transformers>=4.50.0
transformers>=4.51.0
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ include_package_data = True
install_requires =
accelerate>=1.4.0
datasets>=3.0.0
transformers>=4.50.0
transformers>=4.51.0

[options.packages.find]
exclude =
Expand Down
2 changes: 2 additions & 0 deletions trl/trainer/bco_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,8 @@ class BCOConfig(TrainingArguments):
Maximum value of the density ratio. The estimated density ratio is clamped to this value.
"""

_VALID_DICT_FIELDS = TrainingArguments._VALID_DICT_FIELDS + ["model_init_kwargs", "ref_model_init_kwargs"]

# Parameters whose default values are overridden from TrainingArguments
logging_steps: float = field(
default=10,
Expand Down
2 changes: 2 additions & 0 deletions trl/trainer/cpo_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,8 @@ class CPOConfig(TrainingArguments):
Number of processes to use for processing the dataset.
"""

_VALID_DICT_FIELDS = TrainingArguments._VALID_DICT_FIELDS + ["model_init_kwargs"]

# Parameters whose default values are overridden from TrainingArguments
learning_rate: float = field(
default=1e-6,
Expand Down
2 changes: 2 additions & 0 deletions trl/trainer/dpo_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,8 @@ class DPOConfig(TrainingArguments):
evaluation.
"""

_VALID_DICT_FIELDS = TrainingArguments._VALID_DICT_FIELDS + ["model_init_kwargs", "ref_model_init_kwargs"]

# Parameters whose default values are overridden from TrainingArguments
learning_rate: float = field(
default=1e-6,
Expand Down
4 changes: 4 additions & 0 deletions trl/trainer/gkd_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
from dataclasses import dataclass, field
from typing import Any, Optional

from transformers import TrainingArguments

from .sft_config import SFTConfig


Expand Down Expand Up @@ -50,6 +52,8 @@ class GKDConfig(SFTConfig):
on teacher-generated output).
"""

_VALID_DICT_FIELDS = TrainingArguments._VALID_DICT_FIELDS + ["teacher_model_init_kwargs"]

temperature: float = field(
default=0.9,
metadata={"help": "Temperature for sampling. The higher the temperature, the more random the completions."},
Expand Down
3 changes: 0 additions & 3 deletions trl/trainer/grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -844,9 +844,6 @@ def _get_per_token_logps(self, model, input_ids, attention_mask, logits_to_keep,
).logits
logits = logits[:, :-1, :] # (B, L-1, V), exclude the last logit: it corresponds to the next token pred
input_ids_batch = input_ids_batch[:, -logits_to_keep:]
# For transformers<=4.48, logits_to_keep argument isn't supported, so here we drop logits ourselves.
# See https://github.com/huggingface/trl/issues/2770
logits = logits[:, -logits_to_keep:]
# Divide logits by sampling temperature.
# See https://huggingface.co/blog/the_n_implementation_details_of_rlhf_with_ppo#policy-training-implementation-details
logits = logits / self.temperature
Expand Down
2 changes: 2 additions & 0 deletions trl/trainer/iterative_sft_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ class may differ from those in [`~transformers.TrainingArguments`].
Whether to optimize CUDA cache for slightly more memory-efficient training.
"""

_VALID_DICT_FIELDS = TrainingArguments._VALID_DICT_FIELDS + ["model_init_kwargs"]

# Parameters whose default values are overridden from TrainingArguments
logging_steps: float = field(
default=10,
Expand Down
2 changes: 2 additions & 0 deletions trl/trainer/kto_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,8 @@ class KTOConfig(TrainingArguments):
the model when the model does not have a `get_decoder` method in the case when `use_liger_loss` is `True`.
"""

_VALID_DICT_FIELDS = TrainingArguments._VALID_DICT_FIELDS + ["model_init_kwargs", "ref_model_init_kwargs"]

# Parameters whose default values are overridden from TrainingArguments
learning_rate: float = field(
default=1e-6,
Expand Down
48 changes: 2 additions & 46 deletions trl/trainer/online_dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@

import datasets
import jinja2
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
Expand All @@ -41,7 +40,7 @@
is_apex_available,
is_wandb_available,
)
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, EvalPrediction, seed_worker
from transformers.trainer_utils import EvalPrediction, seed_worker
from transformers.training_args import OptimizerNames
from transformers.utils import is_peft_available, is_sagemaker_mp_enabled, logging

Expand Down Expand Up @@ -694,9 +693,8 @@ def training_step(
return loss.detach() / self.args.gradient_accumulation_steps

# Same as Trainer._maybe_log_save_evaluate but log our metrics
# start_time defaults to None to allow compatibility with transformers<=4.46
def _maybe_log_save_evaluate(
self, tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval, start_time=None, learning_rate=None
self, tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval, start_time, learning_rate=None
):
if self.control.should_log and self.state.global_step > self._globalstep_last_logged:
logs: dict[str, float] = {}
Expand Down Expand Up @@ -737,48 +735,6 @@ def _maybe_log_save_evaluate(
self._save_checkpoint(model, trial)
self.control = self.callback_handler.on_save(self.args, self.state, self.control)

# Copy-pasted from transformers.Trainer to maintain compatibility with earlier versions.
# This can be removed once the minimum transformers version is updated to 4.47.
# Refer to https://github.com/huggingface/trl/pull/2288 for more details.
def _determine_best_metric(self, metrics, trial):
"""
Determine if the model should be saved based on the evaluation metrics.
If args.metric_for_best_model is not set, the loss is used.
Returns:
bool: True if a new best metric was found, else False
"""
is_new_best_metric = False

if self.args.metric_for_best_model is not None:
metric_to_check = self.args.metric_for_best_model

if not metric_to_check.startswith("eval_"):
metric_to_check = f"eval_{metric_to_check}"

try:
metric_value = metrics[metric_to_check]
except KeyError as exc:
raise KeyError(
f"The `metric_for_best_model` training argument is set to '{metric_to_check}', which is not found in the evaluation metrics. "
f"The available evaluation metrics are: {list(metrics.keys())}. Consider changing the `metric_for_best_model` via the TrainingArguments."
) from exc

operator = np.greater if self.args.greater_is_better else np.less

if self.state.best_metric is None:
self.state.best_metric = float("-inf") if self.args.greater_is_better else float("inf")

if operator(metric_value, self.state.best_metric):
run_dir = self._get_output_dir(trial=trial)
checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}"
output_dir = os.path.join(run_dir, checkpoint_folder)
self.state.best_metric = metric_value
self.state.best_model_checkpoint = output_dir

is_new_best_metric = True

return is_new_best_metric

def create_model_card(
self,
model_name: Optional[str] = None,
Expand Down
2 changes: 2 additions & 0 deletions trl/trainer/orpo_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,8 @@ class ORPOConfig(TrainingArguments):
Number of processes to use for processing the dataset.
"""

_VALID_DICT_FIELDS = TrainingArguments._VALID_DICT_FIELDS + ["model_init_kwargs"]

# Parameters whose default values are overridden from TrainingArguments
learning_rate: float = field(
default=1e-6,
Expand Down
2 changes: 2 additions & 0 deletions trl/trainer/sft_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,8 @@ class SFTConfig(TrainingArguments):
Whether to offload the activations to the CPU.
"""

_VALID_DICT_FIELDS = TrainingArguments._VALID_DICT_FIELDS + ["model_init_kwargs"]

# Parameters whose default values are overridden from TrainingArguments
learning_rate: float = field(
default=2e-5,
Expand Down
Loading