Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
185 changes: 154 additions & 31 deletions trl/trainer/grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import contextlib
import functools
import logging
import os
import textwrap
import warnings
Expand Down Expand Up @@ -72,6 +73,8 @@
if is_wandb_available():
import wandb

logger = logging.getLogger(__name__)

# What we call a reward function is a callable that takes a list of prompts and completions and returns a list of
# rewards. When it's a string, it's a model ID, so it's loaded as a pretrained model.
RewardFunc = Union[str, PreTrainedModel, Callable[[list, list], list[float]]]
Expand Down Expand Up @@ -277,25 +280,25 @@ def __init__(

# Models
# Trained model
model_init_kwargs = args.model_init_kwargs or {}
self._model_init_kwargs = args.model_init_kwargs or {}
if isinstance(model, str):
model_id = model
torch_dtype = model_init_kwargs.get("torch_dtype")
torch_dtype = self._model_init_kwargs.get("torch_dtype")
if isinstance(torch_dtype, torch.dtype) or torch_dtype == "auto" or torch_dtype is None:
pass # torch_dtype is already a torch.dtype or "auto" or None
elif isinstance(torch_dtype, str): # it's a str, but not "auto"
torch_dtype = getattr(torch, torch_dtype)
model_init_kwargs["torch_dtype"] = torch_dtype
self._model_init_kwargs["torch_dtype"] = torch_dtype
else:
raise ValueError(
"Invalid `torch_dtype` passed to `GRPOConfig`. Expected either 'auto' or a string representing "
f"a `torch.dtype` (e.g., 'float32'), but got {torch_dtype}."
)
# Disable caching if gradient checkpointing is enabled (not supported)
model_init_kwargs["use_cache"] = (
False if args.gradient_checkpointing else model_init_kwargs.get("use_cache")
self._model_init_kwargs["use_cache"] = (
False if args.gradient_checkpointing else self._model_init_kwargs.get("use_cache")
)
model = AutoModelForCausalLM.from_pretrained(model, **model_init_kwargs)
model = AutoModelForCausalLM.from_pretrained(model, **self._model_init_kwargs)
else:
model_id = model.config._name_or_path
if args.model_init_kwargs is not None:
Expand All @@ -319,7 +322,7 @@ def __init__(
# If beta is 0.0, the reference model is not needed
self.ref_model = None
elif is_deepspeed_zero3_enabled():
self.ref_model = AutoModelForCausalLM.from_pretrained(model_id, **model_init_kwargs)
self.ref_model = AutoModelForCausalLM.from_pretrained(model_id, **self._model_init_kwargs)
elif is_peft_model(model):
# If PEFT is used, the reference model is not needed since the adapter can be disabled
# to revert to the initial model.
Expand All @@ -338,7 +341,7 @@ def __init__(
for i, reward_func in enumerate(reward_funcs):
if isinstance(reward_func, str):
reward_funcs[i] = AutoModelForSequenceClassification.from_pretrained(
reward_func, num_labels=1, **model_init_kwargs
reward_func, num_labels=1, **self._model_init_kwargs
)
self.reward_funcs = reward_funcs

Expand All @@ -353,26 +356,7 @@ def __init__(
else:
self.reward_weights = torch.ones(len(reward_funcs), dtype=torch.float32)

# Reward processing class
if reward_processing_classes is None:
reward_processing_classes = [None] * len(reward_funcs)
elif not isinstance(reward_processing_classes, list):
reward_processing_classes = [reward_processing_classes]
else:
if len(reward_processing_classes) != len(reward_funcs):
raise ValueError("The number of reward processing classes must match the number of reward functions.")

for i, (reward_processing_class, reward_func) in enumerate(zip(reward_processing_classes, reward_funcs)):
if isinstance(reward_func, PreTrainedModel):
if reward_processing_class is None:
reward_processing_class = AutoTokenizer.from_pretrained(reward_func.config._name_or_path)
if reward_processing_class.pad_token_id is None:
reward_processing_class.pad_token = reward_processing_class.eos_token
# The reward model computes the reward for the latest non-padded token in the input sequence.
# So it's important to set the pad token ID to the padding token ID of the processing class.
reward_func.config.pad_token_id = reward_processing_class.pad_token_id
reward_processing_classes[i] = reward_processing_class
self.reward_processing_classes = reward_processing_classes
self.reward_processing_classes = self._make_reward_processing_classes(reward_funcs, reward_processing_classes)

# Data collator
def data_collator(features): # No data collation is needed in GRPO
Expand Down Expand Up @@ -575,6 +559,33 @@ def new_group_context():
if isinstance(reward_func, PreTrainedModel):
self.reward_funcs[i] = self.accelerator.prepare_model(reward_func, evaluation_mode=True)

@staticmethod
def _make_reward_processing_classes(
reward_funcs,
reward_processing_classes: Optional[Union[PreTrainedTokenizerBase, list[PreTrainedTokenizerBase]]] = None,
):
if reward_processing_classes is None:
reward_processing_classes = [None] * len(reward_funcs)
elif not isinstance(reward_processing_classes, list):
reward_processing_classes = [reward_processing_classes]
else:
if len(reward_processing_classes) != len(reward_funcs):
raise ValueError("The number of reward processing classes must match the number of reward functions.")

for i, (reward_processing_class, reward_func) in enumerate(
zip(reward_processing_classes, reward_funcs, strict=True)
):
if isinstance(reward_func, PreTrainedModel):
if reward_processing_class is None:
reward_processing_class = AutoTokenizer.from_pretrained(reward_func.config._name_or_path)
if reward_processing_class.pad_token_id is None:
reward_processing_class.pad_token = reward_processing_class.eos_token
# The reward model computes the reward for the latest non-padded token in the input sequence.
# So it's important to set the pad token ID to the padding token ID of the processing class.
reward_func.config.pad_token_id = reward_processing_class.pad_token_id
reward_processing_classes[i] = reward_processing_class
return reward_processing_classes

def _set_signature_columns_if_needed(self):
# If `self.args.remove_unused_columns` is True, non-signature columns are removed.
# By default, this method sets `self._signature_columns` to the model's expected inputs.
Expand Down Expand Up @@ -806,7 +817,7 @@ def _extract_completions(
def _compute_rewards_per_func(self, inputs, prompts: list[str], completions: list[str], device) -> torch.Tensor:
rewards_per_func = torch.zeros(len(prompts), len(self.reward_funcs), device=device)
for i, (reward_func, reward_processing_class) in enumerate(
zip(self.reward_funcs, self.reward_processing_classes)
zip(self.reward_funcs, self.reward_processing_classes, strict=True)
):
if isinstance(reward_func, nn.Module): # Module instead of PretrainedModel for compat with compiled models
reward_func_name = f"reward {reward_func.config._name_or_path.split('/')[-1]}"
Expand All @@ -817,10 +828,10 @@ def _compute_rewards_per_func(self, inputs, prompts: list[str], completions: lis
reward_func, nn.Module
): # Module instead of PretrainedModel for compat with compiled models
if is_conversational(inputs[0]):
messages = [{"messages": p + c} for p, c in zip(prompts, completions)]
messages = [{"messages": p + c} for p, c in zip(prompts, completions, strict=True)]
texts = [apply_chat_template(x, reward_processing_class)["text"] for x in messages]
else:
texts = [p + c for p, c in zip(prompts, completions)]
texts = [p + c for p, c in zip(prompts, completions, strict=True)]
reward_inputs = reward_processing_class(
text=texts, return_tensors="pt", padding=True, padding_side="right", add_special_tokens=False
)
Expand Down Expand Up @@ -1177,3 +1188,115 @@ def create_model_card(
)

model_card.save(os.path.join(self.args.output_dir, "README.md"))


class GRPOTrainerWithEval(GRPOTrainer):
def __init__(
self,
model: str | PreTrainedModel,
train_reward_funcs: RewardFunc | list[RewardFunc],
eval_reward_funcs: RewardFunc | list[RewardFunc] | None = None,
args: GRPOConfig | None = None,
train_dataset: Dataset | IterableDataset | None = None,
eval_dataset: Dataset | IterableDataset | dict[str, Dataset | IterableDataset] | None = None,
processing_class: PreTrainedTokenizerBase | None = None,
train_reward_processing_classes: PreTrainedTokenizerBase | list[PreTrainedTokenizerBase] | None = None,
eval_reward_processing_classes: PreTrainedTokenizerBase | list[PreTrainedTokenizerBase] | None = None,
**kwargs,
):
super().__init__(
model=model,
reward_funcs=train_reward_funcs,
args=args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
processing_class=processing_class,
reward_processing_classes=train_reward_processing_classes,
**kwargs,
)

# Store training reward functions reference
self.train_reward_funcs = self.reward_funcs
self.train_reward_processing_classes = self.reward_processing_classes

if eval_reward_funcs is not None:
# Okay we have some custom evaluation reward functions, set them up

if "compute_metrics" in kwargs:
logger.warning(
"Please make sure your custom compute_metrics function is using the"
" right evaluation reward functions."
)

# Matching reward_funcs processing
if not isinstance(eval_reward_funcs, list):
eval_reward_funcs = [eval_reward_funcs]
for i, reward_func in enumerate(eval_reward_funcs):
if isinstance(reward_func, str):
eval_reward_funcs[i] = AutoModelForSequenceClassification.from_pretrained(
reward_func, num_labels=1, **self._model_init_kwargs
)
self.eval_reward_funcs = eval_reward_funcs
self.eval_reward_processing_classes = self._make_reward_processing_classes(
eval_reward_funcs, eval_reward_processing_classes
)
else:
# We don't have any, so we just reuse the training ones
self.eval_reward_funcs = self.train_reward_funcs
self.eval_reward_processing_classes = self.train_reward_processing_classes

def _compute_rewards_per_func(self, inputs, prompts: list[str], completions: list[str], device) -> torch.Tensor:
if self.control.should_evaluate:
reward_funcs = self.eval_reward_funcs
reward_processing_classes = self.eval_reward_processing_classes
else:
reward_funcs = self.train_reward_funcs
reward_processing_classes = self.train_reward_processing_classes

rewards_per_func = torch.zeros(len(prompts), len(reward_funcs), device=device)
for i, (reward_func, reward_processing_class) in enumerate(
zip(reward_funcs, reward_processing_classes, strict=True)
):
if isinstance(reward_func, nn.Module): # Module instead of PretrainedModel for compat with compiled models
reward_func_name = f"reward {reward_func.config._name_or_path.split('/')[-1]}"
else:
reward_func_name = reward_func.__name__
with profiling_context(self, reward_func_name):
if isinstance(
reward_func, nn.Module
): # Module instead of PretrainedModel for compat with compiled models
if is_conversational(inputs[0]):
messages = [{"messages": p + c} for p, c in zip(prompts, completions, strict=True)]
texts = [apply_chat_template(x, reward_processing_class)["text"] for x in messages]
else:
texts = [p + c for p, c in zip(prompts, completions, strict=True)]
reward_inputs = reward_processing_class(
texts, return_tensors="pt", padding=True, padding_side="right", add_special_tokens=False
)
reward_inputs = super()._prepare_inputs(reward_inputs)
with torch.inference_mode():
rewards_per_func[:, i] = reward_func(**reward_inputs).logits[:, 0] # Shape (B*G,)
else:
# Repeat all input columns (but "prompt" and "completion") to match the number of generations
keys = [key for key in inputs[0] if key not in {"prompt", "completion"}]
reward_kwargs = {key: [example[key] for example in inputs] for key in keys}
output_reward_func = reward_func(prompts=prompts, completions=completions, **reward_kwargs)
rewards_per_func[:, i] = torch.tensor(output_reward_func, dtype=torch.float32, device=device)

# Gather the reward per function: this part is crucial, because the rewards are normalized per group and the
# completions may be distributed across processes
return gather(rewards_per_func)

def compute_reward_metrics(self, eval_prediction: EvalPrediction) -> dict[str, float]:
if not self.control.should_evaluate:
raise RuntimeError("We are supposed to be in evaluation mode.")

avg_reward_per_func = eval_prediction.predictions.mean(axis=0)
metrics: dict[str, float] = {}
for i, reward_func in enumerate(self.eval_reward_funcs):
if isinstance(reward_func, PreTrainedModel):
reward_func_name = reward_func.config._name_or_path.split("/")[-1]
else:
reward_func_name = reward_func.__name__
metrics[f"rewards/{reward_func_name}"] = avg_reward_per_func[i].item()
return metrics