Skip to content

Commit 96ba54e

Browse files
committed
Add discriptions of LigerJSD and modify the reference model with lerp
1 parent a5d0352 commit 96ba54e

File tree

4 files changed

+53
-14
lines changed

4 files changed

+53
-14
lines changed

benchmark/scripts/benchmark_jsd.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ def forward(
2727
):
2828
log_p, log_q = log_p.to(torch.float), log_q.to(torch.float)
2929
log_p, log_q = log_p.view(-1, log_p.size(-1)), log_q.view(-1, log_q.size(-1))
30-
m = self.beta * torch.exp(log_p) + (1 - self.beta) * torch.exp(log_q)
30+
m = torch.lerp(torch.exp(log_p), torch.exp(log_q), self.beta)
3131
loss = self.beta * self.kl(torch.log(m), log_p) + (1 - self.beta) * self.kl(
3232
torch.log(m), log_q
3333
)

src/liger_kernel/ops/jsd.py

Lines changed: 21 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -87,16 +87,18 @@ def jsd_backward(dX, grad_output):
8787

8888

8989
class LigerJSDFunction(torch.autograd.Function):
90-
"""
91-
Class implementing the forward and backward pass for the JS Divergence using Triton, as defined by the following formula:
92-
93-
Parameters:
94-
_input (tensor): predict values with shape (BT, V) in logspace
95-
target (tensor): ground truth values with shape (BT, V) in logspace
96-
beta (float): coefficient beta of generalized JSD in the open interval (0, 1)
97-
98-
Returns:
99-
loss (tensor): JSD
90+
r"""
91+
This class implements the forward and backward pass for the generalized Jensen-Shannon Divergence.
92+
.. math::
93+
JSD(\beta)(P || Q)
94+
= \beta * KLDiv(P || (\beta * P + (1 - \beta) * Q)) + (1 - \beta) * KLDiv(Q || (\beta * P + (1 - \beta) * Q))
95+
96+
.. note::
97+
As all the other losses in PyTorch, this function expects the first argument,
98+
:attr:`_input`, to be the predictions, the output of the student model, in log-space
99+
and the second, :attr:`target`, to be the observations, the output of the teacher model, in log-space.
100+
This differs from the standard mathematical notation :math:`JSD(P || Q)` where
101+
:math:`P` denotes the teacher model and :math:`Q` denotes the student model.
100102
"""
101103

102104
@staticmethod
@@ -107,7 +109,15 @@ def forward(
107109
target: torch.Tensor,
108110
beta: float = 0.5,
109111
) -> torch.Tensor:
110-
112+
"""
113+
Args:
114+
_input (torch.Tensor): predict values with shape (BT, V) in logspace
115+
target (torch.Tensor): ground truth values with shape (BT, V) in logspace
116+
beta (float): coefficient beta of generalized JSD in the open interval (0, 1)
117+
118+
Returns:
119+
loss (torch.Tensor): generalized JSD
120+
"""
111121
loss, dX = jsd_forward(_input, target, beta)
112122
ctx.save_for_backward(dX)
113123
return loss

src/liger_kernel/transformers/jsd.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,35 @@
44

55

66
class LigerJSD(nn.Module):
7+
r"""The generalized Jensen-Shannon Divergence.
8+
.. math::
9+
JSD(\beta)(P || Q)
10+
= \beta * KLDiv(P || (\beta * P + (1 - \beta) * Q)) + (1 - \beta) * KLDiv(Q || (\beta * P + (1 - \beta) * Q))
11+
.. note::
12+
As all the other losses in PyTorch, this function expects the first argument,
13+
:attr:`log_q`, to be the predictions, the output of the student model in log-space,
14+
and the second, :attr:`log_p`, to be the observations, the output of the teacher model in log-space.
15+
This differs from the standard mathematical notation :math:`JSD(P || Q)` where
16+
:math:`P` denotes the teacher model and :math:`Q` denotes the student model.
17+
18+
Args:
19+
beta (float): coefficient beta of generalized JSD in the open interval (0, 1). Default: `0.5`
20+
21+
Shape:
22+
- Input: :math:`(*)`, where :math:`*` means any number of dimensions.
23+
- Target: :math:`(*)`, same shape as the input.
24+
- Output: a scalar.
25+
26+
Examples:
27+
```python
28+
>>> jsd = LigerJSD(beta=0.1)
29+
>>> # input should be a distribution in the log space
30+
>>> input = torch.randn(3, 5, requires_grad=True).log_softmax(dim=-1)
31+
>>> target = torch.randn(3, 5, requires_grad=True).log_softmax(dim=-1)
32+
>>> output = jsd(input, target)
33+
```
34+
"""
35+
736
def __init__(self, beta=0.5):
837
super().__init__()
938
assert (

test/transformers/test_jsd.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ def forward(
2424
):
2525
log_p, log_q = log_p.to(torch.float), log_q.to(torch.float)
2626
log_p, log_q = log_p.view(-1, log_p.size(-1)), log_q.view(-1, log_q.size(-1))
27-
m = self.beta * torch.exp(log_p) + (1 - self.beta) * torch.exp(log_q)
27+
m = torch.lerp(torch.exp(log_q), torch.exp(log_p), self.beta)
2828
loss = self.beta * self.kl(torch.log(m), log_p) + (1 - self.beta) * self.kl(
2929
torch.log(m), log_q
3030
)
@@ -36,7 +36,7 @@ def forward(
3636
[
3737
(2, 4096, 32000), # llama2, mistral
3838
(2, 4096, 32000), # llama2, mistral
39-
# # weird shape
39+
# weird shape
4040
(41, 401, 1271),
4141
pytest.param(
4242
1,

0 commit comments

Comments
 (0)