@@ -18,6 +18,42 @@ def preference_loss_fn(chosen_logps, rejected_logps, beta=0.1):
18
18
"""
19
19
raise NotImplementedError ("Preference loss function must be implemented." )
20
20
21
+ @staticmethod
22
+ def chunk_forward (
23
+ input_chunk ,
24
+ weight ,
25
+ target_chunk ,
26
+ bias = None ,
27
+ ignore_index = - 100 ,
28
+ compute_nll_loss = True ,
29
+ ):
30
+ len_chosen_chunk = target_chunk .shape [0 ] // 2
31
+ logits_chunk = input_chunk @ weight .t ()
32
+ if bias is not None :
33
+ logits_chunk = logits_chunk + bias
34
+ log_probs_chunk = F .log_softmax (logits_chunk .float (), dim = - 1 )
35
+
36
+ chosen_nll_loss = 0.0
37
+ if compute_nll_loss :
38
+ chosen_nll_loss = F .nll_loss (
39
+ log_probs_chunk [:len_chosen_chunk ].view (- 1 , log_probs_chunk .shape [- 1 ]),
40
+ target_chunk [:len_chosen_chunk ].view (- 1 ),
41
+ reduction = "sum" ,
42
+ ignore_index = ignore_index ,
43
+ )
44
+
45
+ loss_mask = target_chunk != ignore_index
46
+ label_chunk = torch .where (loss_mask , target_chunk , 0 )
47
+
48
+ per_token_logps = log_probs_chunk .gather (- 1 , label_chunk .unsqueeze (- 1 )).squeeze (
49
+ - 1
50
+ )
51
+ average_log_prob = (per_token_logps * loss_mask ).sum (- 1 ) / loss_mask .sum (- 1 )
52
+
53
+ chosen_logps = average_log_prob [:len_chosen_chunk ]
54
+ rejected_logps = average_log_prob [len_chosen_chunk :]
55
+ return chosen_logps , rejected_logps , chosen_nll_loss
56
+
21
57
@staticmethod
22
58
def forward (
23
59
ctx ,
@@ -32,6 +68,9 @@ def forward(
32
68
beta = 0.1 ,
33
69
compute_nll_loss = True ,
34
70
compiled = True ,
71
+ use_ref_model = False ,
72
+ ref_weight = None ,
73
+ ref_bias = None ,
35
74
** loss_kwargs ,
36
75
):
37
76
"""
@@ -49,7 +88,11 @@ def forward(
49
88
ignore_index (int): Index to ignore for loss computation.
50
89
alpha (float): Weight for the NLL loss.
51
90
beta (float): Weight for the odds ratio loss.
91
+ compute_nll_loss (bool): Whether to compute NLL loss.
52
92
compiled (bool): Whether to use torch compile for chunk accumulation.
93
+ use_ref_model (bool): Whether to use a reference model for the alignment loss.
94
+ ref_weight (torch.Tensor): Reference weight tensor. Shape: (vocab_size, hidden_size).
95
+ ref_bias (torch.Tensor, optional): Reference bias tensor. Shape: (vocab_size,).
53
96
loss_kwargs (dict): Other possible arguments that a loss function might need
54
97
"""
55
98
# TODO: Tune CHUNK_SIZE to fully utilize the GPU
@@ -61,7 +104,6 @@ def forward(
61
104
grad_bias = torch .zeros_like (bias ) if bias is not None else None
62
105
loss_acc = torch .zeros ((), device = _input .device )
63
106
64
- chunks = max (1 , _input .shape [0 ] // (2 * CHUNK_SIZE ))
65
107
loss_func_to_call = partial (
66
108
LigerFusedLinearPreferenceBase ._compute_loss ,
67
109
preference_loss_fn = loss_fn ,
@@ -70,6 +112,9 @@ def forward(
70
112
beta = beta ,
71
113
compute_nll_loss = compute_nll_loss ,
72
114
full_target = target ,
115
+ use_ref_model = use_ref_model ,
116
+ ref_weight = ref_weight ,
117
+ ref_bias = ref_bias ,
73
118
** loss_kwargs ,
74
119
)
75
120
@@ -101,6 +146,7 @@ def accumulate_chunk(input_chunk, target_chunk):
101
146
accumulate_chunk = torch .compile (accumulate_chunk )
102
147
103
148
len_chosen = target .shape [0 ] // 2
149
+ chunks = max (1 , _input .shape [0 ] // (2 * CHUNK_SIZE ))
104
150
_chosen_input_chunks = torch .chunk (_input [:len_chosen ], chunks = chunks , dim = 0 )
105
151
_chosen_target_chunks = torch .chunk (target [:len_chosen ], chunks = chunks , dim = 0 )
106
152
_rejected_input_chunks = torch .chunk (_input [len_chosen :], chunks = chunks , dim = 0 )
@@ -159,6 +205,9 @@ def _compute_loss(
159
205
alpha = 1.0 ,
160
206
beta = 0.1 ,
161
207
compute_nll_loss = True ,
208
+ use_ref_model = False ,
209
+ ref_weight = None ,
210
+ ref_bias = None ,
162
211
** loss_kwargs ,
163
212
):
164
213
"""
@@ -173,38 +222,41 @@ def _compute_loss(
173
222
ignore_index (int): Index to ignore for loss computation.
174
223
alpha (float): Weight for the NLL loss.
175
224
beta (float): Weight for the odds ratio loss.
225
+ compute_nll_loss (bool): Whether to compute NLL loss.
226
+ use_ref_model (bool): Whether to use a reference model for the alignment loss.
227
+ ref_weight (torch.Tensor): Reference weight tensor. Shape: (vocab_size, hidden_size).
228
+ ref_bias (torch.Tensor, optional): Reference bias tensor. Shape: (vocab_size,).
176
229
loss_kwargs (dict): Additional arguments for the loss function.
177
230
"""
178
- len_chosen_chunk = target_chunk .shape [0 ] // 2
179
-
180
- logits_chunk = input_chunk @ weight .t () # chunk_size x V
181
- if bias is not None :
182
- logits_chunk = logits_chunk + bias
183
- log_probs_chunk = F .log_softmax (logits_chunk .float (), dim = - 1 )
184
-
185
- chosen_nll_loss = 0.0
186
- if compute_nll_loss :
187
- chosen_nll_loss = F .nll_loss (
188
- log_probs_chunk [:len_chosen_chunk ].view (- 1 , log_probs_chunk .shape [- 1 ]),
189
- target_chunk [:len_chosen_chunk ].view (- 1 ),
190
- reduction = "sum" ,
231
+ chosen_logps , rejected_logps , chosen_nll_loss = (
232
+ LigerFusedLinearPreferenceBase .chunk_forward (
233
+ input_chunk ,
234
+ weight ,
235
+ target_chunk ,
236
+ bias = bias ,
191
237
ignore_index = ignore_index ,
238
+ compute_nll_loss = compute_nll_loss ,
192
239
)
193
- chosen_nll_loss = (
194
- chosen_nll_loss
195
- / (full_target [: full_target .shape [0 ] // 2 ] != ignore_index ).sum ()
196
- )
197
-
198
- loss_mask = target_chunk != ignore_index
199
- label_chunk = torch .where (loss_mask , target_chunk , 0 )
200
-
201
- per_token_logps = log_probs_chunk .gather (- 1 , label_chunk .unsqueeze (- 1 )).squeeze (
202
- - 1
203
240
)
204
- average_log_prob = (per_token_logps * loss_mask ).sum (- 1 ) / loss_mask .sum (- 1 )
241
+ chosen_nll_loss = (
242
+ chosen_nll_loss
243
+ / (full_target [: full_target .shape [0 ] // 2 ] != ignore_index ).sum ()
244
+ )
205
245
206
- chosen_logps = average_log_prob [:len_chosen_chunk ]
207
- rejected_logps = average_log_prob [len_chosen_chunk :]
246
+ if use_ref_model :
247
+ with torch .no_grad ():
248
+ ref_chosen_logps , ref_rejected_logps , _ = (
249
+ LigerFusedLinearPreferenceBase .chunk_forward (
250
+ input_chunk ,
251
+ ref_weight ,
252
+ target_chunk ,
253
+ ref_bias ,
254
+ ignore_index = ignore_index ,
255
+ compute_nll_loss = False ,
256
+ )
257
+ )
258
+ loss_kwargs ["ref_chosen_logps" ] = ref_chosen_logps
259
+ loss_kwargs ["ref_rejected_logps" ] = ref_rejected_logps
208
260
209
261
alignment_loss = preference_loss_fn (
210
262
chosen_logps , rejected_logps , beta = beta , ** loss_kwargs
0 commit comments