Skip to content

🍃 GRPO - Do not load reference model when beta == 0 #2806

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Feb 18, 2025
19 changes: 18 additions & 1 deletion tests/test_grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -499,6 +499,23 @@ def test_training_with_sync_ref_model(self):
new_param = trainer.model.get_parameter(n)
self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.")


def test_beta_zero_no_ref_model_and_no_kl(self):
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train")
with tempfile.TemporaryDirectory() as tmp_dir:
training_args = GRPOConfig(
output_dir=tmp_dir,
beta=0, # set beta to 0 to test the case where the reference model is not used
learning_rate=0.1,
per_device_train_batch_size=3,
num_generations=3,
max_completion_length=32,
max_steps=1, # run only one training step to keep the test fast
report_to="none",
)
trainer = GRPOTrainer(
model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5",

@unittest.skipIf(not is_vllm_available(), "vLLM is not available")
@require_torch_accelerator
@require_peft
Expand Down Expand Up @@ -532,4 +549,4 @@ def test_training_vllm_and_peft(self):
# Check that the params have changed
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.")
self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.")
1 change: 1 addition & 0 deletions trl/trainer/grpo_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ class GRPOConfig(TrainingArguments):
[`~transformers.TrainingArguments`].
beta (`float`, *optional*, defaults to `0.04`):
KL coefficient.
If 0, we do not need to load any reference model reducing memory usage and improving training speed.
reward_weights (`list[float]` or `None`, *optional*, defaults to `None`):
Weights for each reward function. Must match the number of reward functions. If `None`, all rewards are
weighted equally with weight `1.0`.
Expand Down
36 changes: 24 additions & 12 deletions trl/trainer/grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,11 +243,16 @@ def __init__(
"This argument can only be used when the `model` argument is a string."
)

self.beta = args.beta

if peft_config is not None:
model = get_peft_model(model, peft_config)

# Reference model
if is_deepspeed_zero3_enabled():
# If beta is 0, the reference model is not needed
if self.beta == 0:
self.ref_model = None
elif is_deepspeed_zero3_enabled():
self.ref_model = AutoModelForCausalLM.from_pretrained(model_id, **model_init_kwargs)
elif peft_config is None:
# If PEFT configuration is not provided, create a reference model based on the initial model.
Expand Down Expand Up @@ -313,8 +318,6 @@ def data_collator(features): # No data collation is needed in GRPO
self.num_generations = args.num_generations # = G in the GRPO paper
self.use_vllm = args.use_vllm

self.beta = args.beta

# The trainer estimates the number of FLOPs (floating-point operations) using the number of elements in the
# input tensor associated with the key "input_ids". However, in GRPO, the sampled data does not include the
# "input_ids" key. Instead, the available keys is "prompt". As a result, the trainer issues the warning:
Expand Down Expand Up @@ -578,7 +581,9 @@ def _prepare_inputs(self, inputs: dict[str, Union[torch.Tensor, Any]]) -> dict[s
logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens

with torch.inference_mode():
if self.ref_model is not None:
if self.beta == 0:
ref_per_token_logps = None
elif self.ref_model is not None:
ref_per_token_logps = self._get_per_token_logps(
self.ref_model, prompt_completion_ids, attention_mask, logits_to_keep
)
Expand Down Expand Up @@ -697,23 +702,30 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N

per_token_logps = self._get_per_token_logps(model, input_ids, attention_mask, logits_to_keep)

# Compute the KL divergence between the model and the reference model
ref_per_token_logps = inputs["ref_per_token_logps"]
per_token_kl = torch.exp(ref_per_token_logps - per_token_logps) - (ref_per_token_logps - per_token_logps) - 1
advantages = inputs["advantages"]

# Compute the KL divergence between the model and the reference model if beta is not 0
ref_per_token_logps = inputs["ref_per_token_logps"]
# x - x.detach() allows for preserving gradients from x
advantages = inputs["advantages"]
per_token_loss = torch.exp(per_token_logps - per_token_logps.detach()) * advantages.unsqueeze(1)
per_token_loss = -(per_token_loss - self.beta * per_token_kl)
if ref_per_token_logps is None:
per_token_loss = -per_token_loss
else:
# we need to compute the KL divergence between the model and the reference model
per_token_kl = (
torch.exp(ref_per_token_logps - per_token_logps) - (ref_per_token_logps - per_token_logps) - 1
)

per_token_loss = -(per_token_loss - self.beta * per_token_kl)

mean_kl = ((per_token_kl * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean()
self._metrics["kl"].append(self.accelerator.gather_for_metrics(mean_kl).mean().item())
loss = ((per_token_loss * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean()

# Log the metrics
completion_length = self.accelerator.gather_for_metrics(completion_mask.sum(1)).float().mean().item()
self._metrics["completion_length"].append(completion_length)

mean_kl = ((per_token_kl * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean()
self._metrics["kl"].append(self.accelerator.gather_for_metrics(mean_kl).mean().item())

return loss

def prediction_step(self, model, inputs, prediction_loss_only, ignore_keys: Optional[list[str]] = None):
Expand Down
Loading