Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,7 @@ loss.backward()
| CrossEntropy | `liger_kernel.transformers.LigerCrossEntropyLoss` |
| FusedLinearCrossEntropy | `liger_kernel.transformers.LigerFusedLinearCrossEntropyLoss`|
| KLDivergence | `liger_kernel.transformers.LigerKLDIVLoss` |
| JSD | `liger_kernel.transformers.LigerJSD` |

- **RMSNorm**: [RMSNorm](https://arxiv.org/pdf/1910.07467), which normalizes activations using their root mean square, is implemented by fusing the normalization and scaling steps into a single Triton kernel, and achieves ~3X speedup with ~3X peak memory reduction.
- **LayerNorm**: [LayerNorm](https://arxiv.org/pdf/1607.06450), which centers and normalizes activations across the feature dimension, is implemented by fusing the centering, normalization and scaling steps into a single Triton kernel, and achieves ~2X speedup.
Expand All @@ -264,6 +265,7 @@ $$\text{GeGLU}(x)=\text{GELU}(xW+b)\otimes(xV+c)$$
<!-- TODO: verify vocab sizes are accurate -->
- **FusedLinearCrossEntropy**: Peak memory usage of cross entropy loss is further improved by fusing the model head with the CE loss and chunking the input for block-wise loss and gradient calculation, a technique inspired by [Efficient Cross Entropy](https://github.com/mgmalek/efficient_cross_entropy). It achieves >4X memory reduction for 128k vocab size. **This is highly effective for large batch size, large sequence length, and large vocabulary sizes.** Please refer to the [Medusa example](https://github.com/linkedin/Liger-Kernel/tree/main/examples/medusa) for individual kernel usage.
- **KLDivergence**: [KL Divergence](https://pytorch.org/docs/stable/generated/torch.nn.KLDivLoss.html) is implemented by fusing the forward into a single triton kernel, with reduction done outside the kernel. It achieves ~1.5X speed and ~15% memory reduction for 128K vocab size.
- **JSD**: [Generalized JSD](https://arxiv.org/pdf/2306.13649) (Jensen-Shannon divergence), is implemented by computing both the loss and gradient in the forward pass. It achieves ~1.5X speed and ~54% memory reduction for 128k vocab size.

### Experimental Kernels

Expand Down
72 changes: 36 additions & 36 deletions benchmark/data/all_benchmark_data.csv
Original file line number Diff line number Diff line change
Expand Up @@ -445,39 +445,39 @@ kl_div,torch,full,speed,ms,V,vocab size,16384,11.124671936035156,11.122162818908
kl_div,torch,full,speed,ms,V,vocab size,32768,23.052032470703125,23.050334930419922,23.052589416503906,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-09-04 12:59:48,0.2.1
kl_div,torch,full,speed,ms,V,vocab size,65536,46.063167572021484,46.05990219116211,46.06643295288086,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-09-04 12:59:48,0.2.1
kl_div,torch,full,speed,ms,V,vocab size,131072,92.06393432617188,92.06393432617188,92.06393432617188,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-09-04 12:59:48,0.2.1
jsd,liger,full,memory,MB,V,vocab size,4096,768.0029296875,768.0029296875,768.0029296875,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-09-25 20:52:17,0.3.0
jsd,liger,full,memory,MB,V,vocab size,8192,1536.0029296875,1536.0029296875,1536.0029296875,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-09-25 20:52:17,0.3.0
jsd,liger,full,memory,MB,V,vocab size,16384,3072.0048828125,3072.0048828125,3072.0048828125,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-09-25 20:52:17,0.3.0
jsd,liger,full,memory,MB,V,vocab size,32768,6144.0087890625,6144.0087890625,6144.0087890625,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-09-25 20:52:17,0.3.0
jsd,liger,full,memory,MB,V,vocab size,65536,12288.0166015625,12288.0166015625,12288.0166015625,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-09-25 20:52:17,0.3.0
jsd,liger,full,memory,MB,V,vocab size,131072,24576.015625,24576.015625,24576.015625,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-09-25 20:52:17,0.3.0
jsd,torch,full,memory,MB,V,vocab size,4096,1664.0009765625,1664.0009765625,1664.0009765625,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-09-25 20:52:19,0.3.0
jsd,torch,full,memory,MB,V,vocab size,8192,3328.0009765625,3328.0009765625,3328.0009765625,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-09-25 20:52:19,0.3.0
jsd,torch,full,memory,MB,V,vocab size,16384,6656.0009765625,6656.0009765625,6656.0009765625,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-09-25 20:52:19,0.3.0
jsd,torch,full,memory,MB,V,vocab size,32768,13312.0009765625,13312.0009765625,13312.0009765625,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-09-25 20:52:19,0.3.0
jsd,torch,full,memory,MB,V,vocab size,65536,26624.0,26624.0,26624.0,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-09-25 20:52:19,0.3.0
jsd,torch,full,memory,MB,V,vocab size,131072,53248.0,53248.0,53248.0,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-09-25 20:52:19,0.3.0
jsd,liger,forward,speed,ms,V,vocab size,4096,0.4657920002937317,0.4644480049610138,0.4670400023460388,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-09-25 20:52:22,0.3.0
jsd,liger,forward,speed,ms,V,vocab size,8192,0.9084159731864929,0.9064639806747437,0.9099519848823547,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-09-25 20:52:22,0.3.0
jsd,liger,forward,speed,ms,V,vocab size,16384,9.939423561096191,9.933785438537598,9.945216178894043,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-09-25 20:52:22,0.3.0
jsd,liger,forward,speed,ms,V,vocab size,32768,20.06915283203125,20.05768394470215,20.087200164794922,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-09-25 20:52:22,0.3.0
jsd,liger,forward,speed,ms,V,vocab size,65536,38.88547134399414,38.880577087402344,38.89036560058594,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-09-25 20:52:22,0.3.0
jsd,liger,forward,speed,ms,V,vocab size,131072,77.7418212890625,77.7418212890625,77.7418212890625,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-09-25 20:52:22,0.3.0
jsd,torch,forward,speed,ms,V,vocab size,4096,2.1717119216918945,2.1697471141815186,2.173452854156494,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-09-25 20:52:23,0.3.0
jsd,torch,forward,speed,ms,V,vocab size,8192,4.2592315673828125,4.255411148071289,4.2608771324157715,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-09-25 20:52:23,0.3.0
jsd,torch,forward,speed,ms,V,vocab size,16384,8.363903999328613,8.359071731567383,8.36620807647705,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-09-25 20:52:23,0.3.0
jsd,torch,forward,speed,ms,V,vocab size,32768,16.591264724731445,16.588390350341797,16.595033645629883,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-09-25 20:52:23,0.3.0
jsd,torch,forward,speed,ms,V,vocab size,65536,33.06208038330078,33.06206130981445,33.06536102294922,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-09-25 20:52:23,0.3.0
jsd,torch,forward,speed,ms,V,vocab size,131072,66.0923843383789,66.0923843383789,66.0923843383789,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-09-25 20:52:23,0.3.0
jsd,liger,full,speed,ms,V,vocab size,4096,1.5683839321136475,1.4662528038024902,1.7244799137115479,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-09-25 20:52:25,0.3.0
jsd,liger,full,speed,ms,V,vocab size,8192,2.0588159561157227,2.055116891860962,2.093465566635132,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-09-25 20:52:25,0.3.0
jsd,liger,full,speed,ms,V,vocab size,16384,11.944751739501953,11.936684608459473,11.961983680725098,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-09-25 20:52:25,0.3.0
jsd,liger,full,speed,ms,V,vocab size,32768,24.27791976928711,24.254375457763672,24.299558639526367,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-09-25 20:52:25,0.3.0
jsd,liger,full,speed,ms,V,vocab size,65536,47.206687927246094,47.17191696166992,47.241458892822266,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-09-25 20:52:25,0.3.0
jsd,liger,full,speed,ms,V,vocab size,131072,94.15420532226562,94.15420532226562,94.15420532226562,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-09-25 20:52:25,0.3.0
jsd,torch,full,speed,ms,V,vocab size,4096,4.875328063964844,4.873446464538574,4.878073692321777,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-09-25 20:52:28,0.3.0
jsd,torch,full,speed,ms,V,vocab size,8192,9.582816123962402,9.57910442352295,9.58505630493164,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-09-25 20:52:28,0.3.0
jsd,torch,full,speed,ms,V,vocab size,16384,18.931264877319336,18.92802619934082,18.934911727905273,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-09-25 20:52:28,0.3.0
jsd,torch,full,speed,ms,V,vocab size,32768,38.07579040527344,38.07549285888672,38.076087951660156,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-09-25 20:52:28,0.3.0
jsd,torch,full,speed,ms,V,vocab size,65536,75.97628784179688,75.97628784179688,75.97628784179688,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-09-25 20:52:28,0.3.0
jsd,torch,full,speed,ms,V,vocab size,131072,151.8501739501953,151.8501739501953,151.8501739501953,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-09-25 20:52:28,0.3.0
jsd,liger,full,memory,MB,V,vocab size,4096,768.0029296875,768.0029296875,768.0029296875,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-02 16:21:31,0.3.1
jsd,liger,full,memory,MB,V,vocab size,8192,1536.0029296875,1536.0029296875,1536.0029296875,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-02 16:21:31,0.3.1
jsd,liger,full,memory,MB,V,vocab size,16384,3072.0048828125,3072.0048828125,3072.0048828125,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-02 16:21:31,0.3.1
jsd,liger,full,memory,MB,V,vocab size,32768,6144.0087890625,6144.0087890625,6144.0087890625,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-02 16:21:31,0.3.1
jsd,liger,full,memory,MB,V,vocab size,65536,12288.0166015625,12288.0166015625,12288.0166015625,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-02 16:21:31,0.3.1
jsd,liger,full,memory,MB,V,vocab size,131072,24576.015625,24576.015625,24576.015625,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-02 16:21:31,0.3.1
jsd,torch,full,memory,MB,V,vocab size,4096,1664.0009765625,1664.0009765625,1664.0009765625,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-02 16:21:33,0.3.1
jsd,torch,full,memory,MB,V,vocab size,8192,3328.0009765625,3328.0009765625,3328.0009765625,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-02 16:21:33,0.3.1
jsd,torch,full,memory,MB,V,vocab size,16384,6656.0009765625,6656.0009765625,6656.0009765625,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-02 16:21:33,0.3.1
jsd,torch,full,memory,MB,V,vocab size,32768,13312.0009765625,13312.0009765625,13312.0009765625,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-02 16:21:33,0.3.1
jsd,torch,full,memory,MB,V,vocab size,65536,26624.0,26624.0,26624.0,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-02 16:21:33,0.3.1
jsd,torch,full,memory,MB,V,vocab size,131072,53248.0,53248.0,53248.0,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-02 16:21:33,0.3.1
jsd,liger,forward,speed,ms,V,vocab size,4096,0.4651840031147003,0.4636736214160919,0.4659839868545532,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-02 16:21:37,0.3.1
jsd,liger,forward,speed,ms,V,vocab size,8192,0.927888035774231,0.926751971244812,0.92952960729599,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-02 16:21:37,0.3.1
jsd,liger,forward,speed,ms,V,vocab size,16384,10.96003246307373,10.942886352539062,10.970770835876465,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-02 16:21:37,0.3.1
jsd,liger,forward,speed,ms,V,vocab size,32768,22.405792236328125,22.390380859375,22.41998863220215,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-02 16:21:37,0.3.1
jsd,liger,forward,speed,ms,V,vocab size,65536,43.49095916748047,43.47438049316406,43.50754165649414,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-02 16:21:37,0.3.1
jsd,liger,forward,speed,ms,V,vocab size,131072,87.0363540649414,87.0363540649414,87.0363540649414,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-02 16:21:37,0.3.1
jsd,torch,forward,speed,ms,V,vocab size,4096,2.4744958877563477,2.4725184440612793,2.4764864444732666,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-02 16:21:38,0.3.1
jsd,torch,forward,speed,ms,V,vocab size,8192,4.8528642654418945,4.851238250732422,4.854745864868164,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-02 16:21:38,0.3.1
jsd,torch,forward,speed,ms,V,vocab size,16384,9.532496452331543,9.528634071350098,9.535890579223633,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-02 16:21:38,0.3.1
jsd,torch,forward,speed,ms,V,vocab size,32768,18.91379165649414,18.911853790283203,18.919116973876953,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-02 16:21:38,0.3.1
jsd,torch,forward,speed,ms,V,vocab size,65536,37.70152282714844,37.70074462890625,37.70229721069336,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-02 16:21:38,0.3.1
jsd,torch,forward,speed,ms,V,vocab size,131072,75.37680053710938,75.37680053710938,75.37680053710938,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-02 16:21:38,0.3.1
jsd,liger,full,speed,ms,V,vocab size,4096,1.2074079513549805,1.1739968061447144,1.2760319709777832,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-02 16:21:40,0.3.1
jsd,liger,full,speed,ms,V,vocab size,8192,2.091792106628418,2.0771327018737793,2.106553554534912,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-02 16:21:40,0.3.1
jsd,liger,full,speed,ms,V,vocab size,16384,12.928031921386719,12.8988676071167,12.936230659484863,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-02 16:21:40,0.3.1
jsd,liger,full,speed,ms,V,vocab size,32768,26.55548858642578,26.550823211669922,26.570655822753906,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-02 16:21:40,0.3.1
jsd,liger,full,speed,ms,V,vocab size,65536,51.6833610534668,51.6833610534668,51.6833610534668,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-02 16:21:40,0.3.1
jsd,liger,full,speed,ms,V,vocab size,131072,103.12793731689453,103.12793731689453,103.12793731689453,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-02 16:21:40,0.3.1
jsd,torch,full,speed,ms,V,vocab size,4096,5.397359848022461,5.392876625061035,5.39998722076416,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-02 16:21:43,0.3.1
jsd,torch,full,speed,ms,V,vocab size,8192,10.60153579711914,10.597900390625,10.60470962524414,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-02 16:21:43,0.3.1
jsd,torch,full,speed,ms,V,vocab size,16384,20.9442081451416,20.94247055053711,20.9469051361084,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-02 16:21:43,0.3.1
jsd,torch,full,speed,ms,V,vocab size,32768,42.113216400146484,42.113216400146484,42.113216400146484,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-02 16:21:43,0.3.1
jsd,torch,full,speed,ms,V,vocab size,65536,83.9959716796875,83.9959716796875,83.9959716796875,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-02 16:21:43,0.3.1
jsd,torch,full,speed,ms,V,vocab size,131072,167.94175720214844,167.94175720214844,167.94175720214844,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-02 16:21:43,0.3.1
24 changes: 16 additions & 8 deletions benchmark/scripts/benchmark_jsd.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,25 @@
from liger_kernel.transformers.jsd import LigerJSD


class TorchJSD(torch.nn.Module):
def __init__(self):
class TorchJSD(nn.Module):
def __init__(self, beta: float = 0.5, dtype: torch.dtype = torch.float):
super(TorchJSD, self).__init__()
self.kl = nn.KLDivLoss(reduction="batchmean", log_target=True)

def forward(self, log_p: torch.tensor, log_q: torch.tensor):
self.beta = beta
self.dtype = dtype

def forward(
self,
log_q: torch.tensor, # input
log_p: torch.tensor, # target
):
log_p, log_q = log_p.to(torch.float), log_q.to(torch.float)
log_p, log_q = log_p.view(-1, log_p.size(-1)), log_q.view(-1, log_q.size(-1))
m = 0.5 * (torch.exp(log_p) + torch.exp(log_q))
log_m = torch.log(m)
loss = 0.5 * (self.kl(log_m, log_p) + self.kl(log_m, log_q))
return loss
m = torch.lerp(torch.exp(log_p), torch.exp(log_q), self.beta)
loss = self.beta * self.kl(torch.log(m), log_p) + (1 - self.beta) * self.kl(
torch.log(m), log_q
)
return loss.to(self.dtype)


def bench_speed_jsd(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput:
Expand Down
47 changes: 31 additions & 16 deletions src/liger_kernel/ops/jsd.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ def _jsd_kernel(
loss_stride,
dX_ptr,
dX_stride,
beta,
n_rows,
n_cols,
BLOCK_SIZE: tl.constexpr,
Expand All @@ -37,20 +38,20 @@ def _jsd_kernel(

Q = tl.exp(X)
P = tl.exp(Y)
M = 0.5 * P + 0.5 * Q
M = beta * P + (1 - beta) * Q
log_M = tl.log(M)

loss = 0.5 * (P * Y + Q * X - 2 * M * log_M)
loss = beta * P * Y + (1 - beta) * Q * X - M * log_M
tl.store(loss_ptr + offsets, loss, mask=mask)

dX = 0.5 * Q * (X - log_M) / n_rows
dX = (1 - beta) * Q * (X - log_M) / n_rows
tl.store(dX_ptr + offsets, dX, mask=mask)


MAX_FUSED_SIZE = 65536


def jsd_forward(_input, target):
def jsd_forward(_input, target, beta):
BT, V = _input.shape
n_rows = BT
BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(V))
Expand All @@ -67,6 +68,7 @@ def jsd_forward(_input, target):
loss_stride=loss.stride(-2),
dX_ptr=dX,
dX_stride=dX.stride(-2),
beta=beta,
n_rows=n_rows,
n_cols=V,
BLOCK_SIZE=BLOCK_SIZE,
Expand All @@ -77,23 +79,26 @@ def jsd_forward(_input, target):


def jsd_backward(dX, grad_output):
# If cross entropy is the last layer, grad_output is 1.0. Skip the mul to save time
# If jsd is the last layer, grad_output is 1.0. Skip the mul to save time
if torch.equal(grad_output, torch.tensor(1.0, device=grad_output.device)):
return dX
else:
return grad_output * dX


class LigerJSDFunction(torch.autograd.Function):
"""
Class implementing the forward and backward pass for the JS Divergence using Triton, as defined by the following formula:

Parameters:
_input (tensor): predict values with shape (BT, V) in logspace
target (tensor): gournd truth values with shape (BT, V) in logspace

Returns:
loss (tensor): JSD
r"""
This class implements the forward and backward pass for the generalized Jensen-Shannon Divergence.
.. math::
JSD(\beta)(P || Q)
= \beta * KLDiv(P || (\beta * P + (1 - \beta) * Q)) + (1 - \beta) * KLDiv(Q || (\beta * P + (1 - \beta) * Q))

.. note::
As all the other losses in PyTorch, this function expects the first argument,
:attr:`_input`, to be the predictions, the output of the student model, in log-space
and the second, :attr:`target`, to be the observations, the output of the teacher model, in log-space.
This differs from the standard mathematical notation :math:`JSD(P || Q)` where
:math:`P` denotes the teacher model and :math:`Q` denotes the student model.
"""

@staticmethod
Expand All @@ -102,9 +107,18 @@ def forward(
ctx,
_input: torch.Tensor,
target: torch.Tensor,
beta: float = 0.5,
) -> torch.Tensor:

loss, dX = jsd_forward(_input, target)
"""
Args:
_input (torch.Tensor): predict values with shape (BT, V) in logspace
target (torch.Tensor): ground truth values with shape (BT, V) in logspace
beta (float): coefficient beta of generalized JSD in the open interval (0, 1)

Returns:
loss (torch.Tensor): generalized JSD
"""
loss, dX = jsd_forward(_input, target, beta)
ctx.save_for_backward(dX)
return loss

Expand All @@ -116,4 +130,5 @@ def backward(ctx, grad_output: torch.Tensor) -> torch.Tensor:
return (
dX,
None,
None,
)
1 change: 1 addition & 0 deletions src/liger_kernel/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
LigerFusedLinearCrossEntropyLoss,
)
from liger_kernel.transformers.geglu import LigerGEGLUMLP # noqa: F401
from liger_kernel.transformers.jsd import LigerJSD # noqa: F401
from liger_kernel.transformers.layer_norm import LigerLayerNorm # noqa: F401
from liger_kernel.transformers.monkey_patch import ( # noqa: F401
_apply_liger_kernel,
Expand Down
Loading
Loading