Skip to content

Commit 6817c2d

Browse files
Tcc0403lancerts
andauthored
Add beta support for jsd (#290)
## Summary Resolve #278 . ## Details ### Forward: ```math \begin{align} JSD(X, Y, \beta) &= JSD_{\beta}(P \Vert Q)\\ &= \beta\ KL(P \Vert \beta P + (1-\beta)Q) + (1-\beta)\ KL(Q \Vert \beta P + (1-\beta)Q)\\ &= \sum \beta\ PY + (1-\beta)QX - M\ logM \end{align} ``` where $X=logQ$, $Y=logP$ and $M=\beta P + (1-\beta)Q$. ### Gradients: ```math \frac{\partial}{\partial X_i} JSD(X, Y, \beta) = (1-\beta)Q_i(X_i - logM_i) ``` ## Testing Done ![jsd_memory](https://github.com/user-attachments/assets/a26e1a64-df4b-49fe-8564-01a6757cb76a) ![jsd_speed](https://github.com/user-attachments/assets/6f631bdb-5abf-44ed-875b-2596f3a30b8b) - Hardware Type: H100 - [x] run `make test` to ensure correctness - [x] run `make checkstyle` to ensure code style - [x] run `make test-convergence` to ensure convergence --------- Co-authored-by: Shao Tang <[email protected]>
1 parent 60640e1 commit 6817c2d

File tree

8 files changed

+230
-75
lines changed

8 files changed

+230
-75
lines changed

README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -250,6 +250,7 @@ loss.backward()
250250
| CrossEntropy | `liger_kernel.transformers.LigerCrossEntropyLoss` |
251251
| FusedLinearCrossEntropy | `liger_kernel.transformers.LigerFusedLinearCrossEntropyLoss`|
252252
| KLDivergence | `liger_kernel.transformers.LigerKLDIVLoss` |
253+
| JSD | `liger_kernel.transformers.LigerJSD` |
253254

254255
- **RMSNorm**: [RMSNorm](https://arxiv.org/pdf/1910.07467), which normalizes activations using their root mean square, is implemented by fusing the normalization and scaling steps into a single Triton kernel, and achieves ~3X speedup with ~3X peak memory reduction.
255256
- **LayerNorm**: [LayerNorm](https://arxiv.org/pdf/1607.06450), which centers and normalizes activations across the feature dimension, is implemented by fusing the centering, normalization and scaling steps into a single Triton kernel, and achieves ~2X speedup.
@@ -264,6 +265,7 @@ $$\text{GeGLU}(x)=\text{GELU}(xW+b)\otimes(xV+c)$$
264265
<!-- TODO: verify vocab sizes are accurate -->
265266
- **FusedLinearCrossEntropy**: Peak memory usage of cross entropy loss is further improved by fusing the model head with the CE loss and chunking the input for block-wise loss and gradient calculation, a technique inspired by [Efficient Cross Entropy](https://github.com/mgmalek/efficient_cross_entropy). It achieves >4X memory reduction for 128k vocab size. **This is highly effective for large batch size, large sequence length, and large vocabulary sizes.** Please refer to the [Medusa example](https://github.com/linkedin/Liger-Kernel/tree/main/examples/medusa) for individual kernel usage.
266267
- **KLDivergence**: [KL Divergence](https://pytorch.org/docs/stable/generated/torch.nn.KLDivLoss.html) is implemented by fusing the forward into a single triton kernel, with reduction done outside the kernel. It achieves ~1.5X speed and ~15% memory reduction for 128K vocab size.
268+
- **JSD**: [Generalized JSD](https://arxiv.org/pdf/2306.13649) (Jensen-Shannon divergence), is implemented by computing both the loss and gradient in the forward pass. It achieves ~1.5X speed and ~54% memory reduction for 128k vocab size.
267269

268270
### Experimental Kernels
269271

benchmark/data/all_benchmark_data.csv

Lines changed: 36 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -445,39 +445,39 @@ kl_div,torch,full,speed,ms,V,vocab size,16384,11.124671936035156,11.122162818908
445445
kl_div,torch,full,speed,ms,V,vocab size,32768,23.052032470703125,23.050334930419922,23.052589416503906,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-09-04 12:59:48,0.2.1
446446
kl_div,torch,full,speed,ms,V,vocab size,65536,46.063167572021484,46.05990219116211,46.06643295288086,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-09-04 12:59:48,0.2.1
447447
kl_div,torch,full,speed,ms,V,vocab size,131072,92.06393432617188,92.06393432617188,92.06393432617188,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-09-04 12:59:48,0.2.1
448-
jsd,liger,full,memory,MB,V,vocab size,4096,768.0029296875,768.0029296875,768.0029296875,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-09-25 20:52:17,0.3.0
449-
jsd,liger,full,memory,MB,V,vocab size,8192,1536.0029296875,1536.0029296875,1536.0029296875,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-09-25 20:52:17,0.3.0
450-
jsd,liger,full,memory,MB,V,vocab size,16384,3072.0048828125,3072.0048828125,3072.0048828125,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-09-25 20:52:17,0.3.0
451-
jsd,liger,full,memory,MB,V,vocab size,32768,6144.0087890625,6144.0087890625,6144.0087890625,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-09-25 20:52:17,0.3.0
452-
jsd,liger,full,memory,MB,V,vocab size,65536,12288.0166015625,12288.0166015625,12288.0166015625,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-09-25 20:52:17,0.3.0
453-
jsd,liger,full,memory,MB,V,vocab size,131072,24576.015625,24576.015625,24576.015625,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-09-25 20:52:17,0.3.0
454-
jsd,torch,full,memory,MB,V,vocab size,4096,1664.0009765625,1664.0009765625,1664.0009765625,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-09-25 20:52:19,0.3.0
455-
jsd,torch,full,memory,MB,V,vocab size,8192,3328.0009765625,3328.0009765625,3328.0009765625,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-09-25 20:52:19,0.3.0
456-
jsd,torch,full,memory,MB,V,vocab size,16384,6656.0009765625,6656.0009765625,6656.0009765625,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-09-25 20:52:19,0.3.0
457-
jsd,torch,full,memory,MB,V,vocab size,32768,13312.0009765625,13312.0009765625,13312.0009765625,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-09-25 20:52:19,0.3.0
458-
jsd,torch,full,memory,MB,V,vocab size,65536,26624.0,26624.0,26624.0,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-09-25 20:52:19,0.3.0
459-
jsd,torch,full,memory,MB,V,vocab size,131072,53248.0,53248.0,53248.0,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-09-25 20:52:19,0.3.0
460-
jsd,liger,forward,speed,ms,V,vocab size,4096,0.4657920002937317,0.4644480049610138,0.4670400023460388,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-09-25 20:52:22,0.3.0
461-
jsd,liger,forward,speed,ms,V,vocab size,8192,0.9084159731864929,0.9064639806747437,0.9099519848823547,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-09-25 20:52:22,0.3.0
462-
jsd,liger,forward,speed,ms,V,vocab size,16384,9.939423561096191,9.933785438537598,9.945216178894043,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-09-25 20:52:22,0.3.0
463-
jsd,liger,forward,speed,ms,V,vocab size,32768,20.06915283203125,20.05768394470215,20.087200164794922,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-09-25 20:52:22,0.3.0
464-
jsd,liger,forward,speed,ms,V,vocab size,65536,38.88547134399414,38.880577087402344,38.89036560058594,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-09-25 20:52:22,0.3.0
465-
jsd,liger,forward,speed,ms,V,vocab size,131072,77.7418212890625,77.7418212890625,77.7418212890625,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-09-25 20:52:22,0.3.0
466-
jsd,torch,forward,speed,ms,V,vocab size,4096,2.1717119216918945,2.1697471141815186,2.173452854156494,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-09-25 20:52:23,0.3.0
467-
jsd,torch,forward,speed,ms,V,vocab size,8192,4.2592315673828125,4.255411148071289,4.2608771324157715,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-09-25 20:52:23,0.3.0
468-
jsd,torch,forward,speed,ms,V,vocab size,16384,8.363903999328613,8.359071731567383,8.36620807647705,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-09-25 20:52:23,0.3.0
469-
jsd,torch,forward,speed,ms,V,vocab size,32768,16.591264724731445,16.588390350341797,16.595033645629883,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-09-25 20:52:23,0.3.0
470-
jsd,torch,forward,speed,ms,V,vocab size,65536,33.06208038330078,33.06206130981445,33.06536102294922,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-09-25 20:52:23,0.3.0
471-
jsd,torch,forward,speed,ms,V,vocab size,131072,66.0923843383789,66.0923843383789,66.0923843383789,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-09-25 20:52:23,0.3.0
472-
jsd,liger,full,speed,ms,V,vocab size,4096,1.5683839321136475,1.4662528038024902,1.7244799137115479,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-09-25 20:52:25,0.3.0
473-
jsd,liger,full,speed,ms,V,vocab size,8192,2.0588159561157227,2.055116891860962,2.093465566635132,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-09-25 20:52:25,0.3.0
474-
jsd,liger,full,speed,ms,V,vocab size,16384,11.944751739501953,11.936684608459473,11.961983680725098,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-09-25 20:52:25,0.3.0
475-
jsd,liger,full,speed,ms,V,vocab size,32768,24.27791976928711,24.254375457763672,24.299558639526367,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-09-25 20:52:25,0.3.0
476-
jsd,liger,full,speed,ms,V,vocab size,65536,47.206687927246094,47.17191696166992,47.241458892822266,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-09-25 20:52:25,0.3.0
477-
jsd,liger,full,speed,ms,V,vocab size,131072,94.15420532226562,94.15420532226562,94.15420532226562,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-09-25 20:52:25,0.3.0
478-
jsd,torch,full,speed,ms,V,vocab size,4096,4.875328063964844,4.873446464538574,4.878073692321777,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-09-25 20:52:28,0.3.0
479-
jsd,torch,full,speed,ms,V,vocab size,8192,9.582816123962402,9.57910442352295,9.58505630493164,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-09-25 20:52:28,0.3.0
480-
jsd,torch,full,speed,ms,V,vocab size,16384,18.931264877319336,18.92802619934082,18.934911727905273,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-09-25 20:52:28,0.3.0
481-
jsd,torch,full,speed,ms,V,vocab size,32768,38.07579040527344,38.07549285888672,38.076087951660156,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-09-25 20:52:28,0.3.0
482-
jsd,torch,full,speed,ms,V,vocab size,65536,75.97628784179688,75.97628784179688,75.97628784179688,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-09-25 20:52:28,0.3.0
483-
jsd,torch,full,speed,ms,V,vocab size,131072,151.8501739501953,151.8501739501953,151.8501739501953,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-09-25 20:52:28,0.3.0
448+
jsd,liger,full,memory,MB,V,vocab size,4096,768.0029296875,768.0029296875,768.0029296875,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-02 16:21:31,0.3.1
449+
jsd,liger,full,memory,MB,V,vocab size,8192,1536.0029296875,1536.0029296875,1536.0029296875,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-02 16:21:31,0.3.1
450+
jsd,liger,full,memory,MB,V,vocab size,16384,3072.0048828125,3072.0048828125,3072.0048828125,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-02 16:21:31,0.3.1
451+
jsd,liger,full,memory,MB,V,vocab size,32768,6144.0087890625,6144.0087890625,6144.0087890625,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-02 16:21:31,0.3.1
452+
jsd,liger,full,memory,MB,V,vocab size,65536,12288.0166015625,12288.0166015625,12288.0166015625,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-02 16:21:31,0.3.1
453+
jsd,liger,full,memory,MB,V,vocab size,131072,24576.015625,24576.015625,24576.015625,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-02 16:21:31,0.3.1
454+
jsd,torch,full,memory,MB,V,vocab size,4096,1664.0009765625,1664.0009765625,1664.0009765625,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-02 16:21:33,0.3.1
455+
jsd,torch,full,memory,MB,V,vocab size,8192,3328.0009765625,3328.0009765625,3328.0009765625,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-02 16:21:33,0.3.1
456+
jsd,torch,full,memory,MB,V,vocab size,16384,6656.0009765625,6656.0009765625,6656.0009765625,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-02 16:21:33,0.3.1
457+
jsd,torch,full,memory,MB,V,vocab size,32768,13312.0009765625,13312.0009765625,13312.0009765625,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-02 16:21:33,0.3.1
458+
jsd,torch,full,memory,MB,V,vocab size,65536,26624.0,26624.0,26624.0,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-02 16:21:33,0.3.1
459+
jsd,torch,full,memory,MB,V,vocab size,131072,53248.0,53248.0,53248.0,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-02 16:21:33,0.3.1
460+
jsd,liger,forward,speed,ms,V,vocab size,4096,0.4651840031147003,0.4636736214160919,0.4659839868545532,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-02 16:21:37,0.3.1
461+
jsd,liger,forward,speed,ms,V,vocab size,8192,0.927888035774231,0.926751971244812,0.92952960729599,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-02 16:21:37,0.3.1
462+
jsd,liger,forward,speed,ms,V,vocab size,16384,10.96003246307373,10.942886352539062,10.970770835876465,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-02 16:21:37,0.3.1
463+
jsd,liger,forward,speed,ms,V,vocab size,32768,22.405792236328125,22.390380859375,22.41998863220215,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-02 16:21:37,0.3.1
464+
jsd,liger,forward,speed,ms,V,vocab size,65536,43.49095916748047,43.47438049316406,43.50754165649414,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-02 16:21:37,0.3.1
465+
jsd,liger,forward,speed,ms,V,vocab size,131072,87.0363540649414,87.0363540649414,87.0363540649414,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-02 16:21:37,0.3.1
466+
jsd,torch,forward,speed,ms,V,vocab size,4096,2.4744958877563477,2.4725184440612793,2.4764864444732666,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-02 16:21:38,0.3.1
467+
jsd,torch,forward,speed,ms,V,vocab size,8192,4.8528642654418945,4.851238250732422,4.854745864868164,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-02 16:21:38,0.3.1
468+
jsd,torch,forward,speed,ms,V,vocab size,16384,9.532496452331543,9.528634071350098,9.535890579223633,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-02 16:21:38,0.3.1
469+
jsd,torch,forward,speed,ms,V,vocab size,32768,18.91379165649414,18.911853790283203,18.919116973876953,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-02 16:21:38,0.3.1
470+
jsd,torch,forward,speed,ms,V,vocab size,65536,37.70152282714844,37.70074462890625,37.70229721069336,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-02 16:21:38,0.3.1
471+
jsd,torch,forward,speed,ms,V,vocab size,131072,75.37680053710938,75.37680053710938,75.37680053710938,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-02 16:21:38,0.3.1
472+
jsd,liger,full,speed,ms,V,vocab size,4096,1.2074079513549805,1.1739968061447144,1.2760319709777832,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-02 16:21:40,0.3.1
473+
jsd,liger,full,speed,ms,V,vocab size,8192,2.091792106628418,2.0771327018737793,2.106553554534912,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-02 16:21:40,0.3.1
474+
jsd,liger,full,speed,ms,V,vocab size,16384,12.928031921386719,12.8988676071167,12.936230659484863,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-02 16:21:40,0.3.1
475+
jsd,liger,full,speed,ms,V,vocab size,32768,26.55548858642578,26.550823211669922,26.570655822753906,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-02 16:21:40,0.3.1
476+
jsd,liger,full,speed,ms,V,vocab size,65536,51.6833610534668,51.6833610534668,51.6833610534668,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-02 16:21:40,0.3.1
477+
jsd,liger,full,speed,ms,V,vocab size,131072,103.12793731689453,103.12793731689453,103.12793731689453,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-02 16:21:40,0.3.1
478+
jsd,torch,full,speed,ms,V,vocab size,4096,5.397359848022461,5.392876625061035,5.39998722076416,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-02 16:21:43,0.3.1
479+
jsd,torch,full,speed,ms,V,vocab size,8192,10.60153579711914,10.597900390625,10.60470962524414,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-02 16:21:43,0.3.1
480+
jsd,torch,full,speed,ms,V,vocab size,16384,20.9442081451416,20.94247055053711,20.9469051361084,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-02 16:21:43,0.3.1
481+
jsd,torch,full,speed,ms,V,vocab size,32768,42.113216400146484,42.113216400146484,42.113216400146484,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-02 16:21:43,0.3.1
482+
jsd,torch,full,speed,ms,V,vocab size,65536,83.9959716796875,83.9959716796875,83.9959716796875,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-02 16:21:43,0.3.1
483+
jsd,torch,full,speed,ms,V,vocab size,131072,167.94175720214844,167.94175720214844,167.94175720214844,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-02 16:21:43,0.3.1

benchmark/scripts/benchmark_jsd.py

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13,17 +13,25 @@
1313
from liger_kernel.transformers.jsd import LigerJSD
1414

1515

16-
class TorchJSD(torch.nn.Module):
17-
def __init__(self):
16+
class TorchJSD(nn.Module):
17+
def __init__(self, beta: float = 0.5, dtype: torch.dtype = torch.float):
1818
super(TorchJSD, self).__init__()
1919
self.kl = nn.KLDivLoss(reduction="batchmean", log_target=True)
20-
21-
def forward(self, log_p: torch.tensor, log_q: torch.tensor):
20+
self.beta = beta
21+
self.dtype = dtype
22+
23+
def forward(
24+
self,
25+
log_q: torch.tensor, # input
26+
log_p: torch.tensor, # target
27+
):
28+
log_p, log_q = log_p.to(torch.float), log_q.to(torch.float)
2229
log_p, log_q = log_p.view(-1, log_p.size(-1)), log_q.view(-1, log_q.size(-1))
23-
m = 0.5 * (torch.exp(log_p) + torch.exp(log_q))
24-
log_m = torch.log(m)
25-
loss = 0.5 * (self.kl(log_m, log_p) + self.kl(log_m, log_q))
26-
return loss
30+
m = torch.lerp(torch.exp(log_p), torch.exp(log_q), self.beta)
31+
loss = self.beta * self.kl(torch.log(m), log_p) + (1 - self.beta) * self.kl(
32+
torch.log(m), log_q
33+
)
34+
return loss.to(self.dtype)
2735

2836

2937
def bench_speed_jsd(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput:

src/liger_kernel/ops/jsd.py

Lines changed: 31 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ def _jsd_kernel(
1515
loss_stride,
1616
dX_ptr,
1717
dX_stride,
18+
beta,
1819
n_rows,
1920
n_cols,
2021
BLOCK_SIZE: tl.constexpr,
@@ -37,20 +38,20 @@ def _jsd_kernel(
3738

3839
Q = tl.exp(X)
3940
P = tl.exp(Y)
40-
M = 0.5 * P + 0.5 * Q
41+
M = beta * P + (1 - beta) * Q
4142
log_M = tl.log(M)
4243

43-
loss = 0.5 * (P * Y + Q * X - 2 * M * log_M)
44+
loss = beta * P * Y + (1 - beta) * Q * X - M * log_M
4445
tl.store(loss_ptr + offsets, loss, mask=mask)
4546

46-
dX = 0.5 * Q * (X - log_M) / n_rows
47+
dX = (1 - beta) * Q * (X - log_M) / n_rows
4748
tl.store(dX_ptr + offsets, dX, mask=mask)
4849

4950

5051
MAX_FUSED_SIZE = 65536
5152

5253

53-
def jsd_forward(_input, target):
54+
def jsd_forward(_input, target, beta):
5455
BT, V = _input.shape
5556
n_rows = BT
5657
BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(V))
@@ -67,6 +68,7 @@ def jsd_forward(_input, target):
6768
loss_stride=loss.stride(-2),
6869
dX_ptr=dX,
6970
dX_stride=dX.stride(-2),
71+
beta=beta,
7072
n_rows=n_rows,
7173
n_cols=V,
7274
BLOCK_SIZE=BLOCK_SIZE,
@@ -77,23 +79,26 @@ def jsd_forward(_input, target):
7779

7880

7981
def jsd_backward(dX, grad_output):
80-
# If cross entropy is the last layer, grad_output is 1.0. Skip the mul to save time
82+
# If jsd is the last layer, grad_output is 1.0. Skip the mul to save time
8183
if torch.equal(grad_output, torch.tensor(1.0, device=grad_output.device)):
8284
return dX
8385
else:
8486
return grad_output * dX
8587

8688

8789
class LigerJSDFunction(torch.autograd.Function):
88-
"""
89-
Class implementing the forward and backward pass for the JS Divergence using Triton, as defined by the following formula:
90-
91-
Parameters:
92-
_input (tensor): predict values with shape (BT, V) in logspace
93-
target (tensor): gournd truth values with shape (BT, V) in logspace
94-
95-
Returns:
96-
loss (tensor): JSD
90+
r"""
91+
This class implements the forward and backward pass for the generalized Jensen-Shannon Divergence.
92+
.. math::
93+
JSD(\beta)(P || Q)
94+
= \beta * KLDiv(P || (\beta * P + (1 - \beta) * Q)) + (1 - \beta) * KLDiv(Q || (\beta * P + (1 - \beta) * Q))
95+
96+
.. note::
97+
As all the other losses in PyTorch, this function expects the first argument,
98+
:attr:`_input`, to be the predictions, the output of the student model, in log-space
99+
and the second, :attr:`target`, to be the observations, the output of the teacher model, in log-space.
100+
This differs from the standard mathematical notation :math:`JSD(P || Q)` where
101+
:math:`P` denotes the teacher model and :math:`Q` denotes the student model.
97102
"""
98103

99104
@staticmethod
@@ -102,9 +107,18 @@ def forward(
102107
ctx,
103108
_input: torch.Tensor,
104109
target: torch.Tensor,
110+
beta: float = 0.5,
105111
) -> torch.Tensor:
106-
107-
loss, dX = jsd_forward(_input, target)
112+
"""
113+
Args:
114+
_input (torch.Tensor): predict values with shape (BT, V) in logspace
115+
target (torch.Tensor): ground truth values with shape (BT, V) in logspace
116+
beta (float): coefficient beta of generalized JSD in the open interval (0, 1)
117+
118+
Returns:
119+
loss (torch.Tensor): generalized JSD
120+
"""
121+
loss, dX = jsd_forward(_input, target, beta)
108122
ctx.save_for_backward(dX)
109123
return loss
110124

@@ -116,4 +130,5 @@ def backward(ctx, grad_output: torch.Tensor) -> torch.Tensor:
116130
return (
117131
dX,
118132
None,
133+
None,
119134
)

src/liger_kernel/transformers/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
LigerFusedLinearCrossEntropyLoss,
77
)
88
from liger_kernel.transformers.geglu import LigerGEGLUMLP # noqa: F401
9+
from liger_kernel.transformers.jsd import LigerJSD # noqa: F401
910
from liger_kernel.transformers.layer_norm import LigerLayerNorm # noqa: F401
1011
from liger_kernel.transformers.monkey_patch import ( # noqa: F401
1112
_apply_liger_kernel,

0 commit comments

Comments
 (0)