Skip to content

Commit b91da7b

Browse files
authored
Set add_special_tokens=False to not add EOS unexpectedly (#287)
Do not `add_special_tokens` when preprocessing prompts. This changes performance (reward) for the better for non-GPT models (which don't add EOS by default) due to not ending the prompt with `<endoftext>`. Although adding `<endofcontext>` (EOC, from HH) could be something to add in trhe future.
1 parent 81e935a commit b91da7b

File tree

6 files changed

+23
-64
lines changed

6 files changed

+23
-64
lines changed

examples/summarize_daily_cnn/t5_summarize_daily_cnn.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,14 +51,14 @@ def reward_fn(samples: List[str], prompts: List[str], outputs: List[str]):
5151

5252
for i in tqdm(range(len(prompts))):
5353
key = tokenizer.decode(
54-
tokenizer(prompts[i], truncation=True, max_length=max_length)["input_ids"],
54+
tokenizer(prompts[i], truncation=True, max_length=max_length, add_special_tokens=False)["input_ids"],
5555
skip_special_tokens=True,
5656
) # get prompt like trlx's prompt
5757
prompt_label[key.strip()] = summaries[i]
5858

5959
for i in tqdm(range(len(val_prompts))):
6060
key = tokenizer.decode(
61-
tokenizer(val_prompts[i], truncation=True, max_length=max_length)["input_ids"],
61+
tokenizer(val_prompts[i], truncation=True, max_length=max_length, add_special_tokens=False)["input_ids"],
6262
skip_special_tokens=True,
6363
) # get prompt like trlx's prompt
6464
prompt_label[key.strip()] = val_summaries[i]

examples/summarize_rlhf/trlx_gptj_text_summarization.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,12 +67,13 @@ def get_prompt_dataset(prompts, max_length):
6767
prompts[i].split("TL;DR:")[0],
6868
truncation=True,
6969
max_length=max_length - 5, # to make sure "TL;DR" dont get truncated
70+
add_special_tokens=False,
7071
)["input_ids"],
7172
skip_special_tokens=True,
7273
).strip()
7374
tmp = tmp + "\nTL;DR:"
7475
tmp = tokenizer.decode(
75-
tokenizer(tmp, truncation=True, max_length=max_length)["input_ids"],
76+
tokenizer(tmp, truncation=True, max_length=max_length, add_special_tokens=False)["input_ids"],
7677
skip_special_tokens=True,
7778
).strip()
7879
formatted_prompts.append(tmp)

trlx/pipeline/offline_pipeline.py

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import torch
44
from torch.nn.utils.rnn import pad_sequence
55
from torch.utils.data import DataLoader
6-
from transformers import DataCollatorWithPadding
6+
from transformers import DataCollatorWithPadding, PreTrainedTokenizer
77

88
from trlx.data.ilql_types import ILQLBatch, ILQLElement
99
from trlx.pipeline import BasePipeline, BaseRolloutStore, register_datapipeline
@@ -23,7 +23,8 @@ def tokenize_dialogue(dialogue: Union[str, List[str]], tokenizer, max_length=204
2323
ctx_length = max_length
2424
if tokenizer.truncation_side == "left":
2525
for phrase in reversed(dialogue):
26-
tokens = tokenizer(phrase).input_ids[-ctx_length:]
26+
# Manually added BOS and EOS above so we don't want to add special tokens here
27+
tokens = tokenizer(phrase, add_special_tokens=False).input_ids[-ctx_length:]
2728
ctx_length -= len(tokens)
2829
out.insert(0, tokens)
2930
if ctx_length == 0:
@@ -38,7 +39,8 @@ def tokenize_dialogue(dialogue: Union[str, List[str]], tokenizer, max_length=204
3839

3940
elif tokenizer.truncation_side == "right":
4041
for phrase in dialogue:
41-
tokens = tokenizer(phrase).input_ids[:ctx_length]
42+
# Manually added BOS and EOS above so we don't want to add special tokens here
43+
tokens = tokenizer(phrase, add_special_tokens=False).input_ids[:ctx_length]
4244
ctx_length -= len(tokens)
4345
out.append(tokens)
4446
if ctx_length == 0:
@@ -52,13 +54,20 @@ class PromptPipeline(BasePipeline):
5254
Tokenizes prompts, unless they are already tokenized, and truncates them to `max_prompt_length` from the right
5355
"""
5456

55-
def __init__(self, prompts: List[str], max_prompt_length: int, tokenizer=None):
57+
def __init__(self, prompts: List[str], max_prompt_length: int, tokenizer: PreTrainedTokenizer):
5658
super().__init__()
57-
model_inputs = tokenizer(prompts, truncation=True, padding=False, max_length=max_prompt_length)
58-
prompts = model_inputs["input_ids"]
59+
60+
model_inputs = tokenizer(
61+
prompts, truncation=True, padding=False, max_length=max_prompt_length, add_special_tokens=False
62+
)
63+
64+
prompts_tokens = model_inputs["input_ids"]
5965
attention_mask = model_inputs["attention_mask"]
66+
6067
self.tokenizer = tokenizer
61-
self.prompts = [{"input_ids": prompt, "attention_mask": mask} for prompt, mask in zip(prompts, attention_mask)]
68+
self.prompts = [
69+
{"input_ids": tokens, "attention_mask": mask} for tokens, mask in zip(prompts_tokens, attention_mask)
70+
]
6271

6372
def __getitem__(self, ix: int):
6473
return self.prompts[ix]

trlx/trainer/accelerate_base_trainer.py

Lines changed: 1 addition & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import sys
44
from abc import abstractmethod
55
from time import time
6-
from typing import Dict, List, Optional, Sequence, Tuple, Union
6+
from typing import Dict, List, Optional, Tuple
77

88
import ray
99
import torch
@@ -175,25 +175,6 @@ def setup_scheduler(self):
175175
scheduler = scheduler_class(self.opt, **self.config.scheduler.kwargs)
176176
return scheduler
177177

178-
def tokenize(self, text: Union[Sequence[str], Sequence[torch.LongTensor]]):
179-
"""
180-
Tokenize a batch of text after adding bos token to each of the samples
181-
"""
182-
if isinstance(text[0], torch.LongTensor):
183-
return text
184-
185-
text = [self.tokenizer.bos_token + txt for txt in text]
186-
return self.tokenizer(
187-
text,
188-
truncation=True,
189-
max_length=self.config.seq_length,
190-
return_tensors="pt",
191-
# NOTE: We manually add special tokens (bos) above so we set this False
192-
# to avoid models that automatically add special tokens (e.g. OPT)
193-
# adding them twice more.
194-
add_special_tokens=False,
195-
)
196-
197178
def decode(
198179
self,
199180
prompts: List[torch.LongTensor],

trlx/trainer/accelerate_ilql_trainer.py

Lines changed: 1 addition & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import os
2-
from typing import Optional, Sequence, Union, cast
2+
from typing import Optional, cast
33

44
import numpy as np
55
import torch
@@ -43,22 +43,6 @@ def get_arch(self, config):
4343
num_layers_unfrozen=config.model.num_layers_unfrozen,
4444
)
4545

46-
def tokenize(self, texts: Union[Sequence[str], Sequence[torch.LongTensor]]):
47-
if isinstance(texts[0], torch.LongTensor):
48-
return texts
49-
50-
tokenized = self.tokenizer(
51-
[self.tokenizer.bos_token + x + self.tokenizer.eos_token for x in texts],
52-
max_length=self.max_length,
53-
truncation=True,
54-
# NOTE: We manually add special tokens (bos) above so we set this False
55-
# to avoid models that automatically add special tokens (e.g. OPT)
56-
# adding them twice more.
57-
add_special_tokens=False,
58-
)
59-
input_ids = list(map(torch.as_tensor, tokenized.input_ids))
60-
return input_ids
61-
6246
def post_backward_callback(self):
6347
if self.iter_count % self.config.method.steps_for_target_q_sync == 0:
6448
self.accelerator.unwrap_model(self.model).sync_target_q_heads()

trlx/trainer/nemo_ilql_trainer.py

Lines changed: 1 addition & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from pathlib import Path
2-
from typing import Iterable, Sequence, Union, cast
2+
from typing import Iterable, Sequence, cast
33

44
import numpy as np
55
import torch
@@ -156,22 +156,6 @@ def __init__(
156156
if stop_sequences is not None and len(stop_sequences) > 0:
157157
logging.warning(f"Ignoring stop_sequences {stop_sequences=}")
158158

159-
def tokenize(self, texts: Union[Sequence[str], Sequence[torch.LongTensor]]):
160-
if isinstance(texts[0], torch.LongTensor):
161-
return texts
162-
163-
tokenized = self.tokenizer(
164-
[self.tokenizer.bos_token + x + self.tokenizer.eos_token for x in texts],
165-
max_length=self.max_length,
166-
truncation=True,
167-
# NOTE: We manually add special tokens (bos) above so we set this False
168-
# to avoid models that automatically add special tokens (e.g. OPT)
169-
# adding them twice more.
170-
add_special_tokens=False,
171-
)
172-
input_ids = list(map(torch.as_tensor, tokenized.input_ids))
173-
return input_ids
174-
175159
def learn(self):
176160
def collate_fn(elems: Iterable[ILQLElement]):
177161
batch = ilql_collate_fn(elems)

0 commit comments

Comments
 (0)