Skip to content

Commit 5156018

Browse files
shivam15slancerts
andauthored
add out of bounds check to cross entropy (#588)
## Summary <!--- This is a required section; please describe the main purpose of this proposed code change. ---> Same as title <!--- ## Details This is an optional section; is there anything specific that reviewers should be aware of? ---> ## Testing Done <!--- This is a required section; please describe how this change was tested. ---> <!-- Replace BLANK with your device type. For example, A100-80G-PCIe Complete the following tasks before sending your PR, and replace `[ ]` with `[x]` to indicate you have done them. --> - Hardware Type: <BLANK> - [ ] run `make test` to ensure correctness - [ ] run `make checkstyle` to ensure code style - [ ] run `make test-convergence` to ensure convergence --------- Co-authored-by: Shao Tang <[email protected]>
1 parent 0c2203e commit 5156018

File tree

2 files changed

+43
-0
lines changed

2 files changed

+43
-0
lines changed

src/liger_kernel/ops/cross_entropy.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -285,6 +285,10 @@ def cross_entropy_forward(
285285

286286
target_mask = target != ignore_index
287287
n_non_ignore = target_mask.sum().item()
288+
assert (target * target_mask).max() < _input.shape[-1], (
289+
f"Target {target.max()} is out of bounds. Expected < {_input.shape[-1]}"
290+
)
291+
assert (target * target_mask).min() >= 0, f"Target {target.min()} is out of bounds. Expected >= 0"
288292
sum_non_ignore_weight = n_non_ignore
289293
weight_sum = 0.0
290294
if weight is not None:

test/transformers/test_cross_entropy.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -290,6 +290,32 @@ def _test_correctness_with_z_loss_with_other_params_once(
290290
assert_verbose_allclose(_input.grad, _input2.grad, atol=atol, rtol=rtol)
291291

292292

293+
def _test_correctness_with_out_of_bounds_target_once(target_ce, B, T, V, ignore_index):
294+
torch.manual_seed(0)
295+
296+
_tensor = torch.randn(B * T, V, device=device, dtype=torch.bfloat16)
297+
_input = _tensor.detach().clone().requires_grad_(True)
298+
target = torch.randint(0, V, (B * T,), device=device, dtype=torch.long)
299+
300+
# Assign some random number of elements as ignore_index
301+
num_elements_to_assign = torch.randint(
302+
1, B * T // 2, (1,)
303+
).item() # Random number of elements to set to ignore_index
304+
indices_to_assign = torch.randperm(B * T)[:num_elements_to_assign] # Randomly select indices
305+
target[indices_to_assign] = ignore_index
306+
307+
# Assign out of bounds target
308+
num_out_of_bounds = torch.randint(1, B * T // 2, (1,)).item()
309+
indices_to_assign = torch.randperm(B * T)[:num_out_of_bounds] # Randomly select indices
310+
target[indices_to_assign] = torch.randint(V, 2 * V, (num_out_of_bounds,)).to(device)
311+
312+
try:
313+
_ = target_ce(_input, target)
314+
assert False, "Should have thrown an error"
315+
except AssertionError as e:
316+
assert "out of bounds" in str(e)
317+
318+
293319
def _test_correctness_with_weight_once(target_ce, B, T, V, reduction, weight, scalar, dtype, atol, rtol):
294320
torch.manual_seed(0)
295321
torch_ce = CrossEntropyLoss(weight=weight, reduction=reduction)
@@ -916,3 +942,16 @@ def test_float32_internal():
916942

917943
torch.allclose(X_bf16, X_fp32.bfloat16())
918944
torch.allclose(loss_bf16, loss_fp32)
945+
946+
947+
@pytest.mark.parametrize(
948+
"B, T, V, ignore_index",
949+
[
950+
(2, 4096, 32000, 2),
951+
# weird shapes
952+
(3, 423, 32000, -123),
953+
],
954+
)
955+
def test_correctness_with_out_of_bounds_target_once(B, T, V, ignore_index):
956+
liger_ce = LigerCrossEntropyLoss(ignore_index=ignore_index)
957+
_test_correctness_with_out_of_bounds_target_once(liger_ce, B, T, V, ignore_index)

0 commit comments

Comments
 (0)