-
Notifications
You must be signed in to change notification settings - Fork 399
Closed
Labels
good first issueGood for newcomersGood for newcomers
Description
🐛 Describe the bug
Liger-Kernel/test/transformers/test_cross_entropy.py
Lines 188 to 201 in 7e0f459
_input = _tensor.to(torch.float32).detach().clone().requires_grad_(True) | |
_input2 = _tensor.detach().clone().requires_grad_(True) | |
target = torch.randint(0, V, (B * T,), device=device, dtype=torch.long) | |
# downcasting to original dtype | |
output = torch_ce(softcap * torch.tanh(_input / softcap), target).to(dtype) | |
output2 = target_ce(_input2, target) | |
assert torch.allclose(output, output2, atol=atol, rtol=rtol) | |
output.backward() | |
output2.backward() | |
There should be a
torch.allclose()
after backward()
.
to.(torch.float32)
in L188 can be moved to L194
Reproduce
No response
Versions
None
Metadata
Metadata
Assignees
Labels
good first issueGood for newcomersGood for newcomers