Skip to content

Commit 8e8e62b

Browse files
✂️ [DPO] Fix truncation keep_end leading to zero'd out samples (#3398)
Co-authored-by: Quentin Gallouédec <[email protected]> Co-authored-by: Quentin Gallouédec <[email protected]>
1 parent 824100c commit 8e8e62b

File tree

3 files changed

+120
-36
lines changed

3 files changed

+120
-36
lines changed

tests/test_utils.py

Lines changed: 48 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
batch_generation,
3333
decode_and_strip_padding,
3434
flush_left,
35+
flush_right,
3536
generate_model_card,
3637
get_peft_config,
3738
pad,
@@ -474,22 +475,64 @@ def test_single_row(self):
474475
self.assertTrue(torch.equal(new_tensor1, expected_tensor1))
475476

476477
def test_no_shift_needed(self):
477-
mask = torch.tensor([[1, 1, 0, 0], [1, 1, 0, 0]])
478-
tensor1 = torch.tensor([[5, 6, 0, 0], [7, 8, 0, 0]])
478+
mask = torch.tensor([[1, 1, 0, 0], [1, 0, 0, 0]])
479+
tensor1 = torch.tensor([[5, 6, 0, 0], [7, 0, 0, 0]])
479480
new_mask, new_tensor1 = flush_left(mask, tensor1)
480481

481-
expected_mask = torch.tensor([[1, 1], [1, 1]])
482-
expected_tensor1 = torch.tensor([[5, 6], [7, 8]])
482+
expected_mask = torch.tensor([[1, 1], [1, 0]])
483+
expected_tensor1 = torch.tensor([[5, 6], [7, 0]])
483484

484485
self.assertTrue(torch.equal(new_mask, expected_mask))
485486
self.assertTrue(torch.equal(new_tensor1, expected_tensor1))
486487

487488
def test_no_tensors(self):
488489
mask = torch.tensor([[0, 0, 1, 1, 1], [0, 1, 1, 0, 0]])
489490
new_mask = flush_left(mask)
490-
491491
expected_mask = torch.tensor([[1, 1, 1], [1, 1, 0]])
492+
self.assertTrue(torch.equal(new_mask, expected_mask))
492493

494+
495+
class TestFlushRight(unittest.TestCase):
496+
def test_basic_case(self):
497+
mask = torch.tensor([[1, 1, 1, 0, 0], [0, 0, 1, 1, 0]])
498+
tensor1 = torch.tensor([[2, 3, 4, 0, 0], [0, 0, 5, 6, 0]])
499+
tensor2 = torch.tensor([[7, 8, 9, 0, 0], [0, 0, 10, 11, 0]])
500+
new_mask, new_tensor1, new_tensor2 = flush_right(mask, tensor1, tensor2)
501+
502+
expected_mask = torch.tensor([[1, 1, 1], [0, 1, 1]])
503+
expected_tensor1 = torch.tensor([[2, 3, 4], [0, 5, 6]])
504+
expected_tensor2 = torch.tensor([[7, 8, 9], [0, 10, 11]])
505+
506+
self.assertTrue(torch.equal(new_mask, expected_mask))
507+
self.assertTrue(torch.equal(new_tensor1, expected_tensor1))
508+
self.assertTrue(torch.equal(new_tensor2, expected_tensor2))
509+
510+
def test_single_row(self):
511+
mask = torch.tensor([[1, 1, 0, 0]])
512+
tensor1 = torch.tensor([[2, 3, 0, 0]])
513+
new_mask, new_tensor1 = flush_right(mask, tensor1)
514+
515+
expected_mask = torch.tensor([[1, 1]])
516+
expected_tensor1 = torch.tensor([[2, 3]])
517+
518+
self.assertTrue(torch.equal(new_mask, expected_mask))
519+
self.assertTrue(torch.equal(new_tensor1, expected_tensor1))
520+
521+
def test_no_shift_needed(self):
522+
mask = torch.tensor([[0, 0, 1, 1], [0, 0, 0, 1]])
523+
tensor1 = torch.tensor([[0, 0, 5, 6], [0, 0, 0, 7]])
524+
new_mask, new_tensor1 = flush_right(mask, tensor1)
525+
526+
expected_mask = torch.tensor([[1, 1], [0, 1]])
527+
expected_tensor1 = torch.tensor([[5, 6], [0, 7]])
528+
529+
self.assertTrue(torch.equal(new_mask, expected_mask))
530+
self.assertTrue(torch.equal(new_tensor1, expected_tensor1))
531+
532+
def test_no_tensors(self):
533+
mask = torch.tensor([[1, 1, 1, 0, 0], [0, 0, 1, 1, 0]])
534+
new_mask = flush_right(mask)
535+
expected_mask = torch.tensor([[1, 1, 1], [0, 1, 1]])
493536
self.assertTrue(torch.equal(new_mask, expected_mask))
494537

495538

trl/trainer/dpo_trainer.py

Lines changed: 22 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@
6262
disable_dropout_in_model,
6363
empty_cache,
6464
flush_left,
65+
flush_right,
6566
generate_model_card,
6667
get_comet_experiment_url,
6768
log_table_to_comet_experiment,
@@ -1124,26 +1125,35 @@ def concatenated_forward(
11241125
dim=1,
11251126
)
11261127

1127-
# Flush left to reduce the memory usage
1128-
# [[0, 0, x, x, x, x], -> [[x, x, x, x],
1129-
# [0, x, x, x, 0, 0]] [x, x, x, 0]]
1130-
attention_mask, input_ids, loss_mask = flush_left(attention_mask, input_ids, loss_mask)
1131-
1132-
# Truncate right
1133-
if self.max_length is not None:
1134-
if self.truncation_mode == "keep_end":
1128+
# Flush and truncate
1129+
if self.max_length is not None and self.max_length < attention_mask.size(1):
1130+
if self.truncation_mode == "keep_start":
1131+
# Flush left to reduce the memory usage
1132+
# [[0, 0, x, x, x, x], -> [[x, x, x, x],
1133+
# [0, x, x, x, 0, 0]] [x, x, x, 0]]
1134+
attention_mask, input_ids, loss_mask = flush_left(attention_mask, input_ids, loss_mask)
1135+
attention_mask = attention_mask[:, : self.max_length]
1136+
input_ids = input_ids[:, : self.max_length]
1137+
loss_mask = loss_mask[:, : self.max_length]
1138+
elif self.truncation_mode == "keep_end":
1139+
# Flush right before truncating left, then flush left
1140+
# [[0, 0, x, x, x, x], -> [[0, 0, x, x],
1141+
# [0, x, x, x, 0, 0]] [0, x, x, x]]
1142+
attention_mask, input_ids, loss_mask = flush_right(attention_mask, input_ids, loss_mask)
11351143
input_ids = input_ids[:, -self.max_length :]
11361144
attention_mask = attention_mask[:, -self.max_length :]
11371145
loss_mask = loss_mask[:, -self.max_length :]
1138-
elif self.truncation_mode == "keep_start":
1139-
input_ids = input_ids[:, : self.max_length]
1140-
attention_mask = attention_mask[:, : self.max_length]
1141-
loss_mask = loss_mask[:, : self.max_length]
1146+
attention_mask, input_ids, loss_mask = flush_left(attention_mask, input_ids, loss_mask)
11421147
else:
11431148
raise ValueError(
11441149
f"Unknown truncation mode: '{self.truncation_mode}'. Should be one of ['keep_end', "
11451150
"'keep_start']."
11461151
)
1152+
else:
1153+
# Flush left to reduce the memory usage
1154+
# [[0, 0, x, x, x, x], -> [[x, x, x, x],
1155+
# [0, x, x, x, 0, 0]] [x, x, x, 0]]
1156+
attention_mask, input_ids, loss_mask = flush_left(attention_mask, input_ids, loss_mask)
11471157

11481158
if self.use_logits_to_keep:
11491159
# Compute logits_to_keep based on loss_mask pattern:

trl/trainer/utils.py

Lines changed: 50 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1618,7 +1618,7 @@ def log_table_to_comet_experiment(name: str, table: pd.DataFrame) -> None:
16181618
experiment.log_table(tabular_data=table, filename=name)
16191619

16201620

1621-
def flush_left(mask: torch.Tensor, *tensors: torch.Tensor) -> tuple[torch.Tensor, ...]:
1621+
def flush_left(mask: torch.Tensor, *tensors: torch.Tensor) -> Union[torch.Tensor, tuple[torch.Tensor, ...]]:
16221622
"""
16231623
Shift non-zero elements in the mask and corresponding tensors to the left.
16241624
@@ -1660,28 +1660,59 @@ def flush_left(mask: torch.Tensor, *tensors: torch.Tensor) -> tuple[torch.Tensor
16601660
[5, 6, 0]])
16611661
```
16621662
"""
1663+
_, M = mask.shape
1664+
16631665
# Create copy of mask and tensors
1664-
mask = mask.clone()
1666+
mask_copy = mask.clone()
16651667
tensors = [t.clone() for t in tensors]
16661668

16671669
# Shift non-zero values to the left
1668-
for i in range(mask.size(0)):
1669-
first_one_idx = torch.nonzero(mask[i])[0].item()
1670-
mask[i] = torch.roll(mask[i], shifts=-first_one_idx)
1671-
for tensor in tensors:
1672-
tensor[i] = torch.roll(tensor[i], shifts=-first_one_idx)
1673-
1674-
# Get the first column idx that is all zeros and remove every column after that
1675-
empty_cols = torch.sum(mask, dim=0) == 0
1676-
first_empty_col = torch.nonzero(empty_cols)[0].item() if empty_cols.any() else mask.size(1)
1677-
mask = mask[:, :first_empty_col]
1678-
for i, tensor in enumerate(tensors):
1679-
tensors[i] = tensor[:, :first_empty_col]
1680-
1681-
if not tensors:
1682-
return mask
1683-
else:
1684-
return mask, *tensors
1670+
first_non_zero = mask_copy.argmax(dim=1)
1671+
pos = torch.arange(M, device=mask_copy.device).unsqueeze(0)
1672+
idx_roll = (pos + first_non_zero.unsqueeze(1)) % M
1673+
mask_roll = mask_copy.gather(1, idx_roll)
1674+
rolled_tensors = [t.gather(1, idx_roll) for t in tensors]
1675+
1676+
# Truncate trailing columns that are all zeros in mask_roll
1677+
col_sums = mask_roll.sum(dim=0)
1678+
empty_cols = col_sums == 0
1679+
first_empty_col = int(empty_cols.to(torch.int8).argmax()) if empty_cols.any() else M
1680+
flushed_mask = mask_roll[:, :first_empty_col]
1681+
flushed_tensors = [t[:, :first_empty_col] for t in rolled_tensors]
1682+
1683+
if not flushed_tensors:
1684+
return flushed_mask
1685+
return flushed_mask, *flushed_tensors
1686+
1687+
1688+
def flush_right(mask: torch.Tensor, *tensors: torch.Tensor) -> Union[torch.Tensor, tuple[torch.Tensor, ...]]:
1689+
"""
1690+
Shift non-zero elements in the mask and corresponding tensors to the right. See `flush_left` for details.
1691+
"""
1692+
_, M = mask.shape
1693+
1694+
# Create copy of mask and tensors
1695+
mask_copy = mask.clone()
1696+
tensors = [t.clone() for t in tensors]
1697+
1698+
# Shift non-zero values to the right
1699+
flipped_mask = torch.fliplr(mask_copy)
1700+
first_non_zero = flipped_mask.argmax(dim=1)
1701+
pos = torch.arange(M, device=mask_copy.device).unsqueeze(0)
1702+
idx_roll = (pos - first_non_zero.unsqueeze(1)) % M
1703+
mask_roll = mask_copy.gather(1, idx_roll)
1704+
rolled_tensors = [t.gather(1, idx_roll) for t in tensors]
1705+
1706+
# Truncate leading columns that are all zeros in mask_roll
1707+
col_sums = mask_roll.sum(dim=0)
1708+
non_empty_cols = col_sums != 0
1709+
first_non_empty_col = int(non_empty_cols.to(torch.int8).argmax()) if non_empty_cols.any() else M
1710+
flushed_mask = mask_roll[:, first_non_empty_col:]
1711+
flushed_tensors = [t[:, first_non_empty_col:] for t in rolled_tensors]
1712+
1713+
if not flushed_tensors:
1714+
return flushed_mask
1715+
return flushed_mask, *flushed_tensors
16851716

16861717

16871718
def selective_log_softmax(logits, index):

0 commit comments

Comments
 (0)