Skip to content

Commit 439fe1c

Browse files
authored
Enhance Cross Entropy Softcap Unit Test (#423)
## Summary Closes #418 - Add gradient check after `backward()`. - Defer type conversion and only upcast before the `tanh` operation. This keeps original tensor `dtype` during cloning. ## Testing Done ``` ============================= test session starts ============================== platform linux -- Python 3.12.1, pytest-8.3.3, pluggy-1.5.0 rootdir: /root/liger-kernel configfile: pyproject.toml plugins: anyio-4.2.0, typeguard-4.1.5 collected 77 items test/transformers/test_cross_entropy.py::test_correctness[1.0-dtype0-1e-08-0.05-sum-2-4096-32000] PASSED [ 1%] test/transformers/test_cross_entropy.py::test_correctness[1.0-dtype0-1e-08-0.05-sum-3-423-32000] PASSED [ 2%] test/transformers/test_cross_entropy.py::test_correctness[1.0-dtype0-1e-08-0.05-mean-2-4096-32000] PASSED [ 3%] test/transformers/test_cross_entropy.py::test_correctness[1.0-dtype0-1e-08-0.05-mean-3-423-32000] PASSED [ 5%] test/transformers/test_cross_entropy.py::test_correctness[1.0-dtype1-1e-08-1e-06-sum-2-4096-32000] PASSED [ 6%] test/transformers/test_cross_entropy.py::test_correctness[1.0-dtype1-1e-08-1e-06-sum-3-423-32000] PASSED [ 7%] test/transformers/test_cross_entropy.py::test_correctness[1.0-dtype1-1e-08-1e-06-mean-2-4096-32000] PASSED [ 9%] test/transformers/test_cross_entropy.py::test_correctness[1.0-dtype1-1e-08-1e-06-mean-3-423-32000] PASSED [ 10%] test/transformers/test_cross_entropy.py::test_correctness_functional[1.0-dtype0-1e-08-0.05-2-2-8] PASSED [ 11%] test/transformers/test_cross_entropy.py::test_correctness_functional[1.0-dtype0-1e-08-0.05-9-7-41] PASSED [ 12%] test/transformers/test_cross_entropy.py::test_correctness_functional[1.0-dtype1-1e-08-1e-06-2-2-8] PASSED [ 14%] test/transformers/test_cross_entropy.py::test_correctness_functional[1.0-dtype1-1e-08-1e-06-9-7-41] PASSED [ 15%] test/transformers/test_cross_entropy.py::test_correctness_with_ignore_index[1.0-dtype0-1e-08-0.05-sum-2-4096-32000-2] PASSED [ 16%] test/transformers/test_cross_entropy.py::test_correctness_with_ignore_index[1.0-dtype0-1e-08-0.05-sum-3-423-32000--123] PASSED [ 18%] test/transformers/test_cross_entropy.py::test_correctness_with_ignore_index[1.0-dtype0-1e-08-0.05-mean-2-4096-32000-2] PASSED [ 19%] test/transformers/test_cross_entropy.py::test_correctness_with_ignore_index[1.0-dtype0-1e-08-0.05-mean-3-423-32000--123] PASSED [ 20%] test/transformers/test_cross_entropy.py::test_correctness_with_ignore_index[1.0-dtype1-1e-08-1e-06-sum-2-4096-32000-2] PASSED [ 22%] test/transformers/test_cross_entropy.py::test_correctness_with_ignore_index[1.0-dtype1-1e-08-1e-06-sum-3-423-32000--123] PASSED [ 23%] test/transformers/test_cross_entropy.py::test_correctness_with_ignore_index[1.0-dtype1-1e-08-1e-06-mean-2-4096-32000-2] PASSED [ 24%] test/transformers/test_cross_entropy.py::test_correctness_with_ignore_index[1.0-dtype1-1e-08-1e-06-mean-3-423-32000--123] PASSED [ 25%] test/transformers/test_cross_entropy.py::test_correctness_with_label_smoothing_once[1.0-dtype0-1e-08-0.05-2-4096-32000-0.1] PASSED [ 27%] test/transformers/test_cross_entropy.py::test_correctness_with_label_smoothing_once[1.0-dtype0-1e-08-0.05-3-423-32000-0.1] PASSED [ 28%] test/transformers/test_cross_entropy.py::test_correctness_with_label_smoothing_once[1.0-dtype1-1e-08-1e-06-2-4096-32000-0.1] PASSED [ 29%] test/transformers/test_cross_entropy.py::test_correctness_with_label_smoothing_once[1.0-dtype1-1e-08-1e-06-3-423-32000-0.1] PASSED [ 31%] test/transformers/test_cross_entropy.py::test_correctness_with_label_smoothing_with_ignore_index_once[1.0-dtype0-1e-08-0.05-2-4096-32000-1-0.1] PASSED [ 32%] test/transformers/test_cross_entropy.py::test_correctness_with_label_smoothing_with_ignore_index_once[1.0-dtype0-1e-08-0.05-3-423-32000--300-0.2] PASSED [ 33%] test/transformers/test_cross_entropy.py::test_correctness_with_label_smoothing_with_ignore_index_once[1.0-dtype1-1e-08-1e-06-2-4096-32000-1-0.1] PASSED [ 35%] test/transformers/test_cross_entropy.py::test_correctness_with_label_smoothing_with_ignore_index_once[1.0-dtype1-1e-08-1e-06-3-423-32000--300-0.2] PASSED [ 36%] test/transformers/test_cross_entropy.py::test_correctness_with_softcap_once[1.0-dtype0-1e-08-0.05-sum-2-4096-32000-30.0] PASSED [ 37%] test/transformers/test_cross_entropy.py::test_correctness_with_softcap_once[1.0-dtype0-1e-08-0.05-sum-3-423-32000-30.0] PASSED [ 38%] test/transformers/test_cross_entropy.py::test_correctness_with_softcap_once[1.0-dtype0-1e-08-0.05-mean-2-4096-32000-30.0] PASSED [ 40%] test/transformers/test_cross_entropy.py::test_correctness_with_softcap_once[1.0-dtype0-1e-08-0.05-mean-3-423-32000-30.0] PASSED [ 41%] test/transformers/test_cross_entropy.py::test_correctness_with_softcap_once[1.0-dtype1-1e-08-1e-06-sum-2-4096-32000-30.0] PASSED [ 42%] test/transformers/test_cross_entropy.py::test_correctness_with_softcap_once[1.0-dtype1-1e-08-1e-06-sum-3-423-32000-30.0] PASSED [ 44%] test/transformers/test_cross_entropy.py::test_correctness_with_softcap_once[1.0-dtype1-1e-08-1e-06-mean-2-4096-32000-30.0] PASSED [ 45%] test/transformers/test_cross_entropy.py::test_correctness_with_softcap_once[1.0-dtype1-1e-08-1e-06-mean-3-423-32000-30.0] PASSED [ 46%] test/transformers/test_cross_entropy.py::test_correctness_with_z_loss_once[0.0001-True-1.0-dtype0-1e-08-0.05-2-4096-32000] PASSED [ 48%] test/transformers/test_cross_entropy.py::test_correctness_with_z_loss_once[0.0001-True-1.0-dtype0-1e-08-0.05-3-423-32000] PASSED [ 49%] test/transformers/test_cross_entropy.py::test_correctness_with_z_loss_once[0.0001-True-1.0-dtype1-1e-08-1e-06-2-4096-32000] PASSED [ 50%] test/transformers/test_cross_entropy.py::test_correctness_with_z_loss_once[0.0001-True-1.0-dtype1-1e-08-1e-06-3-423-32000] PASSED [ 51%] test/transformers/test_cross_entropy.py::test_correctness_with_z_loss_once[0.0001-False-1.0-dtype0-1e-08-0.05-2-4096-32000] PASSED [ 53%] test/transformers/test_cross_entropy.py::test_correctness_with_z_loss_once[0.0001-False-1.0-dtype0-1e-08-0.05-3-423-32000] PASS test/transformers/test_cross_entropy.py::test_correctness_with_z_loss_once[0.0001-False-1.0-dtype1-1e-08-1e-06-2-4096-32000] PASSED [ 55%] test/transformers/test_cross_entropy.py::test_correctness_with_z_loss_once[0.0001-False-1.0-dtype1-1e-08-1e-06-3-423-32000] PASSED [ 57%] test/transformers/test_cross_entropy.py::test_correctness_with_z_loss_once[1e-05-True-1.0-dtype0-1e-08-0.05-2-4096-32000] PASSED [ 58%] test/transformers/test_cross_entropy.py::test_correctness_with_z_loss_once[1e-05-True-1.0-dtype0-1e-08-0.05-3-423-32000] PASSED [ 59%] test/transformers/test_cross_entropy.py::test_correctness_with_z_loss_once[1e-05-True-1.0-dtype1-1e-08-1e-06-2-4096-32000] PASSED [ 61%] test/transformers/test_cross_entropy.py::test_correctness_with_z_loss_once[1e-05-True-1.0-dtype1-1e-08-1e-06-3-423-32000] PASSED [ 62%] test/transformers/test_cross_entropy.py::test_correctness_with_z_loss_once[1e-05-False-1.0-dtype0-1e-08-0.05-2-4096-32000] PASSED [ 63%] test/transformers/test_cross_entropy.py::test_correctness_with_z_loss_once[1e-05-False-1.0-dtype0-1e-08-0.05-3-423-32000] PASSED [ 64%] test/transformers/test_cross_entropy.py::test_correctness_with_z_loss_once[1e-05-False-1.0-dtype1-1e-08-1e-06-2-4096-32000] PASSED [ 66%] test/transformers/test_cross_entropy.py::test_correctness_with_z_loss_once[1e-05-False-1.0-dtype1-1e-08-1e-06-3-423-32000] PASSED [ 67%] test/transformers/test_cross_entropy.py::test_correctness_with_z_loss_with_other_params_once[0.1-42-mean-True-0.0001-1.0-dtype0-1e-08-0.05-2-4096-32000] PASSED [ 68%] test/transformers/test_cross_entropy.py::test_correctness_with_z_loss_with_other_params_once[0.1-42-mean-True-0.0001-1.0-dtype0-1e-08-0.05-3-423-32000] PASSED [ 70%] test/transformers/test_cross_entropy.py::test_correctness_with_z_loss_with_other_params_once[0.1-42-mean-True-0.0001-1.0-dtype1-1e-08-1e-06-2-4096-32000] PASSED [ 71%] test/transformers/test_cross_entropy.py::test_correctness_with_z_loss_with_other_params_once[0.1-42-mean-True-0.0001-1.0-dtype1-1e-08-1e-06-3-423-32000] PASSED [ 72%] test/transformers/test_cross_entropy.py::test_correctness_with_z_loss_with_other_params_once[0.1-42-mean-False-1e-05-1.0-dtype0-1e-08-0.05-2-4096-32000] PASSED [ 74%] test/transformers/test_cross_entropy.py::test_correctness_with_z_loss_with_other_params_once[0.1-42-mean-False-1e-05-1.0-dtype0-1e-08-0.05-3-423-32000] PASSED [ 75%] test/transformers/test_cross_entropy.py::test_correctness_with_z_loss_with_other_params_once[0.1-42-mean-False-1e-05-1.0-dtype1-1e-08-1e-06-2-4096-32000] PASSED [ 76%] test/transformers/test_cross_entropy.py::test_correctness_with_z_loss_with_other_params_once[0.1-42-mean-False-1e-05-1.0-dtype1-1e-08-1e-06-3-423-32000] PASSED [ 77%] test/transformers/test_cross_entropy.py::test_correctness_with_z_loss_with_other_params_once[0.2--42-sum-True-0.0001-1.0-dtype0-1e-08-0.05-2-4096-32000] PASSED [ 79%] test/transformers/test_cross_entropy.py::test_correctness_with_z_loss_with_other_params_once[0.2--42-sum-True-0.0001-1.0-dtype0-1e-08-0.05-3-423-32000] PASSED [ 80%] test/transformers/test_cross_entropy.py::test_correctness_with_z_loss_with_other_params_once[0.2--42-sum-True-0.0001-1.0-dtype1-1e-08-1e-06-2-4096-32000] PASSED [ 81%] test/transformers/test_cross_entropy.py::test_correctness_with_z_loss_with_other_params_once[0.2--42-sum-True-0.0001-1.0-dtype1-1e-08-1e-06-3-423-32000] PASSED [ 83%] test/transformers/test_cross_entropy.py::test_correctness_with_z_loss_with_other_params_once[0.2--42-sum-False-1e-05-1.0-dtype0-1e-08-0.05-2-4096-32000] PASSED [ 84%] test/transformers/test_cross_entropy.py::test_correctness_with_z_loss_with_other_params_once[0.2--42-sum-False-1e-05-1.0-dtype0-1e-08-0.05-3-423-32000] PASSED [ 85%] test/transformers/test_cross_entropy.py::test_correctness_with_z_loss_with_other_params_once[0.2--42-sum-False-1e-05-1.0-dtype1-1e-08-1e-06-2-4096-32000] PASSED [ 87%] test/transformers/test_cross_entropy.py::test_correctness_with_z_loss_with_other_params_once[0.2--42-sum-False-1e-05-1.0-dtype1-1e-08-1e-06-3-423-32000] PASSED [ 88%] test/transformers/test_cross_entropy.py::test_correctness_not_last_layer[1.0-dtype0-1e-08-0.05-sum-2-4096-32000] PASSED [ 89%] test/transformers/test_cross_entropy.py::test_correctness_not_last_layer[1.0-dtype0-1e-08-0.05-sum-3-423-32000] PASSED [ 90%] test/transformers/test_cross_entropy.py::test_correctness_not_last_layer[1.0-dtype0-1e-08-0.05-mean-2-4096-32000] PASSED [ 92%] test/transformers/test_cross_entropy.py::test_correctness_not_last_layer[1.0-dtype0-1e-08-0.05-mean-3-423-32000] PASSED [ 93%] test/transformers/test_cross_entropy.py::test_correctness_not_last_layer[1.0-dtype1-1e-08-1e-06-sum-2-4096-32000] PASSED [ 94%] test/transformers/test_cross_entropy.py::test_correctness_not_last_layer[1.0-dtype1-1e-08-1e-06-sum-3-423-32000] PASSED [ 96%] test/transformers/test_cross_entropy.py::test_correctness_not_last_layer[1.0-dtype1-1e-08-1e-06-mean-2-4096-32000] PASSED [ 97%] test/transformers/test_cross_entropy.py::test_correctness_not_last_layer[1.0-dtype1-1e-08-1e-06-mean-3-423-32000] PASSED [ 98%] test/transformers/test_cross_entropy.py::test_float32_internal PASSED [100%] =============================== warnings summary =============================== ../../usr/local/lib/python3.12/site-packages/_pytest/config/__init__.py:1441 /usr/local/lib/python3.12/site-packages/_pytest/config/__init__.py:1441: PytestConfigWarning: Unknown config option: asyncio_mode self._warn_or_fail_if_strict(f"Unknown config option: {key}\n") -- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html ======================== 77 passed, 1 warning in 29.23s ======================== ``` - Hardware Type: A10G - [X] run `make test` to ensure correctness - [X] run `make checkstyle` to ensure code style - [ ] run `make test-convergence` to ensure convergence --------- Signed-off-by: Austin Liu <[email protected]>
1 parent a8d55fb commit 439fe1c

File tree

1 file changed

+8
-4
lines changed

1 file changed

+8
-4
lines changed

test/transformers/test_cross_entropy.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -184,21 +184,25 @@ def _test_correctness_with_softcap_once(
184184
torch_ce = CrossEntropyLoss(reduction=reduction)
185185

186186
_tensor = torch.randn(B * T, V, device=device, dtype=dtype) * scalar
187-
# upcasting to match liger's casting strategy
188-
_input = _tensor.to(torch.float32).detach().clone().requires_grad_(True)
187+
_input = _tensor.detach().clone().requires_grad_(True)
189188
_input2 = _tensor.detach().clone().requires_grad_(True)
190189

191190
target = torch.randint(0, V, (B * T,), device=device, dtype=torch.long)
192191

193-
# downcasting to original dtype
194-
output = torch_ce(softcap * torch.tanh(_input / softcap), target).to(dtype)
192+
# upcasting to match liger's casting strategy
193+
# and downcasting to original dtype
194+
output = torch_ce(
195+
softcap * torch.tanh(_input.to(torch.float32) / softcap), target
196+
).to(dtype)
195197
output2 = target_ce(_input2, target)
196198

197199
assert torch.allclose(output, output2, atol=atol, rtol=rtol)
198200

199201
output.backward()
200202
output2.backward()
201203

204+
assert torch.allclose(_input.grad, _input2.grad, atol=atol, rtol=rtol)
205+
202206

203207
def _test_correctness_with_z_loss_once(
204208
target_ce,

0 commit comments

Comments
 (0)