1
+ from abc import abstractmethod
2
+ from functools import partial
3
+
1
4
import torch
5
+ from torch .nn import functional as F
2
6
3
7
4
8
class LigerFusedLinearPreferenceBase (torch .autograd .Function ):
9
+
10
+ @abstractmethod
11
+ def preference_loss_fn (chosen_logps , rejected_logps , beta = 0.1 ):
12
+ """
13
+ Compute preference loss.
14
+ Args:
15
+ chosen_logps (torch.Tensor): Avg log probabilities of chosen tokens. Shape: (batch_size,).
16
+ rejected_logps (torch.Tensor): Avg log probabilities of rejected tokens. Shape: (batch_size,).
17
+ beta (float): Weight for the odds ratio loss.
18
+ """
19
+ raise NotImplementedError ("Preference loss function must be implemented." )
20
+
5
21
@staticmethod
6
22
def forward (
7
23
ctx ,
@@ -11,6 +27,9 @@ def forward(
11
27
bias = None ,
12
28
loss_fn = None ,
13
29
chunk_size = 1 ,
30
+ compute_nll_loss = True ,
31
+ ignore_index = - 100 ,
32
+ beta = 0.1 ,
14
33
compiled = True ,
15
34
):
16
35
"""
@@ -24,6 +43,9 @@ def forward(
24
43
bias (torch.Tensor, optional): Bias tensor. Shape: (vocab_size,).
25
44
loss_fn (callable): Loss function to compute the loss on a chunk of input/target.
26
45
chunk_size (int): Size of a chunk (# of batches of stacked chosen and rejected inputs).
46
+ compute_nll_loss (bool): Whether to compute NLL loss.
47
+ ignore_index (int): Index to ignore for loss computation.
48
+ beta (float): Weight for the odds ratio loss.
27
49
compiled (bool): Whether to use torch compile for chunk accumulation.
28
50
"""
29
51
# TODO: Tune CHUNK_SIZE to fully utilize the GPU
@@ -36,21 +58,33 @@ def forward(
36
58
loss_acc = torch .zeros ((), device = _input .device )
37
59
38
60
chunks = max (1 , _input .shape [0 ] // (2 * CHUNK_SIZE ))
61
+ loss_func_to_call = partial (
62
+ LigerFusedLinearPreferenceBase ._compute_loss ,
63
+ preference_loss_fn = loss_fn ,
64
+ ignore_index = ignore_index ,
65
+ beta = beta ,
66
+ compute_nll_loss = compute_nll_loss ,
67
+ full_target = target ,
68
+ )
39
69
40
70
def accumulate_chunk (input_chunk , target_chunk ):
41
71
if bias is not None :
42
72
(chunk_grad_input , chunk_grad_weight , chunk_grad_bias ), (
43
73
chunk_loss ,
44
74
(chunk_or_loss , chunk_chosen_logps , chunk_rejected_logps ),
45
- ) = torch .func .grad_and_value (loss_fn , argnums = (0 , 1 , 3 ), has_aux = True )(
75
+ ) = torch .func .grad_and_value (
76
+ loss_func_to_call , argnums = (0 , 1 , 3 ), has_aux = True
77
+ )(
46
78
input_chunk , weight , target_chunk , bias
47
79
)
48
80
grad_bias .add_ (chunk_grad_bias )
49
81
else :
50
82
(chunk_grad_input , chunk_grad_weight ), (
51
83
chunk_loss ,
52
84
(chunk_or_loss , chunk_chosen_logps , chunk_rejected_logps ),
53
- ) = torch .func .grad_and_value (loss_fn , argnums = (0 , 1 ), has_aux = True )(
85
+ ) = torch .func .grad_and_value (
86
+ loss_func_to_call , argnums = (0 , 1 ), has_aux = True
87
+ )(
54
88
input_chunk , weight , target_chunk
55
89
)
56
90
grad_weight .add_ (chunk_grad_weight )
@@ -105,3 +139,68 @@ def backward(ctx, grad_output):
105
139
grad_bias = grad_bias * grad_output if grad_bias is not None else None
106
140
107
141
return grad_input , grad_weight , None , grad_bias , None , None , None
142
+
143
+ @staticmethod
144
+ def _compute_loss (
145
+ input_chunk ,
146
+ weight ,
147
+ target_chunk ,
148
+ bias = None ,
149
+ preference_loss_fn = None ,
150
+ full_target = None ,
151
+ ignore_index = - 100 ,
152
+ beta = 0.1 ,
153
+ compute_nll_loss = True ,
154
+ ** loss_kwargs ,
155
+ ):
156
+ """
157
+ Compute the total loss for a chunk of input and target, while using an alignment/preference loss function.
158
+ Args:
159
+ preference_loss_fn (callable): Loss function to compute the loss on a chunk of input/target.
160
+ input_chunk (torch.Tensor): Chunk of input tensor. Shape: (2 * chunk_size, sequence_length, hidden_size).
161
+ weight (torch.Tensor): Weight tensor. Shape: (vocab_size, hidden_size).
162
+ target_chunk (torch.Tensor): Chunk of target tensor. Shape: (2 * chunk_size, sequence_length).
163
+ bias (torch.Tensor, optional): Bias tensor. Shape: (vocab_size,).
164
+ full_target (torch.Tensor): Full target tensor. Shape: (batch_size, sequence_length).
165
+ ignore_index (int): Index to ignore for loss computation.
166
+ beta (float): Weight for the odds ratio loss.
167
+ loss_kwargs (dict): Additional arguments for the loss function.
168
+ """
169
+ len_chosen_chunk = target_chunk .shape [0 ] // 2
170
+
171
+ logits_chunk = input_chunk @ weight .t () # chunk_size x V
172
+ if bias is not None :
173
+ logits_chunk = logits_chunk + bias
174
+ log_probs_chunk = F .log_softmax (logits_chunk .float (), dim = - 1 )
175
+
176
+ chosen_nll_loss = 0.0
177
+ if compute_nll_loss :
178
+ chosen_nll_loss = F .nll_loss (
179
+ log_probs_chunk [:len_chosen_chunk ].view (- 1 , log_probs_chunk .shape [- 1 ]),
180
+ target_chunk [:len_chosen_chunk ].view (- 1 ),
181
+ reduction = "sum" ,
182
+ ignore_index = ignore_index ,
183
+ )
184
+ chosen_nll_loss = (
185
+ chosen_nll_loss
186
+ / (full_target [: full_target .shape [0 ] // 2 ] != ignore_index ).sum ()
187
+ )
188
+
189
+ loss_mask = target_chunk != ignore_index
190
+ label_chunk = torch .where (loss_mask , target_chunk , 0 )
191
+
192
+ per_token_logps = log_probs_chunk .gather (- 1 , label_chunk .unsqueeze (- 1 )).squeeze (
193
+ - 1
194
+ )
195
+ average_log_prob = (per_token_logps * loss_mask ).sum (- 1 ) / loss_mask .sum (- 1 )
196
+
197
+ chosen_logps = average_log_prob [:len_chosen_chunk ]
198
+ rejected_logps = average_log_prob [len_chosen_chunk :]
199
+
200
+ alignment_loss = preference_loss_fn (
201
+ chosen_logps , rejected_logps , beta = beta , ** loss_kwargs
202
+ )
203
+ alignment_loss = alignment_loss / (full_target .shape [0 ] // 2 )
204
+
205
+ loss = chosen_nll_loss - alignment_loss
206
+ return loss , (alignment_loss , chosen_logps , rejected_logps )
0 commit comments