Skip to content

Commit d907ec0

Browse files
authored
add reference model logps to chunkedloss interface and fix dpo loss fn (#405)
accomodate reference model logps in chunked loss interface and make dpo loss use reference model logps in its loss function ## Summary <!--- This is a required section; please describe the main purpose of this proposed code change. ---> as title <!--- ## Details This is an optional section; is there anything specific that reviewers should be aware of? ---> ## Testing Done <!--- This is a required section; please describe how this change was tested. ---> <!-- Replace BLANK with your device type. For example, A100-80G-PCIe Complete the following tasks before sending your PR, and replace `[ ]` with `[x]` to indicate you have done them. --> - Hardware Type: <BLANK> - [ ] run `make test` to ensure correctness - [x] run `make checkstyle` to ensure code style - [ ] run `make test-convergence` to ensure convergence
1 parent 317ff43 commit d907ec0

File tree

4 files changed

+216
-46
lines changed

4 files changed

+216
-46
lines changed

src/liger_kernel/chunked_loss/dpo_loss.py

Lines changed: 36 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,15 +9,31 @@
99
class LigerFusedLinearDPOFunction(LigerFusedLinearPreferenceBase):
1010

1111
@staticmethod
12-
def preference_loss_fn(chosen_logps, rejected_logps, beta=0.1):
12+
def preference_loss_fn(
13+
chosen_logps,
14+
rejected_logps,
15+
ref_chosen_logps=None,
16+
ref_rejected_logps=None,
17+
beta=0.1,
18+
):
1319
"""
1420
Compute DPO loss (Direct Preference Optimization).
1521
Args:
1622
chosen_logps (torch.Tensor): Avg log probabilities of chosen tokens. Shape: (batch_size,).
1723
rejected_logps (torch.Tensor): Avg log probabilities of rejected tokens. Shape: (batch_size,).
24+
ref_chosen_logps (torch.Tensor, optional): Reference log probabilities of chosen tokens. Shape: (batch_size,).
25+
ref_rejected_logps (torch.Tensor, optional): Reference log probabilities of rejected tokens. Shape: (batch_size,).
1826
beta (float): Weight for the direct preference loss.
1927
"""
20-
logits_diff = beta * (chosen_logps - rejected_logps)
28+
if ref_chosen_logps is None:
29+
ref_chosen_logps = torch.tensor(0.0, device=chosen_logps.device)
30+
if ref_rejected_logps is None:
31+
ref_rejected_logps = torch.tensor(0.0, device=rejected_logps.device)
32+
33+
chosen_logratios = chosen_logps - ref_chosen_logps
34+
rejected_logratios = rejected_logps - ref_rejected_logps
35+
36+
logits_diff = beta * (chosen_logratios - rejected_logratios)
2137
losses = -F.logsigmoid(logits_diff)
2238
return losses.sum()
2339

@@ -28,10 +44,13 @@ def forward(
2844
weight,
2945
target,
3046
bias=None,
47+
ref_weight=None,
48+
ref_bias=None,
3149
ignore_index=-100,
3250
beta=0.1,
3351
compute_nll_loss=True,
3452
compiled=True,
53+
use_ref_model=True,
3554
):
3655
"""
3756
Fused linear layer with DPO (Direct Preference Optimization) loss.
@@ -48,14 +67,17 @@ def forward(
4867
beta=beta,
4968
compute_nll_loss=compute_nll_loss,
5069
compiled=compiled,
70+
use_ref_model=use_ref_model,
71+
ref_weight=ref_weight,
72+
ref_bias=ref_bias,
5173
)
5274

5375
@staticmethod
5476
def backward(ctx, grad_output):
5577
# Get gradients for _input, weight, bias, and target from the base class
5678
grads = LigerFusedLinearPreferenceBase.backward(ctx, grad_output)[:4]
5779
# Return these gradients, followed by None for the remaining inputs
58-
return *grads, None, None, None, None
80+
return *grads, None, None, None, None, None, None, None
5981

6082

6183
class LigerFusedLinearDPOLoss(torch.nn.Module):
@@ -69,26 +91,36 @@ def __init__(
6991
beta: float = 0.1,
7092
compute_nll_loss: bool = True,
7193
compiled: bool = True,
94+
use_ref_model: bool = False,
7295
):
7396
"""
7497
Args:
7598
ignore_index (int): Index to ignore in the loss.
7699
beta (float): Weight for the odds ratio loss.
100+
compute_nll_loss (bool): Whether to compute the NLL loss.
101+
compiled (bool): Whether to use the torch compiled kernel.
102+
use_ref_model (bool): Whether to use a reference model for the DPO loss.
77103
"""
78104
super().__init__()
79105
self.ignore_index = ignore_index
80106
self.beta = beta
81107
self.compute_nll_loss = compute_nll_loss
82108
self.compiled = compiled
109+
self.use_ref_model = use_ref_model
83110

84-
def forward(self, lin_weight, _input, target, bias=None):
111+
def forward(
112+
self, lin_weight, _input, target, bias=None, ref_weight=None, ref_bias=None
113+
):
85114
return LigerFusedLinearDPOFunction.apply(
86115
_input,
87116
lin_weight,
88117
target,
89118
bias,
119+
ref_weight,
120+
ref_bias,
90121
self.ignore_index,
91122
self.beta,
92123
self.compute_nll_loss,
93124
self.compiled,
125+
self.use_ref_model,
94126
)

src/liger_kernel/chunked_loss/fused_linear_preference.py

Lines changed: 79 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,42 @@ def preference_loss_fn(chosen_logps, rejected_logps, beta=0.1):
1818
"""
1919
raise NotImplementedError("Preference loss function must be implemented.")
2020

21+
@staticmethod
22+
def chunk_forward(
23+
input_chunk,
24+
weight,
25+
target_chunk,
26+
bias=None,
27+
ignore_index=-100,
28+
compute_nll_loss=True,
29+
):
30+
len_chosen_chunk = target_chunk.shape[0] // 2
31+
logits_chunk = input_chunk @ weight.t()
32+
if bias is not None:
33+
logits_chunk = logits_chunk + bias
34+
log_probs_chunk = F.log_softmax(logits_chunk.float(), dim=-1)
35+
36+
chosen_nll_loss = 0.0
37+
if compute_nll_loss:
38+
chosen_nll_loss = F.nll_loss(
39+
log_probs_chunk[:len_chosen_chunk].view(-1, log_probs_chunk.shape[-1]),
40+
target_chunk[:len_chosen_chunk].view(-1),
41+
reduction="sum",
42+
ignore_index=ignore_index,
43+
)
44+
45+
loss_mask = target_chunk != ignore_index
46+
label_chunk = torch.where(loss_mask, target_chunk, 0)
47+
48+
per_token_logps = log_probs_chunk.gather(-1, label_chunk.unsqueeze(-1)).squeeze(
49+
-1
50+
)
51+
average_log_prob = (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1)
52+
53+
chosen_logps = average_log_prob[:len_chosen_chunk]
54+
rejected_logps = average_log_prob[len_chosen_chunk:]
55+
return chosen_logps, rejected_logps, chosen_nll_loss
56+
2157
@staticmethod
2258
def forward(
2359
ctx,
@@ -32,6 +68,9 @@ def forward(
3268
beta=0.1,
3369
compute_nll_loss=True,
3470
compiled=True,
71+
use_ref_model=False,
72+
ref_weight=None,
73+
ref_bias=None,
3574
**loss_kwargs,
3675
):
3776
"""
@@ -49,7 +88,11 @@ def forward(
4988
ignore_index (int): Index to ignore for loss computation.
5089
alpha (float): Weight for the NLL loss.
5190
beta (float): Weight for the odds ratio loss.
91+
compute_nll_loss (bool): Whether to compute NLL loss.
5292
compiled (bool): Whether to use torch compile for chunk accumulation.
93+
use_ref_model (bool): Whether to use a reference model for the alignment loss.
94+
ref_weight (torch.Tensor): Reference weight tensor. Shape: (vocab_size, hidden_size).
95+
ref_bias (torch.Tensor, optional): Reference bias tensor. Shape: (vocab_size,).
5396
loss_kwargs (dict): Other possible arguments that a loss function might need
5497
"""
5598
# TODO: Tune CHUNK_SIZE to fully utilize the GPU
@@ -61,7 +104,6 @@ def forward(
61104
grad_bias = torch.zeros_like(bias) if bias is not None else None
62105
loss_acc = torch.zeros((), device=_input.device)
63106

64-
chunks = max(1, _input.shape[0] // (2 * CHUNK_SIZE))
65107
loss_func_to_call = partial(
66108
LigerFusedLinearPreferenceBase._compute_loss,
67109
preference_loss_fn=loss_fn,
@@ -70,6 +112,9 @@ def forward(
70112
beta=beta,
71113
compute_nll_loss=compute_nll_loss,
72114
full_target=target,
115+
use_ref_model=use_ref_model,
116+
ref_weight=ref_weight,
117+
ref_bias=ref_bias,
73118
**loss_kwargs,
74119
)
75120

@@ -101,6 +146,7 @@ def accumulate_chunk(input_chunk, target_chunk):
101146
accumulate_chunk = torch.compile(accumulate_chunk)
102147

103148
len_chosen = target.shape[0] // 2
149+
chunks = max(1, _input.shape[0] // (2 * CHUNK_SIZE))
104150
_chosen_input_chunks = torch.chunk(_input[:len_chosen], chunks=chunks, dim=0)
105151
_chosen_target_chunks = torch.chunk(target[:len_chosen], chunks=chunks, dim=0)
106152
_rejected_input_chunks = torch.chunk(_input[len_chosen:], chunks=chunks, dim=0)
@@ -159,6 +205,9 @@ def _compute_loss(
159205
alpha=1.0,
160206
beta=0.1,
161207
compute_nll_loss=True,
208+
use_ref_model=False,
209+
ref_weight=None,
210+
ref_bias=None,
162211
**loss_kwargs,
163212
):
164213
"""
@@ -173,38 +222,41 @@ def _compute_loss(
173222
ignore_index (int): Index to ignore for loss computation.
174223
alpha (float): Weight for the NLL loss.
175224
beta (float): Weight for the odds ratio loss.
225+
compute_nll_loss (bool): Whether to compute NLL loss.
226+
use_ref_model (bool): Whether to use a reference model for the alignment loss.
227+
ref_weight (torch.Tensor): Reference weight tensor. Shape: (vocab_size, hidden_size).
228+
ref_bias (torch.Tensor, optional): Reference bias tensor. Shape: (vocab_size,).
176229
loss_kwargs (dict): Additional arguments for the loss function.
177230
"""
178-
len_chosen_chunk = target_chunk.shape[0] // 2
179-
180-
logits_chunk = input_chunk @ weight.t() # chunk_size x V
181-
if bias is not None:
182-
logits_chunk = logits_chunk + bias
183-
log_probs_chunk = F.log_softmax(logits_chunk.float(), dim=-1)
184-
185-
chosen_nll_loss = 0.0
186-
if compute_nll_loss:
187-
chosen_nll_loss = F.nll_loss(
188-
log_probs_chunk[:len_chosen_chunk].view(-1, log_probs_chunk.shape[-1]),
189-
target_chunk[:len_chosen_chunk].view(-1),
190-
reduction="sum",
231+
chosen_logps, rejected_logps, chosen_nll_loss = (
232+
LigerFusedLinearPreferenceBase.chunk_forward(
233+
input_chunk,
234+
weight,
235+
target_chunk,
236+
bias=bias,
191237
ignore_index=ignore_index,
238+
compute_nll_loss=compute_nll_loss,
192239
)
193-
chosen_nll_loss = (
194-
chosen_nll_loss
195-
/ (full_target[: full_target.shape[0] // 2] != ignore_index).sum()
196-
)
197-
198-
loss_mask = target_chunk != ignore_index
199-
label_chunk = torch.where(loss_mask, target_chunk, 0)
200-
201-
per_token_logps = log_probs_chunk.gather(-1, label_chunk.unsqueeze(-1)).squeeze(
202-
-1
203240
)
204-
average_log_prob = (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1)
241+
chosen_nll_loss = (
242+
chosen_nll_loss
243+
/ (full_target[: full_target.shape[0] // 2] != ignore_index).sum()
244+
)
205245

206-
chosen_logps = average_log_prob[:len_chosen_chunk]
207-
rejected_logps = average_log_prob[len_chosen_chunk:]
246+
if use_ref_model:
247+
with torch.no_grad():
248+
ref_chosen_logps, ref_rejected_logps, _ = (
249+
LigerFusedLinearPreferenceBase.chunk_forward(
250+
input_chunk,
251+
ref_weight,
252+
target_chunk,
253+
ref_bias,
254+
ignore_index=ignore_index,
255+
compute_nll_loss=False,
256+
)
257+
)
258+
loss_kwargs["ref_chosen_logps"] = ref_chosen_logps
259+
loss_kwargs["ref_rejected_logps"] = ref_rejected_logps
208260

209261
alignment_loss = preference_loss_fn(
210262
chosen_logps, rejected_logps, beta=beta, **loss_kwargs

0 commit comments

Comments
 (0)