1- from functools import partial
2-
3- import torch
41import torch .nn .functional as F
52
63from liger_kernel .chunked_loss .fused_linear_preference import (
74 LigerFusedLinearPreferenceBase ,
85)
96
107
11- def dpo_loss (chosen_logps , rejected_logps , beta = 0.1 ):
12- """
13- Compute DPO loss (Direct Preference Optimization).
14- Args:
15- chosen_logps (torch.Tensor): Avg log probabilities of chosen tokens. Shape: (batch_size,).
16- rejected_logps (torch.Tensor): Avg log probabilities of rejected tokens. Shape: (batch_size,).
17- beta (float): Weight for the direct preference loss.
18- """
19- logits_diff = beta * (chosen_logps - rejected_logps )
20- losses = - F .logsigmoid (logits_diff )
21- return losses .sum ()
22-
23-
24- def _compute_dpo_loss (
25- input_chunk ,
26- weight ,
27- target_chunk ,
28- bias = None ,
29- full_target = None ,
30- ignore_index = - 100 ,
31- beta = 0.1 ,
32- compute_nll_loss = True ,
33- ):
34- """
35- Compute DPO loss for a chunk of input and target.
36- Args:
37- input_chunk (torch.Tensor): Chunk of input tensor. Shape: (2 * chunk_size, sequence_length, hidden_size).
38- weight (torch.Tensor): Weight tensor. Shape: (vocab_size, hidden_size).
39- target_chunk (torch.Tensor): Chunk of target tensor. Shape: (2 * chunk_size, sequence_length).
40- bias (torch.Tensor, optional): Bias tensor. Shape: (vocab_size,).
41- full_target (torch.Tensor): Full target tensor. Shape: (batch_size, sequence_length).
42- ignore_index (int): Index to ignore for loss computation.
43- beta (float): Weight for the direct preference loss.
44- """
45-
46- len_chosen_chunk = target_chunk .shape [0 ] // 2
47-
48- logits_chunk = input_chunk @ weight .t () # chunk_size x V
49- if bias is not None :
50- logits_chunk = logits_chunk + bias
51- log_probs_chunk = F .log_softmax (
52- logits_chunk .float (), dim = - 1
53- ) # Normalize the unnorm_logits
54-
55- # Compute NLL loss for chosen responses
56- chosen_nll_loss = 0.0
57- if compute_nll_loss :
58- chosen_nll_loss = F .nll_loss (
59- log_probs_chunk [:len_chosen_chunk ].view (- 1 , log_probs_chunk .shape [- 1 ]),
60- target_chunk [:len_chosen_chunk ].view (- 1 ),
61- reduction = "sum" ,
62- ignore_index = ignore_index ,
63- )
64- chosen_nll_loss = (
65- chosen_nll_loss
66- / (full_target [: full_target .shape [0 ] // 2 ] != ignore_index ).sum ()
67- )
68-
69- # Compute log probabilities for both chosen and rejected
70- loss_mask = target_chunk != ignore_index
71- label_chunk = torch .where (loss_mask , target_chunk , 0 )
72-
73- per_token_logps = log_probs_chunk .gather (- 1 , label_chunk .unsqueeze (- 1 )).squeeze (- 1 )
74- average_log_prob = (per_token_logps * loss_mask ).sum (- 1 ) / loss_mask .sum (- 1 )
75-
76- chosen_logps = average_log_prob [:len_chosen_chunk ]
77- rejected_logps = average_log_prob [len_chosen_chunk :]
78-
79- # Compute DPO loss
80- preference_loss = dpo_loss (chosen_logps , rejected_logps , beta = beta )
81- preference_loss = preference_loss / (full_target .shape [0 ] // 2 )
82-
83- # Total loss combines NLL and DPO loss
84- loss = chosen_nll_loss + preference_loss
85- return loss , (preference_loss , chosen_logps , rejected_logps )
8+ class LigerFusedLinearDPOFunction (LigerFusedLinearPreferenceBase ):
869
10+ @staticmethod
11+ def preference_loss_fn (chosen_logps , rejected_logps , beta = 0.1 ):
12+ """
13+ Compute DPO loss (Direct Preference Optimization).
14+ Args:
15+ chosen_logps (torch.Tensor): Avg log probabilities of chosen tokens. Shape: (batch_size,).
16+ rejected_logps (torch.Tensor): Avg log probabilities of rejected tokens. Shape: (batch_size,).
17+ beta (float): Weight for the direct preference loss.
18+ """
19+ logits_diff = beta * (chosen_logps - rejected_logps )
20+ losses = - F .logsigmoid (logits_diff )
21+ return losses .sum ()
8722
88- class LigerFusedLinearDPOFunction (LigerFusedLinearPreferenceBase ):
8923 @staticmethod
9024 def forward (
9125 ctx ,
@@ -101,17 +35,18 @@ def forward(
10135 """
10236 Fused linear layer with DPO (Direct Preference Optimization) loss.
10337 Handles both the forward and backward pass of the final linear layer with DPO loss.
104- Inspired from LigerFusedLinearCrossEntropyFunction (https://arxiv.org/abs/2410.10989) which fuses final linear layer and CE loss.
10538 """
106- dpo_loss_fn = partial (
107- _compute_dpo_loss ,
108- full_target = target ,
39+ return LigerFusedLinearPreferenceBase .forward (
40+ ctx = ctx ,
41+ _input = _input ,
42+ weight = weight ,
43+ target = target ,
44+ bias = bias ,
45+ loss_fn = LigerFusedLinearDPOFunction .preference_loss_fn ,
46+ compute_nll_loss = compute_nll_loss ,
10947 ignore_index = ignore_index ,
11048 beta = beta ,
111- compute_nll_loss = compute_nll_loss ,
112- )
113- return LigerFusedLinearPreferenceBase .forward (
114- ctx , _input , weight , target , bias , loss_fn = dpo_loss_fn
49+ compiled = compiled ,
11550 )
11651
11752 @staticmethod
0 commit comments