Skip to content

Commit 4e7f0a5

Browse files
AIR-hlkashifqgallouedec
authored
🤧 LD-DPO support (#3458)
Co-authored-by: Kashif Rasul <[email protected]> Co-authored-by: Quentin Gallouédec <[email protected]> Co-authored-by: Quentin Gallouédec <[email protected]>
1 parent 17a9069 commit 4e7f0a5

File tree

4 files changed

+89
-7
lines changed

4 files changed

+89
-7
lines changed

docs/source/dpo_trainer.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,10 @@ The [RPO](https://huggingface.co/papers/2404.19733) paper implements an iterativ
168168

169169
The [WPO](https://huggingface.co/papers/2406.11827) paper adapts off-policy data to resemble on-policy data more closely by reweighting preference pairs according to their probability under the current policy. To use this method, set the `use_weighting` flag to `True` in the [`DPOConfig`].
170170

171+
### LD-DPO loss
172+
173+
The [LD-DPO](https://huggingface.co/papers/2409.06411) paper decomposes the portion of the response that exceeds the desired length into two components — human-like preferences and verbosity preference — based on a mixing coefficient \\( \alpha \\). To use this method, set the `ld_alpha` in the [`DPOConfig`] to an appropriate value. The paper suggests setting this value between `0.0` and `1.0`.
174+
171175
### For Mixture of Experts Models: Enabling the auxiliary loss
172176

173177
MOEs are the most efficient if the load is about equally distributed between experts.

tests/test_dpo_trainer.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1258,6 +1258,37 @@ def dummy_compute_metrics(*args, **kwargs):
12581258

12591259
self.assertEqual(trainer.state.log_history[-2]["eval_test"], 0.0)
12601260

1261+
def test_train_with_length_desensitization(self):
1262+
model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5"
1263+
dataset = load_dataset("trl-internal-testing/zen", "standard_preference", split="train")
1264+
tokenizer = AutoTokenizer.from_pretrained(model_id)
1265+
with tempfile.TemporaryDirectory() as tmp_dir:
1266+
training_args = DPOConfig(
1267+
output_dir=tmp_dir,
1268+
per_device_train_batch_size=2,
1269+
learning_rate=9e-1,
1270+
ld_alpha=0.5,
1271+
report_to="none",
1272+
)
1273+
trainer = DPOTrainer(
1274+
model=model_id,
1275+
args=training_args,
1276+
processing_class=tokenizer,
1277+
train_dataset=dataset,
1278+
)
1279+
1280+
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
1281+
1282+
trainer.train()
1283+
1284+
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
1285+
1286+
# Check that the parameters have changed
1287+
for n, param in previous_trainable_params.items():
1288+
new_param = trainer.model.get_parameter(n)
1289+
if param.sum() != 0: # ignore 0 biases
1290+
self.assertFalse(torch.allclose(param, new_param, rtol=1e-12, atol=1e-12))
1291+
12611292

12621293
@require_vision
12631294
class DPOVisionTrainerTester(unittest.TestCase):

trl/trainer/dpo_config.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -131,14 +131,19 @@ class DPOConfig(TrainingArguments):
131131
Whether to ignore the provided reference model and implicitly use a reference model that assigns equal
132132
probability to all responses.
133133
label_smoothing (`float`, *optional*, defaults to `0.0`):
134-
Robust DPO label smoothing parameter from the [cDPO](https://ericmitchell.ai/cdpo.pdf) report and
134+
Robust DPO label smoothing parameter from the [cDPO report](https://ericmitchell.ai/cdpo.pdf) and
135135
[Robust DPO](https://huggingface.co/papers/2403.00409) paper that should be between `0.0` and `0.5`.
136136
use_weighting (`bool`, *optional*, defaults to `False`):
137-
Whether to weight the loss as done in the [WPO](https://huggingface.co/papers/2406.11827) paper.
137+
Whether to weight the loss as done in the [WPO paper](https://huggingface.co/papers/2406.11827).
138138
rpo_alpha (`float`, *optional*, defaults to `None`):
139-
α parameter from the [RPO](https://huggingface.co/papers/2404.19733) paper (v3), which controls the
139+
α parameter from the [RPO paper](https://huggingface.co/papers/2404.19733) (v3), which controls the
140140
weighting of the NLL term in the loss. If `None`, no weighting is applied and the loss is the same as the
141141
DPO loss. The paper recommends `rpo_alpha=1.0`.
142+
ld_alpha (`float` or `None`, *optional*, defaults to `None`):
143+
α parameter from the [LD-DPO paper](https://huggingface.co/papers/2409.06411), which controls the weighting
144+
of the verbose token log-probabilities in responses. If `None`, no weighting is applied to the verbose
145+
part, and the loss is equivalent to the standard DPO loss. The paper recommends setting `ld_alpha` between
146+
`0.0` and `1.0`.
142147
discopop_tau (`float`, *optional*, defaults to `0.05`):
143148
τ/temperature parameter from the [DiscoPOP](https://huggingface.co/papers/2406.08414) paper, which controls
144149
the shape of log ratio modulated loss. The paper recommends the default value `discopop_tau=0.05`.
@@ -346,6 +351,14 @@ class DPOConfig(TrainingArguments):
346351
"`rpo_alpha=1.0`."
347352
},
348353
)
354+
ld_alpha: Optional[float] = field(
355+
default=None,
356+
metadata={
357+
"help": "α parameter from the LD-DPO paper, which controls the weighting of the verbose token "
358+
"log-probabilities in responses. If `None`, no weighting is applied to the verbose part, and the loss is "
359+
"equivalent to the standard DPO loss. The paper recommends setting `ld_alpha` between `0.0` and `1.0`.",
360+
},
361+
)
349362
discopop_tau: float = field(
350363
default=0.05,
351364
metadata={

trl/trainer/dpo_trainer.py

Lines changed: 38 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -804,9 +804,9 @@ def compute_ref_log_probs(self, batch: dict[str, torch.LongTensor]) -> dict:
804804
with torch.no_grad(), compte_ref_context_manager:
805805
if self.ref_model is None:
806806
with self.null_ref_context():
807-
ref_model_output = self.concatenated_forward(self.model, batch)
807+
ref_model_output = self.concatenated_forward(self.model, batch, is_ref_model=True)
808808
else:
809-
ref_model_output = self.concatenated_forward(self.ref_model, batch)
809+
ref_model_output = self.concatenated_forward(self.ref_model, batch, is_ref_model=True)
810810
return ref_model_output["chosen_logps"], ref_model_output["rejected_logps"]
811811

812812
@staticmethod
@@ -1066,10 +1066,22 @@ def dpo_loss(
10661066

10671067
return losses, chosen_rewards, rejected_rewards
10681068

1069-
def concatenated_forward(self, model: nn.Module, batch: dict[str, Union[list, torch.LongTensor]]):
1070-
"""Run the given model on the given batch of inputs, concatenating the chosen and rejected inputs together.
1069+
def concatenated_forward(
1070+
self, model: nn.Module, batch: dict[str, Union[list, torch.LongTensor]], is_ref_model: bool = False
1071+
):
1072+
"""
1073+
Runs the given model on the given batch of inputs, concatenating the chosen and rejected inputs together.
10711074
10721075
We do this to avoid doing two forward passes, because it's faster for FSDP.
1076+
1077+
Args:
1078+
model:
1079+
Model to run the forward pass on.
1080+
batch:
1081+
Batch of input data.
1082+
is_ref_model:
1083+
Whether this method is being called for the reference model. If `True`, length desensitization is not
1084+
applied.
10731085
"""
10741086
num_examples = batch["prompt_input_ids"].shape[0]
10751087

@@ -1218,6 +1230,28 @@ def concatenated_forward(self, model: nn.Module, batch: dict[str, Union[list, to
12181230
if self.loss_type == "ipo":
12191231
all_logps = all_logps / loss_mask.sum(-1)
12201232

1233+
if self.args.ld_alpha is not None and not is_ref_model:
1234+
# Compute response lengths based on loss_mask
1235+
completion_lengths = loss_mask.sum(dim=1)
1236+
1237+
chosen_lengths = completion_lengths[:num_examples]
1238+
rejected_lengths = completion_lengths[num_examples:]
1239+
public_lengths = torch.min(chosen_lengths, rejected_lengths) # l_p in the paper
1240+
public_lengths = torch.cat([public_lengths, public_lengths], dim=0)
1241+
1242+
seq_len = per_token_logps.size(1)
1243+
position_ids = torch.arange(seq_len, device=per_token_logps.device).expand_as(per_token_logps)
1244+
1245+
ld_mask = position_ids < public_lengths.unsqueeze(1)
1246+
mask = position_ids < completion_lengths.unsqueeze(1)
1247+
1248+
front_mask = (ld_mask & mask).float()
1249+
rear_mask = (~ld_mask & mask).float()
1250+
front_logps = (per_token_logps * front_mask).sum(dim=1)
1251+
rear_logps = (per_token_logps * rear_mask).sum(dim=1)
1252+
1253+
all_logps = front_logps + self.args.ld_alpha * rear_logps
1254+
12211255
output["chosen_logps"] = all_logps[:num_examples]
12221256
output["rejected_logps"] = all_logps[num_examples:]
12231257

0 commit comments

Comments
 (0)