-
Notifications
You must be signed in to change notification settings - Fork 389
Commit 439fe1c
authored
Enhance Cross Entropy Softcap Unit Test (#423)
## Summary
Closes #418
- Add gradient check after `backward()`.
- Defer type conversion and only upcast before the `tanh` operation.
This keeps original tensor `dtype` during cloning.
## Testing Done
```
============================= test session starts ==============================
platform linux -- Python 3.12.1, pytest-8.3.3, pluggy-1.5.0
rootdir: /root/liger-kernel
configfile: pyproject.toml
plugins: anyio-4.2.0, typeguard-4.1.5
collected 77 items
test/transformers/test_cross_entropy.py::test_correctness[1.0-dtype0-1e-08-0.05-sum-2-4096-32000] PASSED [ 1%]
test/transformers/test_cross_entropy.py::test_correctness[1.0-dtype0-1e-08-0.05-sum-3-423-32000] PASSED [ 2%]
test/transformers/test_cross_entropy.py::test_correctness[1.0-dtype0-1e-08-0.05-mean-2-4096-32000] PASSED [ 3%]
test/transformers/test_cross_entropy.py::test_correctness[1.0-dtype0-1e-08-0.05-mean-3-423-32000] PASSED [ 5%]
test/transformers/test_cross_entropy.py::test_correctness[1.0-dtype1-1e-08-1e-06-sum-2-4096-32000] PASSED [ 6%]
test/transformers/test_cross_entropy.py::test_correctness[1.0-dtype1-1e-08-1e-06-sum-3-423-32000] PASSED [ 7%]
test/transformers/test_cross_entropy.py::test_correctness[1.0-dtype1-1e-08-1e-06-mean-2-4096-32000] PASSED [ 9%]
test/transformers/test_cross_entropy.py::test_correctness[1.0-dtype1-1e-08-1e-06-mean-3-423-32000] PASSED [ 10%]
test/transformers/test_cross_entropy.py::test_correctness_functional[1.0-dtype0-1e-08-0.05-2-2-8] PASSED [ 11%]
test/transformers/test_cross_entropy.py::test_correctness_functional[1.0-dtype0-1e-08-0.05-9-7-41] PASSED [ 12%]
test/transformers/test_cross_entropy.py::test_correctness_functional[1.0-dtype1-1e-08-1e-06-2-2-8] PASSED [ 14%]
test/transformers/test_cross_entropy.py::test_correctness_functional[1.0-dtype1-1e-08-1e-06-9-7-41] PASSED [ 15%]
test/transformers/test_cross_entropy.py::test_correctness_with_ignore_index[1.0-dtype0-1e-08-0.05-sum-2-4096-32000-2] PASSED [ 16%]
test/transformers/test_cross_entropy.py::test_correctness_with_ignore_index[1.0-dtype0-1e-08-0.05-sum-3-423-32000--123] PASSED [ 18%]
test/transformers/test_cross_entropy.py::test_correctness_with_ignore_index[1.0-dtype0-1e-08-0.05-mean-2-4096-32000-2] PASSED [ 19%]
test/transformers/test_cross_entropy.py::test_correctness_with_ignore_index[1.0-dtype0-1e-08-0.05-mean-3-423-32000--123] PASSED [ 20%]
test/transformers/test_cross_entropy.py::test_correctness_with_ignore_index[1.0-dtype1-1e-08-1e-06-sum-2-4096-32000-2] PASSED [ 22%]
test/transformers/test_cross_entropy.py::test_correctness_with_ignore_index[1.0-dtype1-1e-08-1e-06-sum-3-423-32000--123] PASSED [ 23%]
test/transformers/test_cross_entropy.py::test_correctness_with_ignore_index[1.0-dtype1-1e-08-1e-06-mean-2-4096-32000-2] PASSED [ 24%]
test/transformers/test_cross_entropy.py::test_correctness_with_ignore_index[1.0-dtype1-1e-08-1e-06-mean-3-423-32000--123] PASSED [ 25%]
test/transformers/test_cross_entropy.py::test_correctness_with_label_smoothing_once[1.0-dtype0-1e-08-0.05-2-4096-32000-0.1] PASSED [ 27%]
test/transformers/test_cross_entropy.py::test_correctness_with_label_smoothing_once[1.0-dtype0-1e-08-0.05-3-423-32000-0.1] PASSED [ 28%]
test/transformers/test_cross_entropy.py::test_correctness_with_label_smoothing_once[1.0-dtype1-1e-08-1e-06-2-4096-32000-0.1] PASSED [ 29%]
test/transformers/test_cross_entropy.py::test_correctness_with_label_smoothing_once[1.0-dtype1-1e-08-1e-06-3-423-32000-0.1] PASSED [ 31%]
test/transformers/test_cross_entropy.py::test_correctness_with_label_smoothing_with_ignore_index_once[1.0-dtype0-1e-08-0.05-2-4096-32000-1-0.1] PASSED [ 32%]
test/transformers/test_cross_entropy.py::test_correctness_with_label_smoothing_with_ignore_index_once[1.0-dtype0-1e-08-0.05-3-423-32000--300-0.2] PASSED [ 33%]
test/transformers/test_cross_entropy.py::test_correctness_with_label_smoothing_with_ignore_index_once[1.0-dtype1-1e-08-1e-06-2-4096-32000-1-0.1] PASSED [ 35%]
test/transformers/test_cross_entropy.py::test_correctness_with_label_smoothing_with_ignore_index_once[1.0-dtype1-1e-08-1e-06-3-423-32000--300-0.2] PASSED [ 36%]
test/transformers/test_cross_entropy.py::test_correctness_with_softcap_once[1.0-dtype0-1e-08-0.05-sum-2-4096-32000-30.0] PASSED [ 37%]
test/transformers/test_cross_entropy.py::test_correctness_with_softcap_once[1.0-dtype0-1e-08-0.05-sum-3-423-32000-30.0] PASSED [ 38%]
test/transformers/test_cross_entropy.py::test_correctness_with_softcap_once[1.0-dtype0-1e-08-0.05-mean-2-4096-32000-30.0] PASSED [ 40%]
test/transformers/test_cross_entropy.py::test_correctness_with_softcap_once[1.0-dtype0-1e-08-0.05-mean-3-423-32000-30.0] PASSED [ 41%]
test/transformers/test_cross_entropy.py::test_correctness_with_softcap_once[1.0-dtype1-1e-08-1e-06-sum-2-4096-32000-30.0] PASSED [ 42%]
test/transformers/test_cross_entropy.py::test_correctness_with_softcap_once[1.0-dtype1-1e-08-1e-06-sum-3-423-32000-30.0] PASSED [ 44%]
test/transformers/test_cross_entropy.py::test_correctness_with_softcap_once[1.0-dtype1-1e-08-1e-06-mean-2-4096-32000-30.0] PASSED [ 45%]
test/transformers/test_cross_entropy.py::test_correctness_with_softcap_once[1.0-dtype1-1e-08-1e-06-mean-3-423-32000-30.0] PASSED [ 46%]
test/transformers/test_cross_entropy.py::test_correctness_with_z_loss_once[0.0001-True-1.0-dtype0-1e-08-0.05-2-4096-32000] PASSED [ 48%]
test/transformers/test_cross_entropy.py::test_correctness_with_z_loss_once[0.0001-True-1.0-dtype0-1e-08-0.05-3-423-32000] PASSED [ 49%]
test/transformers/test_cross_entropy.py::test_correctness_with_z_loss_once[0.0001-True-1.0-dtype1-1e-08-1e-06-2-4096-32000] PASSED [ 50%]
test/transformers/test_cross_entropy.py::test_correctness_with_z_loss_once[0.0001-True-1.0-dtype1-1e-08-1e-06-3-423-32000] PASSED [ 51%]
test/transformers/test_cross_entropy.py::test_correctness_with_z_loss_once[0.0001-False-1.0-dtype0-1e-08-0.05-2-4096-32000] PASSED [ 53%]
test/transformers/test_cross_entropy.py::test_correctness_with_z_loss_once[0.0001-False-1.0-dtype0-1e-08-0.05-3-423-32000] PASS
test/transformers/test_cross_entropy.py::test_correctness_with_z_loss_once[0.0001-False-1.0-dtype1-1e-08-1e-06-2-4096-32000] PASSED [ 55%]
test/transformers/test_cross_entropy.py::test_correctness_with_z_loss_once[0.0001-False-1.0-dtype1-1e-08-1e-06-3-423-32000] PASSED [ 57%]
test/transformers/test_cross_entropy.py::test_correctness_with_z_loss_once[1e-05-True-1.0-dtype0-1e-08-0.05-2-4096-32000] PASSED [ 58%]
test/transformers/test_cross_entropy.py::test_correctness_with_z_loss_once[1e-05-True-1.0-dtype0-1e-08-0.05-3-423-32000] PASSED [ 59%]
test/transformers/test_cross_entropy.py::test_correctness_with_z_loss_once[1e-05-True-1.0-dtype1-1e-08-1e-06-2-4096-32000] PASSED [ 61%]
test/transformers/test_cross_entropy.py::test_correctness_with_z_loss_once[1e-05-True-1.0-dtype1-1e-08-1e-06-3-423-32000] PASSED [ 62%]
test/transformers/test_cross_entropy.py::test_correctness_with_z_loss_once[1e-05-False-1.0-dtype0-1e-08-0.05-2-4096-32000] PASSED [ 63%]
test/transformers/test_cross_entropy.py::test_correctness_with_z_loss_once[1e-05-False-1.0-dtype0-1e-08-0.05-3-423-32000] PASSED [ 64%]
test/transformers/test_cross_entropy.py::test_correctness_with_z_loss_once[1e-05-False-1.0-dtype1-1e-08-1e-06-2-4096-32000] PASSED [ 66%]
test/transformers/test_cross_entropy.py::test_correctness_with_z_loss_once[1e-05-False-1.0-dtype1-1e-08-1e-06-3-423-32000] PASSED [ 67%]
test/transformers/test_cross_entropy.py::test_correctness_with_z_loss_with_other_params_once[0.1-42-mean-True-0.0001-1.0-dtype0-1e-08-0.05-2-4096-32000] PASSED [ 68%]
test/transformers/test_cross_entropy.py::test_correctness_with_z_loss_with_other_params_once[0.1-42-mean-True-0.0001-1.0-dtype0-1e-08-0.05-3-423-32000] PASSED [ 70%]
test/transformers/test_cross_entropy.py::test_correctness_with_z_loss_with_other_params_once[0.1-42-mean-True-0.0001-1.0-dtype1-1e-08-1e-06-2-4096-32000] PASSED [ 71%]
test/transformers/test_cross_entropy.py::test_correctness_with_z_loss_with_other_params_once[0.1-42-mean-True-0.0001-1.0-dtype1-1e-08-1e-06-3-423-32000] PASSED [ 72%]
test/transformers/test_cross_entropy.py::test_correctness_with_z_loss_with_other_params_once[0.1-42-mean-False-1e-05-1.0-dtype0-1e-08-0.05-2-4096-32000] PASSED [ 74%]
test/transformers/test_cross_entropy.py::test_correctness_with_z_loss_with_other_params_once[0.1-42-mean-False-1e-05-1.0-dtype0-1e-08-0.05-3-423-32000] PASSED [ 75%]
test/transformers/test_cross_entropy.py::test_correctness_with_z_loss_with_other_params_once[0.1-42-mean-False-1e-05-1.0-dtype1-1e-08-1e-06-2-4096-32000] PASSED [ 76%]
test/transformers/test_cross_entropy.py::test_correctness_with_z_loss_with_other_params_once[0.1-42-mean-False-1e-05-1.0-dtype1-1e-08-1e-06-3-423-32000] PASSED [ 77%]
test/transformers/test_cross_entropy.py::test_correctness_with_z_loss_with_other_params_once[0.2--42-sum-True-0.0001-1.0-dtype0-1e-08-0.05-2-4096-32000] PASSED [ 79%]
test/transformers/test_cross_entropy.py::test_correctness_with_z_loss_with_other_params_once[0.2--42-sum-True-0.0001-1.0-dtype0-1e-08-0.05-3-423-32000] PASSED [ 80%]
test/transformers/test_cross_entropy.py::test_correctness_with_z_loss_with_other_params_once[0.2--42-sum-True-0.0001-1.0-dtype1-1e-08-1e-06-2-4096-32000] PASSED [ 81%]
test/transformers/test_cross_entropy.py::test_correctness_with_z_loss_with_other_params_once[0.2--42-sum-True-0.0001-1.0-dtype1-1e-08-1e-06-3-423-32000] PASSED [ 83%]
test/transformers/test_cross_entropy.py::test_correctness_with_z_loss_with_other_params_once[0.2--42-sum-False-1e-05-1.0-dtype0-1e-08-0.05-2-4096-32000] PASSED [ 84%]
test/transformers/test_cross_entropy.py::test_correctness_with_z_loss_with_other_params_once[0.2--42-sum-False-1e-05-1.0-dtype0-1e-08-0.05-3-423-32000] PASSED [ 85%]
test/transformers/test_cross_entropy.py::test_correctness_with_z_loss_with_other_params_once[0.2--42-sum-False-1e-05-1.0-dtype1-1e-08-1e-06-2-4096-32000] PASSED [ 87%]
test/transformers/test_cross_entropy.py::test_correctness_with_z_loss_with_other_params_once[0.2--42-sum-False-1e-05-1.0-dtype1-1e-08-1e-06-3-423-32000] PASSED [ 88%]
test/transformers/test_cross_entropy.py::test_correctness_not_last_layer[1.0-dtype0-1e-08-0.05-sum-2-4096-32000] PASSED [ 89%]
test/transformers/test_cross_entropy.py::test_correctness_not_last_layer[1.0-dtype0-1e-08-0.05-sum-3-423-32000] PASSED [ 90%]
test/transformers/test_cross_entropy.py::test_correctness_not_last_layer[1.0-dtype0-1e-08-0.05-mean-2-4096-32000] PASSED [ 92%]
test/transformers/test_cross_entropy.py::test_correctness_not_last_layer[1.0-dtype0-1e-08-0.05-mean-3-423-32000] PASSED [ 93%]
test/transformers/test_cross_entropy.py::test_correctness_not_last_layer[1.0-dtype1-1e-08-1e-06-sum-2-4096-32000] PASSED [ 94%]
test/transformers/test_cross_entropy.py::test_correctness_not_last_layer[1.0-dtype1-1e-08-1e-06-sum-3-423-32000] PASSED [ 96%]
test/transformers/test_cross_entropy.py::test_correctness_not_last_layer[1.0-dtype1-1e-08-1e-06-mean-2-4096-32000] PASSED [ 97%]
test/transformers/test_cross_entropy.py::test_correctness_not_last_layer[1.0-dtype1-1e-08-1e-06-mean-3-423-32000] PASSED [ 98%]
test/transformers/test_cross_entropy.py::test_float32_internal PASSED [100%]
=============================== warnings summary ===============================
../../usr/local/lib/python3.12/site-packages/_pytest/config/__init__.py:1441
/usr/local/lib/python3.12/site-packages/_pytest/config/__init__.py:1441: PytestConfigWarning: Unknown config option: asyncio_mode
self._warn_or_fail_if_strict(f"Unknown config option: {key}\n")
-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
======================== 77 passed, 1 warning in 29.23s ========================
```
- Hardware Type: A10G
- [X] run `make test` to ensure correctness
- [X] run `make checkstyle` to ensure code style
- [ ] run `make test-convergence` to ensure convergence
---------
Signed-off-by: Austin Liu <[email protected]>1 parent a8d55fb commit 439fe1cCopy full SHA for 439fe1c
File tree
Expand file treeCollapse file tree
1 file changed
+8
-4
lines changedFilter options
- test/transformers
Expand file treeCollapse file tree
1 file changed
+8
-4
lines changedtest/transformers/test_cross_entropy.py
Copy file name to clipboardExpand all lines: test/transformers/test_cross_entropy.py+8-4Lines changed: 8 additions & 4 deletions
Original file line number | Diff line number | Diff line change | |
---|---|---|---|
| |||
184 | 184 |
| |
185 | 185 |
| |
186 | 186 |
| |
187 |
| - | |
188 |
| - | |
| 187 | + | |
189 | 188 |
| |
190 | 189 |
| |
191 | 190 |
| |
192 | 191 |
| |
193 |
| - | |
194 |
| - | |
| 192 | + | |
| 193 | + | |
| 194 | + | |
| 195 | + | |
| 196 | + | |
195 | 197 |
| |
196 | 198 |
| |
197 | 199 |
| |
198 | 200 |
| |
199 | 201 |
| |
200 | 202 |
| |
201 | 203 |
| |
| 204 | + | |
| 205 | + | |
202 | 206 |
| |
203 | 207 |
| |
204 | 208 |
| |
|
0 commit comments