You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
commit ae7e13ba1eaf58e5066b5cd60dfddf4f66f3cfed
Merge: ede50df 280cb81
Author: Wizyoung <[email protected]>
Date: Thu Nov 7 15:58:13 2024 +0800
Merge branch 'linkedin:main' into main
commit 280cb8139511753ab3a16f286ebffe694ddd1970
Author: Haoyi Wu <[email protected]>
Date: Thu Nov 7 13:45:16 2024 +0800
Improve compatibility to access the base models (#340)
## Summary
<!--- This is a required section; please describe the main purpose of
this proposed code change. --->
This PR resolves #337, which improves the compatibility to access the
base models through the `base_model_prefix` attribute.
## Details
<!---
This is an optional section; is there anything specific that reviewers
should be aware of?
--->
One thing to mention: The `mllama` seems to be an outlier. It has text
model and vision model so it is impossible to access through one
attribute. Meanwhile, the `base_model_prefix` seems to have different
semantics for `mllama` model classes. I left the codes for `mllama`
unchanged.
For other models, I look into the `transformers` library and manually
check the correctness.
## Testing Done
<!--- This is a required section; please describe how this change was
tested. --->
The changes passed `test/transformers/test_monkey_patch.py` by running
`pytest`.
<!--
Replace BLANK with your device type. For example, A100-80G-PCIe
Complete the following tasks before sending your PR, and replace `[ ]`
with
`[x]` to indicate you have done them.
-->
- Hardware Type: RTX 3090
- [ ] run `make test` to ensure correctness
- [x] run `make checkstyle` to ensure code style
- [ ] run `make test-convergence` to ensure convergence
Co-authored-by: Byron Hsu <[email protected]>
commit ab5e88be1950aba248555e5e01907de04329e4dc
Author: Tcc0403 <[email protected]>
Date: Thu Nov 7 13:29:08 2024 +0800
Support Z Loss in CE (#239)
## Summary
This PR aims to resolve #197
Implemented z loss in LigerCrossEntropy.
note: `lse_square_scale` not exposed at flce yet, having issues passing
the tests.
## Details
### For loss:
```math
\begin{align}
L_{total} &= L_{ce} + z\_loss\
z\_loss &= lse\_square\_scale \cdot lse^2\
lse &= log \sum e^{X_i}
\end{align}
```
We can use $m = max(X_i)$ and $d = \sum e^{X_i - m}$, obtained from
online softmax algorithm, to calculate $lse$ directly.
```math
\begin{align}
lse &= log \sum e^{X_i}\
&= log \sum e^{X_i - m + m} = log \sum e^{X_i -m} \cdot e^m\
&= log\ e^m\sum e^{X_i - m} = m + d
\end{align}
```
### For gradients:
First, we calculate the derivative of lse
```math
\begin{align}
\frac{\partial}{\partial x_i}(lse) &= \frac{\partial}{\partial x_i}(log \sum e^{x_i}) \
&= \frac{1}{\sum e^{x_i}} \cdot \frac{\partial}{\partial x_i} \sum e^{x_i}\
&= \frac{e^{x_i}}{\sum e^{x_i}} = softmax(x_i).
\end{align}
```
Then we can obtain the derivative of z_loss by chain rule.
```math
\frac{\partial z\_loss}{\partial x_i} = \frac{\partial}{\partial x_i}\left( lse\_square\_scale \cdot lse^2\right) = 2\cdot lse\_square\_scale \cdot lse \cdot softmax(x_i),
```
and we have the derivative of cross entropy loss with label smoothing
```math
\frac{\partial L_{ce}}{\partial x_i} = softmax(x_i) - (1 - \epsilon)\delta_{k,y} + \frac{\epsilon}{K}= \begin{cases} softmax(x_i) - \frac{\epsilon}{K}, & i \neq y \\
softmax(x_i) - \frac{\epsilon}{K} - (1 - \epsilon) & i = y \end{cases}
```
where $\epsilon$ is label_smoothing and $K$ is the number of total
classes.
Thus, the derivative of total loss is
```math
\begin{align}
\frac{\partial}{\partial x_i}L_{total} &= \frac{\partial}{\partial x_i}L_{ce} + \frac{\partial}{\partial x_i}z\_loss\
&= softmax(x_i) - \frac{\epsilon}{K} - (1 - \epsilon)\delta_{k,y} + 2\cdot lse\_square\_scale \cdot lse \cdot softmax(x_i)\
&=\begin{cases} (1 + 2\cdot lse\_square\_scale \cdot lse)\ softmax(x_i) - \frac{\epsilon}{K}, & i \neq y\\
(1 + 2\cdot lse\_square\_scale \cdot lse)\ softmax(x_i) - \frac{\epsilon}{K} - (1 - \epsilon), & i = y \end{cases}
\end{align}
```
### Reference
[PaLM: Scaling Language Modeling with
Pathways](https://www.jmlr.org/papers/v24/22-1144.html)
[Chameleon: Mixed-Modal Early-Fusion Foundation
Models](https://arxiv.org/abs/2405.09818)
## Testing Done
[benchmark
gist](https://gist.github.com/Tcc0403/b9120282334196f66b5169d9f52bccaa)
neglectable error in speed benchmark.
This benchmark was done on my machine, which is probably not accurate.
```
liger ce: 66.123ms
Peak mem: 8.66200832
liger ce with zloss: 65.991ms
Peak mem: 8.66200832
liger ce with zloss with return zloss: 65.951ms
Peak mem: 8.662073856
```
- Hardware Type: <BLANK>
- [x] run `make test` to ensure correctness
- [x] run `make checkstyle` to ensure code style
- [x] run `make test-convergence` to ensure convergence
---------
Co-authored-by: Shao Tang <[email protected]>
Co-authored-by: Byron Hsu <[email protected]>
commit 85d34efbd423cd97d3e97525af419193fbb07354
Author: Pramodith Ballapuram <[email protected]>
Date: Wed Nov 6 17:44:54 2024 +0000
BUG: Fix bug in layer norm tests. (#359)
## Summary
This PR fixes a bug in a test case for layer norm, where the assert on
the gradient of x was incorrectly compared against itself meaning that
the assertion would always succeed.
## Testing Done
Tested on, A100-80G-SXM4
- Hardware Type: <BLANK>
- [X] run `make test` to ensure correctness
- [X] run `make checkstyle` to ensure code style
- [X] run `make test-convergence` to ensure convergence
commit c131f0423ccef96e71a13d58bda168f5904bfa89
Author: Byron Hsu <[email protected]>
Date: Tue Nov 5 16:50:38 2024 -0800
Update ci.yml
commit 985e6c74b61656061f28be74434a6de2de3aabfd
Author: Byron Hsu <[email protected]>
Date: Tue Nov 5 16:13:49 2024 -0800
Update ci.yml
commit a8c085488f3c47b86b2d560a1225bc27ec59c68d
Author: Byron Hsu <[email protected]>
Date: Tue Nov 5 15:58:11 2024 -0800
fixing ci
commit e985195bec82ea9d89b9d20a758356eee1650dc1
Author: Byron Hsu <[email protected]>
Date: Tue Nov 5 14:10:52 2024 -0800
Update pyproject.toml
commit 98d77e077d7bf8335a4a7748067ea8fc3633e3ef
Author: Byron Hsu <[email protected]>
Date: Tue Nov 5 14:05:27 2024 -0800
broadcast grad acc fix to all models (#354)
## Summary
<!--- This is a required section; please describe the main purpose of
this proposed code change. --->
<!---
## Details
This is an optional section; is there anything specific that reviewers
should be aware of?
--->
follow up for https://github.com/linkedin/Liger-Kernel/pull/339
However, identify few issues
1. revert patching causes flce not taking effect (comment out revert
patching for now, and only test float32)
2. qwen2 vl flce is broken. we should fix later
3. we should provide a real "on-instance" patch that does not use any
monkey patch. now the on-instance patch still relies on monkey patch
## Testing Done
<!--- This is a required section; please describe how this change was
tested. --->
<!--
Replace BLANK with your device type. For example, A100-80G-PCIe
Complete the following tasks before sending your PR, and replace `[ ]`
with
`[x]` to indicate you have done them.
-->
- Hardware Type: <BLANK>
- [ ] run `make test` to ensure correctness
- [ ] run `make checkstyle` to ensure code style
- [ ] run `make test-convergence` to ensure convergence
commit ef3f55dcd06b4fca95a5b75c9fe51ef1b7b7bfef
Author: Byron Hsu <[email protected]>
Date: Mon Nov 4 17:04:47 2024 -0800
merge two tests into one (#349)
## Summary
<!--- This is a required section; please describe the main purpose of
this proposed code change. --->
<!---
## Details
This is an optional section; is there anything specific that reviewers
should be aware of?
--->
remove the launching overhead of the 2nd container
## Testing Done
<!--- This is a required section; please describe how this change was
tested. --->
<!--
Replace BLANK with your device type. For example, A100-80G-PCIe
Complete the following tasks before sending your PR, and replace `[ ]`
with
`[x]` to indicate you have done them.
-->
- Hardware Type: <BLANK>
- [ ] run `make test` to ensure correctness
- [ ] run `make checkstyle` to ensure code style
- [ ] run `make test-convergence` to ensure convergence
commit b09fb65a37a045aa64e92b4d493897ba1c462ce8
Author: Byron Hsu <[email protected]>
Date: Mon Nov 4 16:40:52 2024 -0800
Trim conv test (#348)
## Summary
<!--- This is a required section; please describe the main purpose of
this proposed code change. --->
<!---
## Details
This is an optional section; is there anything specific that reviewers
should be aware of?
--->
Remove non flce convergence test since most users are using flce
## Testing Done
<!--- This is a required section; please describe how this change was
tested. --->
<!--
Replace BLANK with your device type. For example, A100-80G-PCIe
Complete the following tasks before sending your PR, and replace `[ ]`
with
`[x]` to indicate you have done them.
-->
- Hardware Type: <BLANK>
- [ ] run `make test` to ensure correctness
- [ ] run `make checkstyle` to ensure code style
- [ ] run `make test-convergence` to ensure convergence
commit fbcb52d615f46f54ce865cec028ce5c64a205a2a
Author: ByronHsu <[email protected]>
Date: Mon Nov 4 22:54:09 2024 +0000
Move dependent license to a folder
commit a2dfa3cb2f7b6f0e23a65ad76b38a6b567404a2c
Author: Byron Hsu <[email protected]>
Date: Mon Nov 4 14:04:40 2024 -0800
Aggressively trim test bloat (#346)
## Summary
<!--- This is a required section; please describe the main purpose of
this proposed code change. --->
<!---
## Details
This is an optional section; is there anything specific that reviewers
should be aware of?
--->
1. Disable the test for experimental kernels
2. Reduce the size of tensor if the tests takes too long
3. Remove redundant tests that are testing the same thing
Make sure unit test time < 5 mins
## Testing Done
<!--- This is a required section; please describe how this change was
tested. --->
<!--
Replace BLANK with your device type. For example, A100-80G-PCIe
Complete the following tasks before sending your PR, and replace `[ ]`
with
`[x]` to indicate you have done them.
-->
- Hardware Type: <BLANK>
- [ ] run `make test` to ensure correctness
- [ ] run `make checkstyle` to ensure code style
- [ ] run `make test-convergence` to ensure convergence
commit e68b291f11d2f1ab22c5db9b1038021ee1821a0e
Author: Byron Hsu <[email protected]>
Date: Mon Nov 4 13:14:38 2024 -0800
avoid duplicate ci (#345)
## Summary
<!--- This is a required section; please describe the main purpose of
this proposed code change. --->
<!---
## Details
This is an optional section; is there anything specific that reviewers
should be aware of?
--->
## Testing Done
<!--- This is a required section; please describe how this change was
tested. --->
<!--
Replace BLANK with your device type. For example, A100-80G-PCIe
Complete the following tasks before sending your PR, and replace `[ ]`
with
`[x]` to indicate you have done them.
-->
- Hardware Type: <BLANK>
- [ ] run `make test` to ensure correctness
- [ ] run `make checkstyle` to ensure code style
- [ ] run `make test-convergence` to ensure convergence
commit c34843c45eb8c3501d54f506fa359401e06d0166
Author: Byron Hsu <[email protected]>
Date: Mon Nov 4 13:08:19 2024 -0800
set up modal ci (#344)
## Summary
<!--- This is a required section; please describe the main purpose of
this proposed code change. --->
<!---
## Details
This is an optional section; is there anything specific that reviewers
should be aware of?
--->
follow https://github.com/modal-labs/ci-on-modal
## Testing Done
<!--- This is a required section; please describe how this change was
tested. --->
<!--
Replace BLANK with your device type. For example, A100-80G-PCIe
Complete the following tasks before sending your PR, and replace `[ ]`
with
`[x]` to indicate you have done them.
-->
- Hardware Type: <BLANK>
- [ ] run `make test` to ensure correctness
- [ ] run `make checkstyle` to ensure code style
- [ ] run `make test-convergence` to ensure convergence
commit ac7b38a2fdd3368b648d5ee02f6c0fb8661d8005
Author: TJian <[email protected]>
Date: Sun Nov 3 01:07:39 2024 +0800
[AMD] [ROCm] Pick `num_warps` based on platform (#326)
## Summary
<!--- This is a required section; please describe the main purpose of
this proposed code change. --->
This is a PR to enable the kernel to run on AMD GPUs through the initial
changes to the `num_warps`.
This change is proposed by @Edenzzzz and @DocShotgun in this issue
https://github.com/linkedin/Liger-Kernel/issues/266
## Details
<!---
This is an optional section; is there anything specific that reviewers
should be aware of?
--->
I have updated the `transformers` version from `4.44.0` to `4.46.0`
requirement and all unit tests passed on A100 and MI300X.
## Testing Done
<!--- This is a required section; please describe how this change was
tested. --->
<!--
Replace BLANK with your device type. For example, A100-80G-PCIe
Complete the following tasks before sending your PR, and replace `[ ]`
with
`[x]` to indicate you have done them.
-->
- Hardware Type: AMD Instinct MI300X
- [x] run `make test` to ensure correctness
- There are some test failed due to numerical precision issue. Passed by
relaxing the condition by 1 order of magnitude (following the advice in
the Liger-Kernel technical report
https://arxiv.org/pdf/[2410.10989](https://arxiv.org/pdf/2410.10989)
**Footnote 12:** _Note that in practice, the tolerance may need further
relaxation in some cases by one or two orders of magnitude, even for
exact kernels. We use convergence tests to ensure exactness in cases
where the tolerance for correctness needs to be loose._ )
- The test that the tolerance are relaxed involves `kl_div` and `jsd` in
`float32` tests
- The relax conditions are described by the following code snippet
```
_DTYPE_PARAMS = (
"dtype, atol, rtol",
[
pytest.param(
torch.bfloat16,
1e-8,
5e-2,
marks=pytest.mark.skipif(
not supports_bfloat16(), reason="bfloat16 not supported on this GPU"
),
),
(torch.float32, 1e-8 if not is_hip() else 1e-7, 1e-6),
(torch.float16, 1e-3, 1e-3),
],
)
```
- To pass the test, the triton must not be installed from source, it
must be installed through pypi `pip install triton==3.0.0`. This issue
will be tracked with an issue at triton
https://github.com/triton-lang/triton/issues/5013 .
- ~~Something is weird as well, if I just run the failed test
`test/transformers/test_cross_entropy.py::test_correctness_with_ignore_index[10.0-dtype5-1e-08-1e-06-sum-2-4096-32000--100]`,
the test passed. By running `pytest
test/transformers/test_cross_entropy.py::test_correctness_with_ignore_index[10.0-dtype5-1e-08-1e-06-sum-2-4096-32000--100]`.
However it will failed if there are other tests running before this
test.~~
- [x] run `make checkstyle` to ensure code style
- [x] run `make test-convergence` to ensure convergence
<details>
<summary> <s>Failure Test Logs (Click to expand/collapse) </s>
</summary>
```bash
============================================================= FAILURES =============================================================
________________________ test_correctness_with_ignore_index[10.0-dtype5-1e-08-1e-06-sum-2-4096-32000--100] _________________________
B = 2, T = 4096, V = 32000, ignore_index = -100, reduction = 'sum', scalar = 10.0, dtype = torch.float32, atol = 1e-08, rtol = 1e-06
@pytest.mark.parametrize(
"B, T, V, ignore_index",
[
(2, 4096, 32000, -100), # llama2, mistral
(2, 4096, 32000, 2), # llama2, mistral
(1, 4096, 128256, -300), # llama3
# weird shapes
(3, 423, 32000, -123),
],
)
@pytest.mark.parametrize("reduction", ["sum", "mean"])
@pytest.mark.parametrize(
"scalar, dtype, atol, rtol",
[
pytest.param(
0.1,
torch.bfloat16,
1e-8,
5e-2,
marks=pytest.mark.skipif(
not supports_bfloat16(), reason="bfloat16 not supported on this GPU"
),
),
pytest.param(
1.0,
torch.bfloat16,
1e-8,
5e-2,
marks=pytest.mark.skipif(
not supports_bfloat16(), reason="bfloat16 not supported on this GPU"
),
),
pytest.param(
10.0,
torch.bfloat16,
1e-8,
5e-2,
marks=pytest.mark.skipif(
not supports_bfloat16(), reason="bfloat16 not supported on this GPU"
),
),
(0.1, torch.float32, 1e-8, 1e-6),
(1.0, torch.float32, 1e-8, 1e-6),
(10.0, torch.float32, 1e-8, 1e-6),
],
)
@pytest.mark.skipif(
torch.cuda.get_device_properties(0).total_memory < 16 * 1000 * 1000 * 1000,
reason="Needs 16GB+ GPU memory.",
)
def test_correctness_with_ignore_index(
B, T, V, ignore_index, reduction, scalar, dtype, atol, rtol
):
liger_ce = LigerCrossEntropyLoss(ignore_index=ignore_index, reduction=reduction)
> _test_correctness_with_ignore_index_once(
liger_ce, B, T, V, ignore_index, reduction, scalar, dtype, atol, rtol
)
test/transformers/test_cross_entropy.py:302:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
target_ce = LigerCrossEntropyLoss(), B = 2, T = 4096, V = 32000, ignore_index = -100, reduction = 'sum', scalar = 10.0
dtype = torch.float32, atol = 1e-08, rtol = 1e-06
def _test_correctness_with_ignore_index_once(
target_ce, B, T, V, ignore_index, reduction, scalar, dtype, atol, rtol
):
torch_ce = CrossEntropyLoss(ignore_index=ignore_index, reduction=reduction)
_tensor = torch.randn(B * T, V, device="cuda", dtype=dtype) * scalar
_input = _tensor.detach().clone().requires_grad_(True)
_input2 = _tensor.detach().clone().requires_grad_(True)
target = torch.randint(0, V, (B * T,), device="cuda", dtype=torch.long)
# Assign some random number of elements as ignore_index
num_elements_to_assign = torch.randint(
1, B * T // 2, (1,)
).item() # Random number of elements to set to ignore_index
indices_to_assign = torch.randperm(B * T)[
:num_elements_to_assign
] # Randomly select indices
target[indices_to_assign] = ignore_index
output = torch_ce(_input, target)
output2 = target_ce(_input2, target)
assert torch.allclose(output, output2, atol=atol, rtol=rtol)
output.backward()
output2.backward()
> assert torch.allclose(_input.grad, _input2.grad, atol=atol, rtol=rtol)
E AssertionError: assert False
E + where False = <built-in method allclose of type object at 0x7035c99e82c0>(tensor([[4.0225e-16, 3.9353e-17, 6.6616e-19, ..., 1.3759e-13, 7.6381e-10,\n 4.4185e-23],\n [2.9569e-12, 3.8580e-19, 5.3756e-16, ..., 6.0166e-23, 1.4681e-17,\n 5.1994e-20],\n [4.7900e-26, 1.0599e-04, 7.0237e-19, ..., 1.1461e-20, 1.0415e-10,\n 1.0237e-19],\n ...,\n [6.9540e-17, 3.4471e-22, 2.7309e-14, ..., 2.5999e-26, 2.5635e-19,\n 7.0793e-16],\n [6.3721e-23, 1.2054e-13, 1.8638e-20, ..., 1.2807e-23, 5.5705e-16,\n 2.3085e-13],\n [1.9623e-20, 2.4720e-11, 1.8808e-15, ..., 3.5100e-20, 3.6195e-15,\n 1.5356e-23]], device='cuda:0'), tensor([[4.0225e-16, 3.9353e-17, 6.6616e-19, ..., 1.3759e-13, 7.6381e-10,\n 4.4185e-23],\n [2.9569e-12, 3.8580e-19, 5.3756e-16, ..., 6.0166e-23, 1.4681e-17,\n 5.1994e-20],\n [4.7900e-26, 1.0599e-04, 7.0237e-19, ..., 1.1461e-20, 1.0415e-10,\n 1.0237e-19],\n ...,\n [6.9540e-17, 3.4471e-22, 2.7309e-14, ..., 2.5999e-26, 2.5635e-19,\n 7.0793e-16],\n [6.3722e-23, 1.2054e-13, 1.8638e-20, ..., 1.2807e-23, 5.5705e-16,\n 2.3085e-13],\n [1.9623e-20, 2.4720e-11, 1.8808e-15, ..., 3.5100e-20, 3.6195e-15,\n 1.5356e-23]], device='cuda:0'), atol=1e-08, rtol=1e-06)
E + where <built-in method allclose of type object at 0x7035c99e82c0> = torch.allclose
E + and tensor([[4.0225e-16, 3.9353e-17, 6.6616e-19, ..., 1.3759e-13, 7.6381e-10,\n 4.4185e-23],\n [2.9569e-12, 3.8580e-19, 5.3756e-16, ..., 6.0166e-23, 1.4681e-17,\n 5.1994e-20],\n [4.7900e-26, 1.0599e-04, 7.0237e-19, ..., 1.1461e-20, 1.0415e-10,\n 1.0237e-19],\n ...,\n [6.9540e-17, 3.4471e-22, 2.7309e-14, ..., 2.5999e-26, 2.5635e-19,\n 7.0793e-16],\n [6.3721e-23, 1.2054e-13, 1.8638e-20, ..., 1.2807e-23, 5.5705e-16,\n 2.3085e-13],\n [1.9623e-20, 2.4720e-11, 1.8808e-15, ..., 3.5100e-20, 3.6195e-15,\n 1.5356e-23]], device='cuda:0') = tensor([[ 6.0503, 3.7258, -0.3530, ..., 11.8853, 20.5071, -9.9739],\n [ 15.2597, -0.5924, 6.6471, ..., -9.3584, 3.0466, -2.5966],\n [-17.9122, 31.2363, -1.4114, ..., -5.5268, 17.4033, -3.3372],\n ...,\n [ 4.3242, -7.8904, 10.2973, ..., -17.3829, -1.2789, 6.6447],\n [-10.9055, 10.4553, -5.2270, ..., -12.5100, 5.0782, 11.1050],\n [ -5.8922, 15.0620, 5.5783, ..., -5.3107, 6.2329, -13.0452]],\n device='cuda:0', requires_grad=True).grad
E + and tensor([[4.0225e-16, 3.9353e-17, 6.6616e-19, ..., 1.3759e-13, 7.6381e-10,\n 4.4185e-23],\n [2.9569e-12, 3.8580e-19, 5.3756e-16, ..., 6.0166e-23, 1.4681e-17,\n 5.1994e-20],\n [4.7900e-26, 1.0599e-04, 7.0237e-19, ..., 1.1461e-20, 1.0415e-10,\n 1.0237e-19],\n ...,\n [6.9540e-17, 3.4471e-22, 2.7309e-14, ..., 2.5999e-26, 2.5635e-19,\n 7.0793e-16],\n [6.3722e-23, 1.2054e-13, 1.8638e-20, ..., 1.2807e-23, 5.5705e-16,\n 2.3085e-13],\n [1.9623e-20, 2.4720e-11, 1.8808e-15, ..., 3.5100e-20, 3.6195e-15,\n 1.5356e-23]], device='cuda:0') = tensor([[4.0225e-16, 3.9353e-17, 6.6616e-19, ..., 1.3759e-13, 7.6381e-10,\n 4.4185e-23],\n [2.9569e-12, 3.8580e-19, 5.3756e-16, ..., 6.0166e-23, 1.4681e-17,\n 5.1994e-20],\n [4.7900e-26, 1.0599e-04, 7.0237e-19, ..., 1.1461e-20, 1.0415e-10,\n 1.0237e-19],\n ...,\n [6.9540e-17, 3.4471e-22, 2.7309e-14, ..., 2.5999e-26, 2.5635e-19,\n 7.0793e-16],\n [6.3722e-23, 1.2054e-13, 1.8638e-20, ..., 1.2807e-23, 5.5705e-16,\n 2.3085e-13],\n [1.9623e-20, 2.4720e-11, 1.8808e-15, ..., 3.5100e-20, 3.6195e-15,\n 1.5356e-23]], device='cuda:0', requires_grad=True).grad
test/transformers/test_cross_entropy.py:61: AssertionError
_________________________________ test_correctness_with_beta[0.1-dtype1-1e-08-1e-06-1-4096-128256] _________________________________
B = 1, T = 4096, V = 128256, beta = 0.1, dtype = torch.float32, atol = 1e-08, rtol = 1e-06
@pytest.mark.parametrize(*_SHAPE_PARAMS)
@pytest.mark.parametrize(*_DTYPE_PARAMS)
@pytest.mark.parametrize("beta", [0.1, 0.5, 0.9])
def test_correctness_with_beta(B, T, V, beta, dtype, atol, rtol):
liger_jsd = LigerJSD(beta=beta)
> _test_correctness_with_beta_once(liger_jsd, beta, B, T, V, dtype, atol, rtol)
test/transformers/test_jsd.py:269:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
test/transformers/test_jsd.py:157: in _test_correctness_with_beta_once
assert_verbose_allclose(output, output2, atol=atol, rtol=rtol)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
tensor1 = tensor(0.0805, device='cuda:0', grad_fn=<SumBackward0>)
tensor2 = tensor(0.0805, device='cuda:0', grad_fn=<LigerJSDFunctionBackward>), rtol = 1e-06, atol = 1e-08, max_print = 5
def assert_verbose_allclose(tensor1, tensor2, rtol=1e-05, atol=1e-08, max_print=5):
"""
Assert that two tensors are element-wise equal within a tolerance, providing detailed information about mismatches.
Parameters:
tensor1 (torch.Tensor): First tensor to compare.
tensor2 (torch.Tensor): Second tensor to compare.
rtol (float): Relative tolerance.
atol (float): Absolute tolerance.
max_print (int): Maximum number of mismatched elements to print.
Raises:
AssertionError: If the tensors are not all close within the given tolerance.
"""
# Check if the shapes of the tensors match
if tensor1.shape != tensor2.shape:
raise AssertionError("Input tensors must have the same shape.")
# Calculate the difference between the tensors
diff = torch.abs(tensor1 - tensor2)
# Determine the tolerance
tolerance = atol + rtol * torch.abs(tensor2)
# Find tolerance mismatched elements
tol_mismatched = diff > tolerance
# Find nan mismatched elements
nan_mismatched = torch.logical_xor(torch.isnan(tensor1), torch.isnan(tensor2))
# Find +inf mismatched elements
posinf_mismatched = torch.logical_xor(
torch.isposinf(tensor1), torch.isposinf(tensor2)
)
# Find -inf mismatched elements
neginf_mismatched = torch.logical_xor(
torch.isneginf(tensor1), torch.isneginf(tensor2)
)
# Find all mismatched elements
mismatched = torch.logical_or(
torch.logical_or(tol_mismatched, nan_mismatched),
torch.logical_or(posinf_mismatched, neginf_mismatched),
)
mismatched_indices = torch.nonzero(mismatched)
# Count the number of mismatched elements
num_mismatched = mismatched.sum().item()
# Check if all elements are close
all_close = num_mismatched == 0
# Raise AssertionError with detailed information if there are mismatches
if not all_close and num_mismatched >= 1:
mismatch_details = [f"Number of mismatched elements: {num_mismatched}"]
print_count = min(max_print, num_mismatched)
for index in mismatched_indices[:print_count]:
i = tuple(index.tolist())
mismatch_details.append(
f"Mismatch at index {i}: tensor1[{i}] = {tensor1[i]}, tensor2[{i}] = {tensor2[i]}"
)
if num_mismatched > max_print:
mismatch_details.append(
f"... and {num_mismatched - max_print} more mismatched elements."
)
> raise AssertionError("\n".join(mismatch_details))
E AssertionError: Number of mismatched elements: 1
E Mismatch at index (): tensor1[()] = 0.08054989576339722, tensor2[()] = 0.08054977655410767
test/utils.py:106: AssertionError
_________________________________ test_correctness_with_beta[0.9-dtype1-1e-08-1e-06-1-4096-128256] _________________________________
B = 1, T = 4096, V = 128256, beta = 0.9, dtype = torch.float32, atol = 1e-08, rtol = 1e-06
@pytest.mark.parametrize(*_SHAPE_PARAMS)
@pytest.mark.parametrize(*_DTYPE_PARAMS)
@pytest.mark.parametrize("beta", [0.1, 0.5, 0.9])
def test_correctness_with_beta(B, T, V, beta, dtype, atol, rtol):
liger_jsd = LigerJSD(beta=beta)
> _test_correctness_with_beta_once(liger_jsd, beta, B, T, V, dtype, atol, rtol)
test/transformers/test_jsd.py:269:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
test/transformers/test_jsd.py:157: in _test_correctness_with_beta_once
assert_verbose_allclose(output, output2, atol=atol, rtol=rtol)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
tensor1 = tensor(0.0805, device='cuda:0', grad_fn=<SumBackward0>)
tensor2 = tensor(0.0805, device='cuda:0', grad_fn=<LigerJSDFunctionBackward>), rtol = 1e-06, atol = 1e-08, max_print = 5
def assert_verbose_allclose(tensor1, tensor2, rtol=1e-05, atol=1e-08, max_print=5):
"""
Assert that two tensors are element-wise equal within a tolerance, providing detailed information about mismatches.
Parameters:
tensor1 (torch.Tensor): First tensor to compare.
tensor2 (torch.Tensor): Second tensor to compare.
rtol (float): Relative tolerance.
atol (float): Absolute tolerance.
max_print (int): Maximum number of mismatched elements to print.
Raises:
AssertionError: If the tensors are not all close within the given tolerance.
"""
# Check if the shapes of the tensors match
if tensor1.shape != tensor2.shape:
raise AssertionError("Input tensors must have the same shape.")
# Calculate the difference between the tensors
diff = torch.abs(tensor1 - tensor2)
# Determine the tolerance
tolerance = atol + rtol * torch.abs(tensor2)
# Find tolerance mismatched elements
tol_mismatched = diff > tolerance
# Find nan mismatched elements
nan_mismatched = torch.logical_xor(torch.isnan(tensor1), torch.isnan(tensor2))
# Find +inf mismatched elements
posinf_mismatched = torch.logical_xor(
torch.isposinf(tensor1), torch.isposinf(tensor2)
)
# Find -inf mismatched elements
neginf_mismatched = torch.logical_xor(
torch.isneginf(tensor1), torch.isneginf(tensor2)
)
# Find all mismatched elements
mismatched = torch.logical_or(
torch.logical_or(tol_mismatched, nan_mismatched),
torch.logical_or(posinf_mismatched, neginf_mismatched),
)
mismatched_indices = torch.nonzero(mismatched)
# Count the number of mismatched elements
num_mismatched = mismatched.sum().item()
# Check if all elements are close
all_close = num_mismatched == 0
# Raise AssertionError with detailed information if there are mismatches
if not all_close and num_mismatched >= 1:
mismatch_details = [f"Number of mismatched elements: {num_mismatched}"]
print_count = min(max_print, num_mismatched)
for index in mismatched_indices[:print_count]:
i = tuple(index.tolist())
mismatch_details.append(
f"Mismatch at index {i}: tensor1[{i}] = {tensor1[i]}, tensor2[{i}] = {tensor2[i]}"
)
if num_mismatched > max_print:
mismatch_details.append(
f"... and {num_mismatched - max_print} more mismatched elements."
)
> raise AssertionError("\n".join(mismatch_details))
E AssertionError: Number of mismatched elements: 1
E Mismatch at index (): tensor1[()] = 0.08054172992706299, tensor2[()] = 0.08054161071777344
test/utils.py:106: AssertionError
___________________________________ test_correctness[dtype1-1e-08-1e-06-none-False-32-4096-1024] ___________________________________
B = 32, T = 4096, V = 1024, log_target = False, reduction = 'none', dtype = torch.float32, atol = 1e-08, rtol = 1e-06
@pytest.mark.parametrize(*_SHAPE_PARAMS)
@pytest.mark.parametrize("log_target", [True, False])
@pytest.mark.parametrize("reduction", ["batchmean", "sum", "mean", "none"])
@pytest.mark.parametrize(*_DTYPE_PARAMS)
def test_correctness(B, T, V, log_target, reduction, dtype, atol, rtol):
liger_kldiv = LigerKLDIVLoss(reduction=reduction, log_target=log_target)
> _test_correctness_once(
liger_kldiv, B, T, V, dtype, atol, rtol, reduction, log_target
)
test/transformers/test_kl_div.py:97:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
target_kldiv = LigerKLDIVLoss(), B = 32, T = 4096, V = 1024, dtype = torch.float32, atol = 1e-08, rtol = 1e-06, reduction = 'none'
log_target = False, is_last_layer = True, device = 'cuda'
def _test_correctness_once(
target_kldiv,
B,
T,
V,
dtype,
atol,
rtol,
reduction,
log_target,
is_last_layer=True,
device="cuda",
):
torch.manual_seed(0)
torch_kldiv = KLDivLoss(reduction=reduction, log_target=log_target)
input = torch.randn(
B * T, V, device=device, dtype=dtype, requires_grad=True
).log_softmax(dim=-1)
x1 = input.detach().clone().requires_grad_(True)
x2 = input.detach().clone().requires_grad_(True)
with torch.no_grad():
target = torch.randn(B * T, V, device=device).softmax(dim=-1)
output = torch_kldiv(x1, target)
output2 = target_kldiv(x2, target)
> assert torch.allclose(output, output2, atol=atol, rtol=rtol)
E AssertionError: assert False
E + where False = <built-in method allclose of type object at 0x7035c99e82c0>(tensor([[ 3.8871e-04, 1.5342e-03, 9.7731e-04, ..., 1.5857e-04,\n 2.0651e-05, -2.0225e-04],\n [ 3.0436e-04, 1.4040e-03, -1.4338e-04, ..., -9.6487e-04,\n 3.6957e-04, -1.7970e-04],\n [ 1.3870e-02, 1.8989e-03, -2.3409e-04, ..., -9.2741e-05,\n -2.1325e-03, -3.6861e-04],\n ...,\n [ 1.6965e-04, 7.5081e-04, 1.7243e-03, ..., -3.3345e-04,\n 2.9291e-04, 4.6570e-03],\n [-8.5313e-04, 5.1247e-04, 2.9434e-03, ..., -1.6669e-04,\n 6.3304e-04, 8.2082e-04],\n [-1.0297e-03, -5.9040e-05, -4.5201e-04, ..., 1.1601e-03,\n 1.0437e-03, 2.4179e-04]], device='cuda:0', grad_fn=<SubBackward0>), tensor([[ 3.8871e-04, 1.5342e-03, 9.7731e-04, ..., 1.5857e-04,\n 2.0651e-05, -2.0225e-04],\n [ 3.0436e-04, 1.4040e-03, -1.4338e-04, ..., -9.6487e-04,\n 3.6957e-04, -1.7970e-04],\n [ 1.3870e-02, 1.8989e-03, -2.3409e-04, ..., -9.2741e-05,\n -2.1325e-03, -3.6861e-04],\n ...,\n [ 1.6965e-04, 7.5081e-04, 1.7243e-03, ..., -3.3345e-04,\n 2.9291e-04, 4.6570e-03],\n [-8.5313e-04, 5.1247e-04, 2.9434e-03, ..., -1.6669e-04,\n 6.3304e-04, 8.2082e-04],\n [-1.0297e-03, -5.9040e-05, -4.5201e-04, ..., 1.1601e-03,\n 1.0437e-03, 2.4179e-04]], device='cuda:0',\n grad_fn=<LigerKLDivLossFunctionBackward>), atol=1e-08, rtol=1e-06)
E + where <built-in method allclose of type object at 0x7035c99e82c0> = torch.allclose
test/transformers/test_kl_div.py:75: AssertionError
______________________________ test_correctness_not_last[dtype1-1e-08-1e-06-none-False-32-4096-1024] _______________________________
B = 32, T = 4096, V = 1024, log_target = False, reduction = 'none', dtype = torch.float32, atol = 1e-08, rtol = 1e-06
@pytest.mark.parametrize(*_SHAPE_PARAMS)
@pytest.mark.parametrize("log_target", [True, False])
@pytest.mark.parametrize("reduction", ["batchmean", "sum", "mean", "none"])
@pytest.mark.parametrize(*_DTYPE_PARAMS)
def test_correctness_not_last(B, T, V, log_target, reduction, dtype, atol, rtol):
liger_kldiv = LigerKLDIVLoss(reduction=reduction, log_target=log_target)
> _test_correctness_once(
liger_kldiv,
B,
T,
V,
dtype,
atol,
rtol,
reduction,
log_target,
is_last_layer=False,
)
test/transformers/test_kl_div.py:108:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
target_kldiv = LigerKLDIVLoss(), B = 32, T = 4096, V = 1024, dtype = torch.float32, atol = 1e-08, rtol = 1e-06, reduction = 'none'
log_target = False, is_last_layer = False, device = 'cuda'
def _test_correctness_once(
target_kldiv,
B,
T,
V,
dtype,
atol,
rtol,
reduction,
log_target,
is_last_layer=True,
device="cuda",
):
torch.manual_seed(0)
torch_kldiv = KLDivLoss(reduction=reduction, log_target=log_target)
input = torch.randn(
B * T, V, device=device, dtype=dtype, requires_grad=True
).log_softmax(dim=-1)
x1 = input.detach().clone().requires_grad_(True)
x2 = input.detach().clone().requires_grad_(True)
with torch.no_grad():
target = torch.randn(B * T, V, device=device).softmax(dim=-1)
output = torch_kldiv(x1, target)
output2 = target_kldiv(x2, target)
> assert torch.allclose(output, output2, atol=atol, rtol=rtol)
E AssertionError: assert False
E + where False = <built-in method allclose of type object at 0x7035c99e82c0>(tensor([[ 3.8871e-04, 1.5342e-03, 9.7731e-04, ..., 1.5857e-04,\n 2.0651e-05, -2.0225e-04],\n [ 3.0436e-04, 1.4040e-03, -1.4338e-04, ..., -9.6487e-04,\n 3.6957e-04, -1.7970e-04],\n [ 1.3870e-02, 1.8989e-03, -2.3409e-04, ..., -9.2741e-05,\n -2.1325e-03, -3.6861e-04],\n ...,\n [ 1.6965e-04, 7.5081e-04, 1.7243e-03, ..., -3.3345e-04,\n 2.9291e-04, 4.6570e-03],\n [-8.5313e-04, 5.1247e-04, 2.9434e-03, ..., -1.6669e-04,\n 6.3304e-04, 8.2082e-04],\n [-1.0297e-03, -5.9040e-05, -4.5201e-04, ..., 1.1601e-03,\n 1.0437e-03, 2.4179e-04]], device='cuda:0', grad_fn=<SubBackward0>), tensor([[ 3.8871e-04, 1.5342e-03, 9.7731e-04, ..., 1.5857e-04,\n 2.0651e-05, -2.0225e-04],\n [ 3.0436e-04, 1.4040e-03, -1.4338e-04, ..., -9.6487e-04,\n 3.6957e-04, -1.7970e-04],\n [ 1.3870e-02, 1.8989e-03, -2.3409e-04, ..., -9.2741e-05,\n -2.1325e-03, -3.6861e-04],\n ...,\n [ 1.6965e-04, 7.5081e-04, 1.7243e-03, ..., -3.3345e-04,\n 2.9291e-04, 4.6570e-03],\n [-8.5313e-04, 5.1247e-04, 2.9434e-03, ..., -1.6669e-04,\n 6.3304e-04, 8.2082e-04],\n [-1.0297e-03, -5.9040e-05, -4.5201e-04, ..., 1.1601e-03,\n 1.0437e-03, 2.4179e-04]], device='cuda:0',\n grad_fn=<LigerKLDivLossFunctionBackward>), atol=1e-08, rtol=1e-06)
E + where <built-in method allclose of type object at 0x7035c99e82c0> = torch.allclose
test/transformers/test_kl_div.py:75: AssertionError
_________________________________________________ test_import_custom_cache_manager _________________________________________________
def test_import_custom_cache_manager():
from triton.runtime.cache import get_cache_manager
from liger_kernel.triton import apply_liger_triton_cache_manager
apply_liger_triton_cache_manager()
> cache_manager = get_cache_manager(key="test_hash")
test/triton/test_triton_monkey_patch.py:17:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
/opt/conda/envs/py_3.9/lib/python3.9/site-packages/triton/runtime/cache.py:277: in get_cache_manager
return __cache_cls(_base64(key))
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
key = 'test_hash'
def _base64(key):
# Assume key is a hex string.
> return base64.urlsafe_b64encode(bytes.fromhex(key)).decode("utf-8").rstrip("=")
E ValueError: non-hexadecimal number found in fromhex() arg at position 0
/opt/conda/envs/py_3.9/lib/python3.9/site-packages/triton/runtime/cache.py:261: ValueError
===================================================== short test summary info ======================================================
FAILED test/transformers/test_cross_entropy.py::test_correctness_with_ignore_index[10.0-dtype5-1e-08-1e-06-sum-2-4096-32000--100] - AssertionError: assert False
+ where False = <built-in method allclose of type object at 0x7035c99e82c0>(tensor([[4.0225e-16, 3.9353e-17, 6.6616e-19, ..., 1.3759e-13, 7.6381e-10,\n 4.4185e-23],\n [2.9569e-12, 3.8580e-19, 5.3756e-16, ..., 6.0166e-23, 1.4681e-17,\n 5.1994e-20],\n [4.7900e-26, 1.0599e-04, 7.0237e-19, ..., 1.1461e-20, 1.0415e-10,\n 1.0237e-19],\n ...,\n [6.9540e-17, 3.4471e-22, 2.7309e-14, ..., 2.5999e-26, 2.5635e-19,\n 7.0793e-16],\n [6.3721e-23, 1.2054e-13, 1.8638e-20, ..., 1.2807e-23, 5.5705e-16,\n 2.3085e-13],\n [1.9623e-20, 2.4720e-11, 1.8808e-15, ..., 3.5100e-20, 3.6195e-15,\n 1.5356e-23]], device='cuda:0'), tensor([[4.0225e-16, 3.9353e-17, 6.6616e-19, ..., 1.3759e-13, 7.6381e-10,\n 4.4185e-23],\n [2.9569e-12, 3.8580e-19, 5.3756e-16, ..., 6.0166e-23, 1.4681e-17,\n 5.1994e-20],\n [4.7900e-26, 1.0599e-04, 7.0237e-19, ..., 1.1461e-20, 1.0415e-10,\n 1.0237e-19],\n ...,\n [6.9540e-17, 3.4471e-22, 2.7309e-14, ..., 2.5999e-26, 2.5635e-19,\n 7.0793e-16],\n [6.3722e-23, 1.2054e-13, 1.8638e-20, ..., 1.2807e-23, 5.5705e-16,\n 2.3085e-13],\n [1.9623e-20, 2.4720e-11, 1.8808e-15, ..., 3.5100e-20, 3.6195e-15,\n 1.5356e-23]], device='cuda:0'), atol=1e-08, rtol=1e-06)
+ where <built-in method allclose of type object at 0x7035c99e82c0> = torch.allclose
+ and tensor([[4.0225e-16, 3.9353e-17, 6.6616e-19, ..., 1.3759e-13, 7.6381e-10,\n 4.4185e-23],\n [2.9569e-12, 3.8580e-19, 5.3756e-16, ..., 6.0166e-23, 1.4681e-17,\n 5.1994e-20],\n [4.7900e-26, 1.0599e-04, 7.0237e-19, ..., 1.1461e-20, 1.0415e-10,\n 1.0237e-19],\n ...,\n [6.9540e-17, 3.4471e-22, 2.7309e-14, ..., 2.5999e-26, 2.5635e-19,\n 7.0793e-16],\n [6.3721e-23, 1.2054e-13, 1.8638e-20, ..., 1.2807e-23, 5.5705e-16,\n 2.3085e-13],\n [1.9623e-20, 2.4720e-11, 1.8808e-15, ..., 3.5100e-20, 3.6195e-15,\n 1.5356e-23]], device='cuda:0') = tensor([[ 6.0503, 3.7258, -0.3530, ..., 11.8853, 20.5071, -9.9739],\n [ 15.2597, -0.5924, 6.6471, ..., -9.3584, 3.0466, -2.5966],\n [-17.9122, 31.2363, -1.4114, ..., -5.5268, 17.4033, -3.3372],\n ...,\n [ 4.3242, -7.8904, 10.2973, ..., -17.3829, -1.2789, 6.6447],\n [-10.9055, 10.4553, -5.2270, ..., -12.5100, 5.0782, 11.1050],\n [ -5.8922, 15.0620, 5.5783, ..., -5.3107, 6.2329, -13.0452]],\n device='cuda:0', requires_grad=True).grad
+ and tensor([[4.0225e-16, 3.9353e-17, 6.6616e-19, ..., 1.3759e-13, 7.6381e-10,\n 4.4185e-23],\n [2.9569e-12, 3.8580e-19, 5.3756e-16, ..., 6.0166e-23, 1.4681e-17,\n 5.1994e-20],\n [4.7900e-26, 1.0599e-04, 7.0237e-19, ..., 1.1461e-20, 1.0415e-10,\n 1.0237e-19],\n ...,\n [6.9540e-17, 3.4471e-22, 2.7309e-14, ..., 2.5999e-26, 2.5635e-19,\n 7.0793e-16],\n [6.3722e-23, 1.2054e-13, 1.8638e-20, ..., 1.2807e-23, 5.5705e-16,\n 2.3085e-13],\n [1.9623e-20, 2.4720e-11, 1.8808e-15, ..., 3.5100e-20, 3.6195e-15,\n 1.5356e-23]], device='cuda:0') = tensor([[4.0225e-16, 3.9353e-17, 6.6616e-19, ..., 1.3759e-13, 7.6381e-10,\n 4.4185e-23],\n [2.9569e-12, 3.8580e-19, 5.3756e-16, ..., 6.0166e-23, 1.4681e-17,\n 5.1994e-20],\n [4.7900e-26, 1.0599e-04, 7.0237e-19, ..., 1.1461e-20, 1.0415e-10,\n 1.0237e-19],\n ...,\n [6.9540e-17, 3.4471e-22, 2.7309e-14, ..., 2.5999e-26, 2.5635e-19,\n 7.0793e-16],\n [6.3722e-23, 1.2054e-13, 1.8638e-20, ..., 1.2807e-23, 5.5705e-16,\n 2.3085e-13],\n [1.9623e-20, 2.4720e-11, 1.8808e-15, ..., 3.5100e-20, 3.6195e-15,\n 1.5356e-23]], device='cuda:0', requires_grad=True).grad
FAILED test/transformers/test_jsd.py::test_correctness_with_beta[0.1-dtype1-1e-08-1e-06-1-4096-128256] - AssertionError: Number of mismatched elements: 1
Mismatch at index (): tensor1[()] = 0.08054989576339722, tensor2[()] = 0.08054977655410767
FAILED test/transformers/test_jsd.py::test_correctness_with_beta[0.9-dtype1-1e-08-1e-06-1-4096-128256] - AssertionError: Number of mismatched elements: 1
Mismatch at index (): tensor1[()] = 0.08054172992706299, tensor2[()] = 0.08054161071777344
FAILED test/transformers/test_kl_div.py::test_correctness[dtype1-1e-08-1e-06-none-False-32-4096-1024] - AssertionError: assert False
+ where False = <built-in method allclose of type object at 0x7035c99e82c0>(tensor([[ 3.8871e-04, 1.5342e-03, 9.7731e-04, ..., 1.5857e-04,\n 2.0651e-05, -2.0225e-04],\n [ 3.0436e-04, 1.4040e-03, -1.4338e-04, ..., -9.6487e-04,\n 3.6957e-04, -1.7970e-04],\n [ 1.3870e-02, 1.8989e-03, -2.3409e-04, ..., -9.2741e-05,\n -2.1325e-03, -3.6861e-04],\n ...,\n [ 1.6965e-04, 7.5081e-04, 1.7243e-03, ..., -3.3345e-04,\n 2.9291e-04, 4.6570e-03],\n [-8.5313e-04, 5.1247e-04, 2.9434e-03, ..., -1.6669e-04,\n 6.3304e-04, 8.2082e-04],\n [-1.0297e-03, -5.9040e-05, -4.5201e-04, ..., 1.1601e-03,\n 1.0437e-03, 2.4179e-04]], device='cuda:0', grad_fn=<SubBackward0>), tensor([[ 3.8871e-04, 1.5342e-03, 9.7731e-04, ..., 1.5857e-04,\n 2.0651e-05, -2.0225e-04],\n [ 3.0436e-04, 1.4040e-03, -1.4338e-04, ..., -9.6487e-04,\n 3.6957e-04, -1.7970e-04],\n [ 1.3870e-02, 1.8989e-03, -2.3409e-04, ..., -9.2741e-05,\n -2.1325e-03, -3.6861e-04],\n ...,\n [ 1.6965e-04, 7.5081e-04, 1.7243e-03, ..., -3.3345e-04,\n 2.9291e-04, 4.6570e-03],\n [-8.5313e-04, 5.1247e-04, 2.9434e-03, ..., -1.6669e-04,\n 6.3304e-04, 8.2082e-04],\n [-1.0297e-03, -5.9040e-05, -4.5201e-04, ..., 1.1601e-03,\n 1.0437e-03, 2.4179e-04]], device='cuda:0',\n grad_fn=<LigerKLDivLossFunctionBackward>), atol=1e-08, rtol=1e-06)
+ where <built-in method allclose of type object at 0x7035c99e82c0> = torch.allclose
FAILED test/transformers/test_kl_div.py::test_correctness_not_last[dtype1-1e-08-1e-06-none-False-32-4096-1024] - AssertionError: assert False
+ where False = <built-in method allclose of type object at 0x7035c99e82c0>(tensor([[ 3.8871e-04, 1.5342e-03, 9.7731e-04, ..., 1.5857e-04,\n 2.0651e-05, -2.0225e-04],\n [ 3.0436e-04, 1.4040e-03, -1.4338e-04, ..., -9.6487e-04,\n 3.6957e-04, -1.7970e-04],\n [ 1.3870e-02, 1.8989e-03, -2.3409e-04, ..., -9.2741e-05,\n -2.1325e-03, -3.6861e-04],\n ...,\n [ 1.6965e-04, 7.5081e-04, 1.7243e-03, ..., -3.3345e-04,\n 2.9291e-04, 4.6570e-03],\n [-8.5313e-04, 5.1247e-04, 2.9434e-03, ..., -1.6669e-04,\n 6.3304e-04, 8.2082e-04],\n [-1.0297e-03, -5.9040e-05, -4.5201e-04, ..., 1.1601e-03,\n 1.0437e-03, 2.4179e-04]], device='cuda:0', grad_fn=<SubBackward0>), tensor([[ 3.8871e-04, 1.5342e-03, 9.7731e-04, ..., 1.5857e-04,\n 2.0651e-05, -2.0225e-04],\n [ 3.0436e-04, 1.4040e-03, -1.4338e-04, ..., -9.6487e-04,\n 3.6957e-04, -1.7970e-04],\n [ 1.3870e-02, 1.8989e-03, -2.3409e-04, ..., -9.2741e-05,\n -2.1325e-03, -3.6861e-04],\n ...,\n [ 1.6965e-04, 7.5081e-04, 1.7243e-03, ..., -3.3345e-04,\n 2.9291e-04, 4.6570e-03],\n [-8.5313e-04, 5.1247e-04, 2.9434e-03, ..., -1.6669e-04,\n 6.3304e-04, 8.2082e-04],\n [-1.0297e-03, -5.9040e-05, -4.5201e-04, ..., 1.1601e-03,\n 1.0437e-03, 2.4179e-04]], device='cuda:0',\n grad_fn=<LigerKLDivLossFunctionBackward>), atol=1e-08, rtol=1e-06)
+ where <built-in method allclose of type object at 0x7035c99e82c0> = torch.allclose
FAILED test/triton/test_triton_monkey_patch.py::test_import_custom_cache_manager - ValueError: non-hexadecimal number found in fromhex() arg at position 0
================================ 6 failed, 1012 passed, 8 skipped, 72 warnings in 630.02s (0:10:30) ================================
make: *** [Makefile:8: test] Error 1
```
</details>
---------
Co-authored-by: tjtanaa <[email protected]>
Co-authored-by: root <tjtanaa>
commit a2f301759e051278c1491a1acd2e8ae9d09d21c5
Author: hoshi-hiyouga <[email protected]>
Date: Sat Nov 2 14:51:31 2024 +0800
Fix llama forward patch (#339)
## Summary
<!--- This is a required section; please describe the main purpose of
this proposed code change. --->
The present version of liger kernel use `kwargs` in model forward
function, while in transformers 4.46.0-4.46.1, they pass the
`num_items_in_batch` parameter when `loss_kwargs` was in the model's
forward function [1][2], thus, we change the `kwargs` to `loss_kwargs`
to align with the transformers' implementation [3].
[1]
https://github.com/huggingface/transformers/blob/v4.46.1/src/transformers/trainer.py#L593
[2]
https://github.com/huggingface/transformers/blob/v4.46.1/src/transformers/trainer.py#L3620-L3625
[3]
https://github.com/huggingface/transformers/blob/v4.46.1/src/transformers/models/llama/modeling_llama.py#L1137-L1151
<!---
## Details
This is an optional section; is there anything specific that reviewers
should be aware of?
--->
## Testing Done
<!--- This is a required section; please describe how this change was
tested. --->
<!--
Replace BLANK with your device type. For example, A100-80G-PCIe
Complete the following tasks before sending your PR, and replace `[ ]`
with
`[x]` to indicate you have done them.
-->
- Hardware Type: <BLANK>
- [ ] run `make test` to ensure correctness
- [ ] run `make checkstyle` to ensure code style
- [ ] run `make test-convergence` to ensure convergence
commit 1b04de6b47845f47473500ea18ed55b87e68a68e
Author: Byron Hsu <[email protected]>
Date: Fri Nov 1 13:18:31 2024 -0700
Update pyproject.toml
After https://github.com/linkedin/Liger-Kernel/pull/274, triton needs to be >=2.3.1
commit ac2e8f4563289f7bee0ad9652926afec5c46747b
Author: Yun Dai <[email protected]>
Date: Thu Oct 31 21:46:53 2024 -0700
Fix FusedLinearJSD precision issue when using AMP (#336)
## Summary
<!--- This is a required section; please describe the main purpose of
this proposed code change. --->
1. make sure all the computation between logit to final JSD loss happen
on FP32
2. make sure FLJSD works properly under mixed precision scenario, also
add a test to guard
3. the Torch CE loss impl we use in testing FLCE misses out the fp32
cast for logits, add it back. **NOTE: we should definitely jus switch
directly to [HF
impl](https://github.com/huggingface/transformers/blob/main/src/transformers/loss/loss_utils.py#L32)
for testing to ensure always doing apple-to-apple comparison. See the
added TODO item.**
<!---
## Details
This is an optional section; is there anything specific that reviewers
should be aware of?
--->
## Testing Done
<!--- This is a required section; please describe how this change was
tested. --->
<!--
Replace BLANK with your device type. For example, A100-80G-PCIe
Complete the following tasks before sending your PR, and replace `[ ]`
with
`[x]` to indicate you have done them.
-->
- Hardware Type: <BLANK>
- [x] run `make test` to ensure correctness
- [x] run `make checkstyle` to ensure code style
- [x] run `make test-convergence` to ensure convergence
commit 659d7d7856bf755c1cf26f2df6173da68841ba17
Author: Chiwan Park <[email protected]>
Date: Fri Nov 1 08:24:06 2024 +0900
Fix incorrect training of first and last Medusa heads (#325)
## Summary
<!--- This is a required section; please describe the main purpose of
this proposed code change. --->
Currently, there are two errors on Medusa training examples:
1. When we use Liger Kernel, the first head (`model.medusa_head[0]`) is
not trained.
2. When we don't use Liger Kernel, the logits of the last head
(`medusa_logits[-1]`) is ignored.
This PR fixes these errors.
<!---
## Details
This is an optional section; is there anything specific that reviewers
should be aware of?
--->
## Testing Done
<!--- This is a required section; please describe how this change was
tested. --->
<!--
Replace BLANK with your device type. For example, A100-80G-PCIe
Complete the following tasks before sending your PR, and replace `[ ]`
with
`[x]` to indicate you have done them.
-->
- Hardware Type: A100 80GB 8 GPUs
- [x] run `make test` to ensure correctness
- [x] run `make checkstyle` to ensure code style
- [x] run `make test-convergence` to ensure convergence
commit 827b51c45762d6fc0ffaa7655126467c16f06d44
Author: Byron Hsu <[email protected]>
Date: Thu Oct 31 15:33:05 2024 -0700
Update llama.py
commit e28521bed9f13daacdc363b6975158a2e67ec3a4
Author: Byron Hsu <[email protected]>
Date: Thu Oct 31 14:40:41 2024 -0700
Fix huggingface GA issue for llama (#333)
## Summary
<!--- This is a required section; please describe the main purpose of
this proposed code change. --->
<!---
## Details
This is an optional section; is there anything specific that reviewers
should be aware of?
--->
To fix https://github.com/linkedin/Liger-Kernel/pull/322
This PR introduces a new `lce_forward` compatible with
`transformers>=4.46.0` (after grad acc fix) while ensuring backward
compatibilty.
To be specific, i keep the original flce untouched and write a new one
for `4.46.0`. If HF version is `<4.46.0`, it will show a warning for
deprecation, and fallback to the old flce.
```python
if transformer_version >= version.parse("4.46.0"):
modeling_llama.LlamaForCausalLM.forward = llama_lce_forward
else: # if version < 4.46.0
logger.warning(
"Support for transformers versions < 4.46.0 will soon be discontinued due to issues with incorrect gradient accumulation. "
"Please consider upgrading to avoid potential issues. See details: https://github.com/huggingface/transformers/pull/34191"
)
modeling_llama.LlamaForCausalLM.forward = llama_lce_forward_deprecated
```
For more context of grad acc fix, please see
https://github.com/huggingface/transformers/pull/34191
## TODO
- [ ] broadcast the changes to all models once the effect is verified.
## Testing Done
<!--- This is a required section; please describe how this change was
tested. --->
<!--
Replace BLANK with your device type. For example, A100-80G-PCIe
Complete the following tasks before sending your PR, and replace `[ ]`
with
`[x]` to indicate you have done them.
-->
- Hardware Type: <BLANK>
- [x] run `make test` to ensure correctness
- [x] run `make checkstyle` to ensure code style
- [x] run `make test-convergence` to ensure convergence
commit 337bf9a8361740c1caf38ba28b9dc9f7303c9aca
Author: Anish <[email protected]>
Date: Thu Oct 31 06:04:25 2024 +0545
docs(CONTRIBUTING): fix typo (#331)
## Fix typo in CONTRIBUTING.md
This PR corrects a typo in the CONTRIBUTING.md file, changing
"functionaility" to "functionality" in the semantic versioning section.
Co-authored-by: Yun Dai <[email protected]>
commit 48aa62d3ecb0a46009d2b92510a63e39e860fe82
Author: Tcc0403 <[email protected]>
Date: Thu Oct 31 01:15:12 2024 +0800
Add missing ignore_index tests (#310)
## Summary
`ignore_index` in fused_linear_cross_entropy was not tested
## Testing Done
- Hardware Type: gpu-ci
- [x] run `make test` to ensure correctness
- [x] run `make checkstyle` to ensure code style
- [x] run `make test-convergence` to ensure convergence
---------
Co-authored-by: Byron Hsu <[email protected]>
Co-authored-by: Yun Dai <[email protected]>
commit 1c0c75c3455e788d575966bfc5edec3ef166835e
Author: Yun Dai <[email protected]>
Date: Tue Oct 29 21:59:37 2024 -0700
fix fused JSD with ignore index (#330)
## Summary
<!--- This is a required section; please describe the main purpose of
this proposed code change. --->
1. There's currently a bug in fused linear JSD where we don't extract
the correct subset of label corresponding to the currently processed
chunk
2. add some tests to make sure results are correct when all tokens are
ignored
<!---
## Details
This is an optional section; is there anything specific that reviewers
should be aware of?
--->
## Testing Done
<!--- This is a required section; please describe how this change was
tested. --->
<!--
Replace BLANK with your device type. For example, A100-80G-PCIe
Complete the following tasks before sending your PR, and replace `[ ]`
with
`[x]` to indicate you have done them.
-->
- Hardware Type: <BLANK>
- [x] run `make test` to ensure correctness
- [x] run `make checkstyle` to ensure code style
- [x] run `make test-convergence` to ensure convergence
commit 6cdc93deee15ab6c843149d6ed660c297c5c2d4a
Author: Yun Dai <[email protected]>
Date: Fri Oct 25 17:23:23 2024 -0700
fix FLCE AMP issue (#318)
## Summary
<!--- This is a required section; please describe the main purpose of
this proposed code change. --->
fixes #305 : just rely on torch AMP to determine the input dtype when
AMP context is enabled
<!---
## Details
This is an optional section; is there anything specific that reviewers
should be aware of?
--->
## Testing Done
<!--- This is a required section; please describe how this change was
tested. --->
<!--
Replace BLANK with your device type. For example, A100-80G-PCIe
Complete the following tasks before sending your PR, and replace `[ ]`
with
`[x]` to indicate you have done them.
-->
- Hardware Type: <BLANK>
- [x] run `make test` to ensure correctness
- [x] run `make checkstyle` to ensure code style
- [x] run `make test-convergence` to ensure convergence
commit 9ad8f89373b2206e86e9bb1cdc6e63c37275bd81
Author: Byron Hsu <[email protected]>
Date: Fri Oct 25 09:53:42 2024 -0700
Update README.md
commit 4e2f7c6b9185560294c24ee48c32c07cefc7e828
Author: Byron Hsu <[email protected]>
Date: Fri Oct 25 09:53:08 2024 -0700
remove torch compile section until the issue is fixed
commit 99599091373f178e8ad6a69ecb1b32351d1d5c1f
Author: Byron Hsu <[email protected]>
Date: Mon Oct 21 14:41:32 2024 -0700
Update README.md
commit e49b83a4af985ef1f75c994bbdb7ed103b22ae11
Author: Byron Hsu <[email protected]>
Date: Mon Oct 21 14:40:01 2024 -0700
Update citation and add tech report (#317)
## Summary
<!--- This is a required section; please describe the main purpose of
this proposed code change. --->
<!---
## Details
This is an optional section; is there anything specific that reviewers
should be aware of?
--->
## Testing Done
<!--- This is a required section; please describe how this change was
tested. --->
<!--
Replace BLANK with your device type. For example, A100-80G-PCIe
Complete the following tasks before sending your PR, and replace `[ ]`
with
`[x]` to indicate you have done them.
-->
- Hardware Type: <BLANK>
- [ ] run `make test` to ensure correctness
- [ ] run `make checkstyle` to ensure code style
- [ ] run `make test-convergence` to ensure convergence
commit 7da01b7188266342b94858fd2e01bf037099441c
Author: Kürşat Aktaş <[email protected]>
Date: Tue Oct 22 00:22:41 2024 +0300
Introducing Liger Kernel Guru on Gurubase.io (#316)
I created the [Liger Kernel Guru](https://gurubase.io/g/liger-kernel)
badge on Gurubase.io upon request from @ByronHsu.
Adding a new badge next to the Discord badge made all the badge text
smaller, as the current style pre…
-[2024/10/21] We have released the tech report of Liger Kernel on Arxiv: https://arxiv.org/pdf/2410.10989
48
57
-[2024/9/6] We release v0.2.1 ([X post](https://x.com/liger_kernel/status/1832168197002510649)). 2500+ Stars, 10+ New Contributors, 50+ PRs, 50k Downloads in two weeks!
49
58
-[2024/8/31] CUDA MODE talk, [Liger-Kernel: Real-world Triton kernel for LLM Training](https://youtu.be/gWble4FreV4?si=dxPeIchhkJ36Mbns), [Slides](https://github.com/cuda-mode/lectures?tab=readme-ov-file#lecture-28-liger-kernel)
50
59
-[2024/8/23] Official release: check out our [X post](https://x.com/hsu_byron/status/1827072737673982056)
@@ -102,11 +111,21 @@ With one line of code, Liger Kernel can increase throughput by more than 20% and
102
111
103
112
## Installation
104
113
105
-
### Dependencies
114
+
### Dependencies
115
+
116
+
#### CUDA
106
117
107
118
-`torch >= 2.1.2`
108
119
-`triton >= 2.3.0`
109
-
-`transformers >= 4.42.0`
120
+
121
+
#### ROCm
122
+
123
+
-`torch >= 2.5.0` Install according to the instruction in Pytorch official webpage.
124
+
-`triton >= 3.0.0` Install from pypi. (e.g. `pip install triton==3.0.0`)
125
+
126
+
### Optional Dependencies
127
+
128
+
-`transformers >= 4.x`: Required if you plan to use the transformers models patching APIs. The specific model you are working will dictate the minimum version of transformers.
110
129
111
130
> **Note:**
112
131
> Our kernels inherit the full spectrum of hardware compatibility offered by [Triton](https://github.com/triton-lang/triton).
-**RMSNorm**: [RMSNorm](https://arxiv.org/pdf/1910.07467), which normalizes activations using their root mean square, is implemented by fusing the normalization and scaling steps into a single Triton kernel, and achieves ~3X speedup with ~3X peak memory reduction.
249
275
-**LayerNorm**: [LayerNorm](https://arxiv.org/pdf/1607.06450), which centers and normalizes activations across the feature dimension, is implemented by fusing the centering, normalization and scaling steps into a single Triton kernel, and achieves ~2X speedup.
-**FusedLinearCrossEntropy**: Peak memory usage of cross entropy loss is further improved by fusing the model head with the CE loss and chunking the input for block-wise loss and gradient calculation, a technique inspired by [Efficient Cross Entropy](https://github.com/mgmalek/efficient_cross_entropy). It achieves >4X memory reduction for 128k vocab size. **This is highly effective for large batch size, large sequence length, and large vocabulary sizes.** Please refer to the [Medusa example](https://github.com/linkedin/Liger-Kernel/tree/main/examples/medusa) for individual kernel usage.
260
286
-**KLDivergence**: [KL Divergence](https://pytorch.org/docs/stable/generated/torch.nn.KLDivLoss.html) is implemented by fusing the forward into a single triton kernel, with reduction done outside the kernel. It achieves ~1.5X speed and ~15% memory reduction for 128K vocab size.
287
+
-**JSD**: [Generalized JSD](https://arxiv.org/pdf/2306.13649) (Jensen-Shannon divergence), is implemented by computing both the loss and gradient in the forward pass. It achieves ~1.5X speed and ~54% memory reduction for 128k vocab size.
288
+
-**FusedLinearJSD**: Peak memory usage of JSD loss is further improved by fusing the model head with the model head with the JSD and chunking the input for block-wise loss and gradient calculation. It achieves ~85% memory reduction for 128k vocab size where batch size $\times$ sequence length is 8192.
-**Embedding**: [Embedding](https://pytorch.org/docs/stable/generated/torch.nn.Embedding.html) is implemented by fusing embedding lookup and output operations. It achieves a peak speedup of ~1.5x in the forward pass and an overall speedup of ~1.1x.
270
-
299
+
-**Matmul int2xint8**: is implemented by using the cache tiled matrix multiplication and by fusing the matmul with the unpacking process which achieves a considerable speed up and performs on par with @torch.compile
271
300
<!-- TODO: be more specific about batch size -->
272
301
> **Note:**
273
302
> Reported speedups and memory reductions are with respect to the LLaMA 3-8B Hugging Face layer implementations. All models use 4K hidden size and 4K sequence length and are evaluated based on memory usage and wall time for the forward+backward pass on a single NVIDIA A100 80G GPU using small batch sizes. Liger kernels exhibit more efficient scaling to larger batch sizes, detailed further in the [Benchmark](./benchmark) folder.
274
303
275
-
## Note on ML Compiler
276
-
277
-
### Torch Compile
278
-
279
-
Since Liger Kernel is 100% Triton-based, it works seamlessly with [`torch.compile`](https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html). In the following example, Liger Kernel can further optimize the model on top of Torch Compile, reducing the memory by more than half.
This project is licensed under the [BSD 2-CLAUSE](https://github.com/linkedin/Liger-Kernel/blob/main/LICENSE) License (see `LICENSE` for details).
338
+
It also includes components from projects licensed under:
339
+
340
+
- Apache License 2.0 (see `LICENSE-APACHE-2.0` for details).
341
+
- MIT License (see `LICENSE-MIT-AutoAWQ` for details).
342
+
- MIT License (see `LICENSE-MIT-Efficient Cross Entropy` for details).
343
+
- MIT License (see `LICENSE-MIT-llmc` for details).
344
+
- MIT License (see `LICENSE-MIT-triton` for details).
324
345
325
346
## Contact
326
347
@@ -331,13 +352,29 @@ Many thanks to the contributors to these projects for their invaluable work that
331
352
332
353
Biblatex entry:
333
354
```bib
334
-
@software{liger2024,
335
-
title = {Liger-Kernel: Efficient Triton Kernels for LLM Training},
336
-
author = {Hsu, Pin-Lun and Dai, Yun and Kothapalli, Vignesh and Song, Qingquan and Tang, Shao and Zhu, Siyu},
title={Liger Kernel: Efficient Triton Kernels for LLM Training},
357
+
author={Pin-Lun Hsu and Yun Dai and Vignesh Kothapalli and Qingquan Song and Shao Tang and Siyu Zhu and Steven Shimizu and Shivam Sahni and Haowen Ning and Yanning Chen},
358
+
year={2024},
359
+
eprint={2410.10989},
360
+
archivePrefix={arXiv},
361
+
primaryClass={cs.LG},
362
+
url={https://arxiv.org/abs/2410.10989},
363
+
journal={arXiv preprint arXiv:2410.10989},
339
364
}
340
365
```
341
366
342
367
## Star History
343
368
[](https://star-history.com/#linkedin/Liger-Kernel&Date)
0 commit comments