3
3
import torch
4
4
from torch .nn .utils .rnn import pad_sequence
5
5
from torch .utils .data import DataLoader
6
- from transformers import DataCollatorWithPadding
6
+ from transformers import DataCollatorWithPadding , PreTrainedTokenizer
7
7
8
8
from trlx .data .ilql_types import ILQLBatch , ILQLElement
9
9
from trlx .pipeline import BasePipeline , BaseRolloutStore , register_datapipeline
@@ -23,7 +23,8 @@ def tokenize_dialogue(dialogue: Union[str, List[str]], tokenizer, max_length=204
23
23
ctx_length = max_length
24
24
if tokenizer .truncation_side == "left" :
25
25
for phrase in reversed (dialogue ):
26
- tokens = tokenizer (phrase ).input_ids [- ctx_length :]
26
+ # Manually added BOS and EOS above so we don't want to add special tokens here
27
+ tokens = tokenizer (phrase , add_special_tokens = False ).input_ids [- ctx_length :]
27
28
ctx_length -= len (tokens )
28
29
out .insert (0 , tokens )
29
30
if ctx_length == 0 :
@@ -38,7 +39,8 @@ def tokenize_dialogue(dialogue: Union[str, List[str]], tokenizer, max_length=204
38
39
39
40
elif tokenizer .truncation_side == "right" :
40
41
for phrase in dialogue :
41
- tokens = tokenizer (phrase ).input_ids [:ctx_length ]
42
+ # Manually added BOS and EOS above so we don't want to add special tokens here
43
+ tokens = tokenizer (phrase , add_special_tokens = False ).input_ids [:ctx_length ]
42
44
ctx_length -= len (tokens )
43
45
out .append (tokens )
44
46
if ctx_length == 0 :
@@ -52,13 +54,20 @@ class PromptPipeline(BasePipeline):
52
54
Tokenizes prompts, unless they are already tokenized, and truncates them to `max_prompt_length` from the right
53
55
"""
54
56
55
- def __init__ (self , prompts : List [str ], max_prompt_length : int , tokenizer = None ):
57
+ def __init__ (self , prompts : List [str ], max_prompt_length : int , tokenizer : PreTrainedTokenizer ):
56
58
super ().__init__ ()
57
- model_inputs = tokenizer (prompts , truncation = True , padding = False , max_length = max_prompt_length )
58
- prompts = model_inputs ["input_ids" ]
59
+
60
+ model_inputs = tokenizer (
61
+ prompts , truncation = True , padding = False , max_length = max_prompt_length , add_special_tokens = False
62
+ )
63
+
64
+ prompts_tokens = model_inputs ["input_ids" ]
59
65
attention_mask = model_inputs ["attention_mask" ]
66
+
60
67
self .tokenizer = tokenizer
61
- self .prompts = [{"input_ids" : prompt , "attention_mask" : mask } for prompt , mask in zip (prompts , attention_mask )]
68
+ self .prompts = [
69
+ {"input_ids" : tokens , "attention_mask" : mask } for tokens , mask in zip (prompts_tokens , attention_mask )
70
+ ]
62
71
63
72
def __getitem__ (self , ix : int ):
64
73
return self .prompts [ix ]
0 commit comments