Skip to content

Commit 2281b7e

Browse files
pramodithpramodith
andauthored
Refactor LigerFusedLinearPreferenceBase (#381)
## Summary This PR refactors the `LigerFusedLinearPreferenceBase` class to contain an abstractmethod corresponding to the calculation of the loss that needs to be implemented by all sub-classes. It also adds a new function to the class called `_compute_loss` which is mostly the same as the `_compute_orpo_loss` function introduced in #362 but makes it generic to calculate the NLL/Cross Entropy Loss plus accepts a custom loss function that implements a new alignment loss function. Most RLHF/RLAIF/Alignment algorithms state their final loss as `NLL + Beta * (Alignment_Loss) `so adding the NLL logic inside the base class reduces repeated code. The _compute_loss function accepts ## Testing Done On A100-80G-SXM - Hardware Type: <BLANK> - [X] run `make test` to ensure correctness - [X] run `make checkstyle` to ensure code style - [X] run `make test-convergence` to ensure convergence --------- Co-authored-by: pramodith <[email protected]>
1 parent 6b2fd02 commit 2281b7e

File tree

2 files changed

+126
-81
lines changed

2 files changed

+126
-81
lines changed

src/liger_kernel/chunked_loss/fused_linear_preference.py

Lines changed: 101 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,23 @@
1+
from abc import abstractmethod
2+
from functools import partial
3+
14
import torch
5+
from torch.nn import functional as F
26

37

48
class LigerFusedLinearPreferenceBase(torch.autograd.Function):
9+
10+
@abstractmethod
11+
def preference_loss_fn(chosen_logps, rejected_logps, beta=0.1):
12+
"""
13+
Compute preference loss.
14+
Args:
15+
chosen_logps (torch.Tensor): Avg log probabilities of chosen tokens. Shape: (batch_size,).
16+
rejected_logps (torch.Tensor): Avg log probabilities of rejected tokens. Shape: (batch_size,).
17+
beta (float): Weight for the odds ratio loss.
18+
"""
19+
raise NotImplementedError("Preference loss function must be implemented.")
20+
521
@staticmethod
622
def forward(
723
ctx,
@@ -11,6 +27,9 @@ def forward(
1127
bias=None,
1228
loss_fn=None,
1329
chunk_size=1,
30+
compute_nll_loss=True,
31+
ignore_index=-100,
32+
beta=0.1,
1433
compiled=True,
1534
):
1635
"""
@@ -24,6 +43,9 @@ def forward(
2443
bias (torch.Tensor, optional): Bias tensor. Shape: (vocab_size,).
2544
loss_fn (callable): Loss function to compute the loss on a chunk of input/target.
2645
chunk_size (int): Size of a chunk (# of batches of stacked chosen and rejected inputs).
46+
compute_nll_loss (bool): Whether to compute NLL loss.
47+
ignore_index (int): Index to ignore for loss computation.
48+
beta (float): Weight for the odds ratio loss.
2749
compiled (bool): Whether to use torch compile for chunk accumulation.
2850
"""
2951
# TODO: Tune CHUNK_SIZE to fully utilize the GPU
@@ -36,21 +58,33 @@ def forward(
3658
loss_acc = torch.zeros((), device=_input.device)
3759

3860
chunks = max(1, _input.shape[0] // (2 * CHUNK_SIZE))
61+
loss_func_to_call = partial(
62+
LigerFusedLinearPreferenceBase._compute_loss,
63+
preference_loss_fn=loss_fn,
64+
ignore_index=ignore_index,
65+
beta=beta,
66+
compute_nll_loss=compute_nll_loss,
67+
full_target=target,
68+
)
3969

4070
def accumulate_chunk(input_chunk, target_chunk):
4171
if bias is not None:
4272
(chunk_grad_input, chunk_grad_weight, chunk_grad_bias), (
4373
chunk_loss,
4474
(chunk_or_loss, chunk_chosen_logps, chunk_rejected_logps),
45-
) = torch.func.grad_and_value(loss_fn, argnums=(0, 1, 3), has_aux=True)(
75+
) = torch.func.grad_and_value(
76+
loss_func_to_call, argnums=(0, 1, 3), has_aux=True
77+
)(
4678
input_chunk, weight, target_chunk, bias
4779
)
4880
grad_bias.add_(chunk_grad_bias)
4981
else:
5082
(chunk_grad_input, chunk_grad_weight), (
5183
chunk_loss,
5284
(chunk_or_loss, chunk_chosen_logps, chunk_rejected_logps),
53-
) = torch.func.grad_and_value(loss_fn, argnums=(0, 1), has_aux=True)(
85+
) = torch.func.grad_and_value(
86+
loss_func_to_call, argnums=(0, 1), has_aux=True
87+
)(
5488
input_chunk, weight, target_chunk
5589
)
5690
grad_weight.add_(chunk_grad_weight)
@@ -105,3 +139,68 @@ def backward(ctx, grad_output):
105139
grad_bias = grad_bias * grad_output if grad_bias is not None else None
106140

107141
return grad_input, grad_weight, None, grad_bias, None, None, None
142+
143+
@staticmethod
144+
def _compute_loss(
145+
input_chunk,
146+
weight,
147+
target_chunk,
148+
bias=None,
149+
preference_loss_fn=None,
150+
full_target=None,
151+
ignore_index=-100,
152+
beta=0.1,
153+
compute_nll_loss=True,
154+
**loss_kwargs,
155+
):
156+
"""
157+
Compute the total loss for a chunk of input and target, while using an alignment/preference loss function.
158+
Args:
159+
preference_loss_fn (callable): Loss function to compute the loss on a chunk of input/target.
160+
input_chunk (torch.Tensor): Chunk of input tensor. Shape: (2 * chunk_size, sequence_length, hidden_size).
161+
weight (torch.Tensor): Weight tensor. Shape: (vocab_size, hidden_size).
162+
target_chunk (torch.Tensor): Chunk of target tensor. Shape: (2 * chunk_size, sequence_length).
163+
bias (torch.Tensor, optional): Bias tensor. Shape: (vocab_size,).
164+
full_target (torch.Tensor): Full target tensor. Shape: (batch_size, sequence_length).
165+
ignore_index (int): Index to ignore for loss computation.
166+
beta (float): Weight for the odds ratio loss.
167+
loss_kwargs (dict): Additional arguments for the loss function.
168+
"""
169+
len_chosen_chunk = target_chunk.shape[0] // 2
170+
171+
logits_chunk = input_chunk @ weight.t() # chunk_size x V
172+
if bias is not None:
173+
logits_chunk = logits_chunk + bias
174+
log_probs_chunk = F.log_softmax(logits_chunk.float(), dim=-1)
175+
176+
chosen_nll_loss = 0.0
177+
if compute_nll_loss:
178+
chosen_nll_loss = F.nll_loss(
179+
log_probs_chunk[:len_chosen_chunk].view(-1, log_probs_chunk.shape[-1]),
180+
target_chunk[:len_chosen_chunk].view(-1),
181+
reduction="sum",
182+
ignore_index=ignore_index,
183+
)
184+
chosen_nll_loss = (
185+
chosen_nll_loss
186+
/ (full_target[: full_target.shape[0] // 2] != ignore_index).sum()
187+
)
188+
189+
loss_mask = target_chunk != ignore_index
190+
label_chunk = torch.where(loss_mask, target_chunk, 0)
191+
192+
per_token_logps = log_probs_chunk.gather(-1, label_chunk.unsqueeze(-1)).squeeze(
193+
-1
194+
)
195+
average_log_prob = (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1)
196+
197+
chosen_logps = average_log_prob[:len_chosen_chunk]
198+
rejected_logps = average_log_prob[len_chosen_chunk:]
199+
200+
alignment_loss = preference_loss_fn(
201+
chosen_logps, rejected_logps, beta=beta, **loss_kwargs
202+
)
203+
alignment_loss = alignment_loss / (full_target.shape[0] // 2)
204+
205+
loss = chosen_nll_loss - alignment_loss
206+
return loss, (alignment_loss, chosen_logps, rejected_logps)

src/liger_kernel/chunked_loss/orpo_loss.py

Lines changed: 25 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
from functools import partial
2-
31
import torch
42
import torch.nn.functional as F
53

@@ -8,79 +6,24 @@
86
)
97

108

11-
def odds_ratio_loss(chosen_logps, rejected_logps, beta=0.1):
12-
"""
13-
Compute odds-ratio loss.
14-
Args:
15-
chosen_logps (torch.Tensor): Avg log probabilities of chosen tokens. Shape: (batch_size,).
16-
rejected_logps (torch.Tensor): Avg log probabilities of rejected tokens. Shape: (batch_size,).
17-
beta (float): Weight for the odds ratio loss.
18-
"""
19-
log_odds = (chosen_logps - rejected_logps) - (
20-
torch.log1p(-torch.exp(chosen_logps)) - torch.log1p(-torch.exp(rejected_logps))
21-
)
22-
ratio = F.logsigmoid(log_odds)
23-
return beta * ratio.sum()
24-
25-
26-
def _compute_orpo_loss(
27-
input_chunk,
28-
weight,
29-
target_chunk,
30-
bias=None,
31-
full_target=None,
32-
ignore_index=-100,
33-
beta=0.1,
34-
compute_nll_loss=True,
35-
):
36-
"""
37-
Compute ORPO loss for a chunk of input and target.
38-
Args:
39-
input_chunk (torch.Tensor): Chunk of input tensor. Shape: (2 * chunk_size, sequence_length, hidden_size).
40-
weight (torch.Tensor): Weight tensor. Shape: (vocab_size, hidden_size).
41-
target_chunk (torch.Tensor): Chunk of target tensor. Shape: (2 * chunk_size, sequence_length).
42-
bias (torch.Tensor, optional): Bias tensor. Shape: (vocab_size,).
43-
full_target (torch.Tensor): Full target tensor. Shape: (batch_size, sequence_length).
44-
ignore_index (int): Index to ignore for loss computation.
45-
beta (float): Weight for the odds ratio loss.
46-
"""
47-
len_chosen_chunk = target_chunk.shape[0] // 2
48-
49-
logits_chunk = input_chunk @ weight.t() # chunk_size x V
50-
if bias is not None:
51-
logits_chunk = logits_chunk + bias
52-
log_probs_chunk = F.log_softmax(logits_chunk.float(), dim=-1)
9+
class LigerFusedLinearORPOFunction(LigerFusedLinearPreferenceBase):
5310

54-
chosen_nll_loss = 0.0
55-
if compute_nll_loss:
56-
chosen_nll_loss = F.nll_loss(
57-
log_probs_chunk[:len_chosen_chunk].view(-1, log_probs_chunk.shape[-1]),
58-
target_chunk[:len_chosen_chunk].view(-1),
59-
reduction="sum",
60-
ignore_index=ignore_index,
61-
)
62-
chosen_nll_loss = (
63-
chosen_nll_loss
64-
/ (full_target[: full_target.shape[0] // 2] != ignore_index).sum()
11+
@staticmethod
12+
def preference_loss_fn(chosen_logps, rejected_logps, beta=0.1):
13+
"""
14+
Compute odds-ratio loss.
15+
Args:
16+
chosen_logps (torch.Tensor): Avg log probabilities of chosen tokens. Shape: (batch_size,).
17+
rejected_logps (torch.Tensor): Avg log probabilities of rejected tokens. Shape: (batch_size,).
18+
beta (float): Weight for the odds ratio loss.
19+
"""
20+
log_odds = (chosen_logps - rejected_logps) - (
21+
torch.log1p(-torch.exp(chosen_logps))
22+
- torch.log1p(-torch.exp(rejected_logps))
6523
)
24+
ratio = F.logsigmoid(log_odds)
25+
return beta * ratio.sum()
6626

67-
loss_mask = target_chunk != ignore_index
68-
label_chunk = torch.where(loss_mask, target_chunk, 0)
69-
70-
per_token_logps = log_probs_chunk.gather(-1, label_chunk.unsqueeze(-1)).squeeze(-1)
71-
average_log_prob = (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1)
72-
73-
chosen_logps = average_log_prob[:len_chosen_chunk]
74-
rejected_logps = average_log_prob[len_chosen_chunk:]
75-
76-
or_loss = odds_ratio_loss(chosen_logps, rejected_logps, beta=beta)
77-
or_loss = or_loss / (full_target.shape[0] // 2)
78-
79-
loss = chosen_nll_loss - or_loss
80-
return loss, (or_loss, chosen_logps, rejected_logps)
81-
82-
83-
class LigerFusedLinearORPOFunction(LigerFusedLinearPreferenceBase):
8427
@staticmethod
8528
def forward(
8629
ctx,
@@ -98,15 +41,18 @@ def forward(
9841
Handles both the forward and backward pass of the final linear layer with ORPO loss.
9942
Inspired from LigerFusedLinearCrossEntropyFunction (https://arxiv.org/abs/2410.10989) which fuses final linear layer and CE loss.
10043
"""
101-
orpo_loss_fn = partial(
102-
_compute_orpo_loss,
103-
full_target=target,
44+
45+
return LigerFusedLinearPreferenceBase.forward(
46+
ctx=ctx,
47+
_input=_input,
48+
weight=weight,
49+
target=target,
50+
bias=bias,
51+
loss_fn=LigerFusedLinearORPOFunction.preference_loss_fn,
52+
compute_nll_loss=compute_nll_loss,
10453
ignore_index=ignore_index,
10554
beta=beta,
106-
compute_nll_loss=compute_nll_loss,
107-
)
108-
return LigerFusedLinearPreferenceBase.forward(
109-
ctx, _input, weight, target, bias, loss_fn=orpo_loss_fn
55+
compiled=compiled,
11056
)
11157

11258
@staticmethod

0 commit comments

Comments
 (0)