Skip to content

Commit 854c1b3

Browse files
committed
align with interface
1 parent a50d38b commit 854c1b3

File tree

1 file changed

+22
-87
lines changed

1 file changed

+22
-87
lines changed

src/liger_kernel/chunked_loss/dpo_loss.py

Lines changed: 22 additions & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -1,91 +1,25 @@
1-
from functools import partial
2-
3-
import torch
41
import torch.nn.functional as F
52

63
from liger_kernel.chunked_loss.fused_linear_preference import (
74
LigerFusedLinearPreferenceBase,
85
)
96

107

11-
def dpo_loss(chosen_logps, rejected_logps, beta=0.1):
12-
"""
13-
Compute DPO loss (Direct Preference Optimization).
14-
Args:
15-
chosen_logps (torch.Tensor): Avg log probabilities of chosen tokens. Shape: (batch_size,).
16-
rejected_logps (torch.Tensor): Avg log probabilities of rejected tokens. Shape: (batch_size,).
17-
beta (float): Weight for the direct preference loss.
18-
"""
19-
logits_diff = beta * (chosen_logps - rejected_logps)
20-
losses = -F.logsigmoid(logits_diff)
21-
return losses.sum()
22-
23-
24-
def _compute_dpo_loss(
25-
input_chunk,
26-
weight,
27-
target_chunk,
28-
bias=None,
29-
full_target=None,
30-
ignore_index=-100,
31-
beta=0.1,
32-
compute_nll_loss=True,
33-
):
34-
"""
35-
Compute DPO loss for a chunk of input and target.
36-
Args:
37-
input_chunk (torch.Tensor): Chunk of input tensor. Shape: (2 * chunk_size, sequence_length, hidden_size).
38-
weight (torch.Tensor): Weight tensor. Shape: (vocab_size, hidden_size).
39-
target_chunk (torch.Tensor): Chunk of target tensor. Shape: (2 * chunk_size, sequence_length).
40-
bias (torch.Tensor, optional): Bias tensor. Shape: (vocab_size,).
41-
full_target (torch.Tensor): Full target tensor. Shape: (batch_size, sequence_length).
42-
ignore_index (int): Index to ignore for loss computation.
43-
beta (float): Weight for the direct preference loss.
44-
"""
45-
46-
len_chosen_chunk = target_chunk.shape[0] // 2
47-
48-
logits_chunk = input_chunk @ weight.t() # chunk_size x V
49-
if bias is not None:
50-
logits_chunk = logits_chunk + bias
51-
log_probs_chunk = F.log_softmax(
52-
logits_chunk.float(), dim=-1
53-
) # Normalize the unnorm_logits
54-
55-
# Compute NLL loss for chosen responses
56-
chosen_nll_loss = 0.0
57-
if compute_nll_loss:
58-
chosen_nll_loss = F.nll_loss(
59-
log_probs_chunk[:len_chosen_chunk].view(-1, log_probs_chunk.shape[-1]),
60-
target_chunk[:len_chosen_chunk].view(-1),
61-
reduction="sum",
62-
ignore_index=ignore_index,
63-
)
64-
chosen_nll_loss = (
65-
chosen_nll_loss
66-
/ (full_target[: full_target.shape[0] // 2] != ignore_index).sum()
67-
)
68-
69-
# Compute log probabilities for both chosen and rejected
70-
loss_mask = target_chunk != ignore_index
71-
label_chunk = torch.where(loss_mask, target_chunk, 0)
72-
73-
per_token_logps = log_probs_chunk.gather(-1, label_chunk.unsqueeze(-1)).squeeze(-1)
74-
average_log_prob = (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1)
75-
76-
chosen_logps = average_log_prob[:len_chosen_chunk]
77-
rejected_logps = average_log_prob[len_chosen_chunk:]
78-
79-
# Compute DPO loss
80-
preference_loss = dpo_loss(chosen_logps, rejected_logps, beta=beta)
81-
preference_loss = preference_loss / (full_target.shape[0] // 2)
82-
83-
# Total loss combines NLL and DPO loss
84-
loss = chosen_nll_loss + preference_loss
85-
return loss, (preference_loss, chosen_logps, rejected_logps)
8+
class LigerFusedLinearDPOFunction(LigerFusedLinearPreferenceBase):
869

10+
@staticmethod
11+
def preference_loss_fn(chosen_logps, rejected_logps, beta=0.1):
12+
"""
13+
Compute DPO loss (Direct Preference Optimization).
14+
Args:
15+
chosen_logps (torch.Tensor): Avg log probabilities of chosen tokens. Shape: (batch_size,).
16+
rejected_logps (torch.Tensor): Avg log probabilities of rejected tokens. Shape: (batch_size,).
17+
beta (float): Weight for the direct preference loss.
18+
"""
19+
logits_diff = beta * (chosen_logps - rejected_logps)
20+
losses = -F.logsigmoid(logits_diff)
21+
return losses.sum()
8722

88-
class LigerFusedLinearDPOFunction(LigerFusedLinearPreferenceBase):
8923
@staticmethod
9024
def forward(
9125
ctx,
@@ -101,17 +35,18 @@ def forward(
10135
"""
10236
Fused linear layer with DPO (Direct Preference Optimization) loss.
10337
Handles both the forward and backward pass of the final linear layer with DPO loss.
104-
Inspired from LigerFusedLinearCrossEntropyFunction (https://arxiv.org/abs/2410.10989) which fuses final linear layer and CE loss.
10538
"""
106-
dpo_loss_fn = partial(
107-
_compute_dpo_loss,
108-
full_target=target,
39+
return LigerFusedLinearPreferenceBase.forward(
40+
ctx=ctx,
41+
_input=_input,
42+
weight=weight,
43+
target=target,
44+
bias=bias,
45+
loss_fn=LigerFusedLinearDPOFunction.preference_loss_fn,
46+
compute_nll_loss=compute_nll_loss,
10947
ignore_index=ignore_index,
11048
beta=beta,
111-
compute_nll_loss=compute_nll_loss,
112-
)
113-
return LigerFusedLinearPreferenceBase.forward(
114-
ctx, _input, weight, target, bias, loss_fn=dpo_loss_fn
49+
compiled=compiled,
11550
)
11651

11752
@staticmethod

0 commit comments

Comments
 (0)