Skip to content

Commit 9bc0836

Browse files
fix(offline_pipeline): ILQL negative indexing under truncation (#435)
* fix(offline_pipeline): prepend `is_output=False` msg when truncated * fix(test_pipelines): specify `truncation_side` * style * fix(offline_pipeline): prepend starting <bos> under truncation * fix(test_pipelines): update tests for the truncation change * docs(offline_pipeline): update `tokenize_dialogue` type signature
1 parent 2318d04 commit 9bc0836

File tree

2 files changed

+41
-15
lines changed

2 files changed

+41
-15
lines changed

tests/test_pipelines.py

Lines changed: 29 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,19 @@ class TestTokenizeDialog(TestCase):
1111
def setUp(self):
1212
self.tokenizer = AutoTokenizer.from_pretrained("gpt2")
1313

14+
def test_tokenize_dialogue_truncation_basic(self):
15+
dialogue = ["this will be truncated", "."]
16+
self.tokenizer.truncation_side = "left"
17+
18+
dialog = tokenize_dialogue(dialogue, self.tokenizer, max_length=2)
19+
20+
assert len(dialog) == 2
21+
user_dm, bot_dm = dialog
22+
assert len(user_dm.tokens) == 1
23+
assert len(bot_dm.tokens) == 1
24+
assert user_dm == DialogMessage(is_output=False, tokens=(self.tokenizer.bos_token_id,))
25+
assert bot_dm == DialogMessage(is_output=True, tokens=(self.tokenizer.eos_token_id,))
26+
1427
@given(st.lists(st.text(), max_size=32))
1528
def test_tokenize_dialogue_single_turn(self, response_words):
1629
response = " ".join(response_words) # space seperate to make it multiple tokens
@@ -46,20 +59,18 @@ def test_tokenize_dialogue_single_turn_truncation_left(self, response_words, max
4659
response = " ".join(response_words) # space seperate to make it multiple tokens
4760
self.tokenizer.truncation_side = "left"
4861
tokenized_response = tuple(self.tokenizer(response, add_special_tokens=False).input_ids)
49-
tokenized_response = tokenized_response + (self.tokenizer.eos_token_id,)
62+
tokenized_response += (self.tokenizer.eos_token_id,)
5063
dialog = tokenize_dialogue(response, self.tokenizer, max_length=max_length)
5164

52-
# if no truncation should have happened, then the user BOS prompt should be present
53-
if len(tokenized_response) + 1 <= max_length:
54-
assert len(dialog) == 2
55-
user_dm, bot_dm = dialog
65+
# whether or not truncation has happened, user BOS prompt should be present
66+
assert len(dialog) == 2
67+
user_dm, bot_dm = dialog
68+
assert user_dm == DialogMessage(is_output=False, tokens=(self.tokenizer.bos_token_id,))
5669

57-
assert user_dm == DialogMessage(is_output=False, tokens=(self.tokenizer.bos_token_id,))
70+
if len(tokenized_response) < max_length:
5871
assert bot_dm == DialogMessage(is_output=True, tokens=tokenized_response)
5972
else:
60-
assert len(dialog) == 1
61-
bot_dm = dialog[0]
62-
assert bot_dm == DialogMessage(is_output=True, tokens=tokenized_response[-max_length:])
73+
assert bot_dm == DialogMessage(is_output=True, tokens=tokenized_response[-max_length + 1 :])
6374

6475
all_tokens = sum((dm.tokens for dm in dialog), ())
6576
assert len(all_tokens) <= max_length
@@ -76,6 +87,9 @@ def test_tokenize_dialogue_multi_turn(self, user_response_pairs):
7687

7788
dm_convo = [DialogMessage(is_output=i % 2 == 1, tokens=tokens) for i, tokens in enumerate(tokenized_flat_convo)]
7889
nonempty_dm_convo = [dm for dm in dm_convo if dm.tokens]
90+
if nonempty_dm_convo[0].is_output:
91+
nonempty_dm_convo.insert(0, DialogMessage(is_output=False, tokens=(self.tokenizer.eos_token_id,)))
92+
7993
assert dialog == nonempty_dm_convo
8094

8195
@given(st.lists(st.tuples(st.text(), st.text()), min_size=1, max_size=32), st.integers(min_value=2, max_value=16))
@@ -91,6 +105,9 @@ def test_tokenize_dialogue_multi_turn_truncation_right(self, user_response_pairs
91105

92106
all_tokens = sum((dm.tokens for dm in dialog), ())
93107
should_be_tokens = sum(tokenized_flat_convo, ())[:max_length]
108+
if dialog[0] == DialogMessage(is_output=False, tokens=(self.tokenizer.eos_token_id,)):
109+
should_be_tokens = (self.tokenizer.eos_token_id, *should_be_tokens[: max_length - 1])
110+
94111
assert all_tokens == should_be_tokens
95112
assert len(all_tokens) <= max_length
96113

@@ -106,8 +123,9 @@ def test_tokenize_dialogue_multi_turn_truncation_left(self, user_response_pairs,
106123
dialog = tokenize_dialogue(flat_convo, self.tokenizer, max_length=max_length)
107124

108125
all_tokens = sum((dm.tokens for dm in dialog), ())
109-
110126
should_be_tokens = sum(tokenized_flat_convo, ())[-max_length:]
111-
assert all_tokens == should_be_tokens
127+
if dialog[0] == DialogMessage(is_output=False, tokens=(self.tokenizer.eos_token_id,)):
128+
should_be_tokens = (self.tokenizer.eos_token_id, *should_be_tokens[-max_length + 1 :])
112129

130+
assert all_tokens == should_be_tokens
113131
assert len(all_tokens) <= max_length

trlx/pipeline/offline_pipeline.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,15 +26,15 @@ class DialogMessage:
2626

2727

2828
def tokenize_dialogue( # noqa: C901
29-
dialogue: Union[str, List[str]], tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], max_length=2048
29+
dialogue: Union[str, Iterable[str]], tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], max_length=2048
3030
) -> List[DialogMessage]:
3131
"""
3232
Tokenize sample with the interleaved form of (prompt_1, output_1, prompt_2, output_2...)
3333
"""
3434
if isinstance(dialogue, str):
3535
bos_token = tokenizer.bos_token or tokenizer.eos_token
3636
dialogue = [bos_token, dialogue]
37-
elif isinstance(dialogue, tuple):
37+
elif isinstance(dialogue, Iterable):
3838
if len(dialogue) % 2 != 0:
3939
raise ValueError("Dialogue must have an even number of phrases, alternating prompt and output")
4040
dialogue = list(dialogue)
@@ -64,9 +64,17 @@ def tokenize_dialogue( # noqa: C901
6464
truncated = [DialogMessage(is_output=m.is_output, tokens=m.tokens[::-1]) for m in truncated[::-1]]
6565

6666
# remove empty messages
67-
truncated = [t for t in truncated if len(t.tokens) > 0]
67+
out = [t for t in truncated if len(t.tokens) > 0]
6868

69-
return truncated
69+
if out[0].is_output:
70+
if sum(map(lambda msg: len(msg.tokens), out)) == max_length:
71+
if tokenizer.truncation_side == "left":
72+
out[0].tokens = out[0].tokens[1:]
73+
else:
74+
out[-1].tokens = out[-1].tokens[:-1]
75+
76+
out.insert(0, DialogMessage(False, (tokenizer.bos_token_id,)))
77+
return out
7078

7179

7280
class DialogStore(BaseRolloutStore):

0 commit comments

Comments
 (0)