@@ -11,6 +11,19 @@ class TestTokenizeDialog(TestCase):
11
11
def setUp (self ):
12
12
self .tokenizer = AutoTokenizer .from_pretrained ("gpt2" )
13
13
14
+ def test_tokenize_dialogue_truncation_basic (self ):
15
+ dialogue = ["this will be truncated" , "." ]
16
+ self .tokenizer .truncation_side = "left"
17
+
18
+ dialog = tokenize_dialogue (dialogue , self .tokenizer , max_length = 2 )
19
+
20
+ assert len (dialog ) == 2
21
+ user_dm , bot_dm = dialog
22
+ assert len (user_dm .tokens ) == 1
23
+ assert len (bot_dm .tokens ) == 1
24
+ assert user_dm == DialogMessage (is_output = False , tokens = (self .tokenizer .bos_token_id ,))
25
+ assert bot_dm == DialogMessage (is_output = True , tokens = (self .tokenizer .eos_token_id ,))
26
+
14
27
@given (st .lists (st .text (), max_size = 32 ))
15
28
def test_tokenize_dialogue_single_turn (self , response_words ):
16
29
response = " " .join (response_words ) # space seperate to make it multiple tokens
@@ -46,20 +59,18 @@ def test_tokenize_dialogue_single_turn_truncation_left(self, response_words, max
46
59
response = " " .join (response_words ) # space seperate to make it multiple tokens
47
60
self .tokenizer .truncation_side = "left"
48
61
tokenized_response = tuple (self .tokenizer (response , add_special_tokens = False ).input_ids )
49
- tokenized_response = tokenized_response + (self .tokenizer .eos_token_id ,)
62
+ tokenized_response += (self .tokenizer .eos_token_id ,)
50
63
dialog = tokenize_dialogue (response , self .tokenizer , max_length = max_length )
51
64
52
- # if no truncation should have happened, then the user BOS prompt should be present
53
- if len (tokenized_response ) + 1 <= max_length :
54
- assert len ( dialog ) == 2
55
- user_dm , bot_dm = dialog
65
+ # whether or not truncation has happened, user BOS prompt should be present
66
+ assert len (dialog ) == 2
67
+ user_dm , bot_dm = dialog
68
+ assert user_dm == DialogMessage ( is_output = False , tokens = ( self . tokenizer . bos_token_id ,))
56
69
57
- assert user_dm == DialogMessage ( is_output = False , tokens = ( self . tokenizer . bos_token_id ,))
70
+ if len ( tokenized_response ) < max_length :
58
71
assert bot_dm == DialogMessage (is_output = True , tokens = tokenized_response )
59
72
else :
60
- assert len (dialog ) == 1
61
- bot_dm = dialog [0 ]
62
- assert bot_dm == DialogMessage (is_output = True , tokens = tokenized_response [- max_length :])
73
+ assert bot_dm == DialogMessage (is_output = True , tokens = tokenized_response [- max_length + 1 :])
63
74
64
75
all_tokens = sum ((dm .tokens for dm in dialog ), ())
65
76
assert len (all_tokens ) <= max_length
@@ -76,6 +87,9 @@ def test_tokenize_dialogue_multi_turn(self, user_response_pairs):
76
87
77
88
dm_convo = [DialogMessage (is_output = i % 2 == 1 , tokens = tokens ) for i , tokens in enumerate (tokenized_flat_convo )]
78
89
nonempty_dm_convo = [dm for dm in dm_convo if dm .tokens ]
90
+ if nonempty_dm_convo [0 ].is_output :
91
+ nonempty_dm_convo .insert (0 , DialogMessage (is_output = False , tokens = (self .tokenizer .eos_token_id ,)))
92
+
79
93
assert dialog == nonempty_dm_convo
80
94
81
95
@given (st .lists (st .tuples (st .text (), st .text ()), min_size = 1 , max_size = 32 ), st .integers (min_value = 2 , max_value = 16 ))
@@ -91,6 +105,9 @@ def test_tokenize_dialogue_multi_turn_truncation_right(self, user_response_pairs
91
105
92
106
all_tokens = sum ((dm .tokens for dm in dialog ), ())
93
107
should_be_tokens = sum (tokenized_flat_convo , ())[:max_length ]
108
+ if dialog [0 ] == DialogMessage (is_output = False , tokens = (self .tokenizer .eos_token_id ,)):
109
+ should_be_tokens = (self .tokenizer .eos_token_id , * should_be_tokens [: max_length - 1 ])
110
+
94
111
assert all_tokens == should_be_tokens
95
112
assert len (all_tokens ) <= max_length
96
113
@@ -106,8 +123,9 @@ def test_tokenize_dialogue_multi_turn_truncation_left(self, user_response_pairs,
106
123
dialog = tokenize_dialogue (flat_convo , self .tokenizer , max_length = max_length )
107
124
108
125
all_tokens = sum ((dm .tokens for dm in dialog ), ())
109
-
110
126
should_be_tokens = sum (tokenized_flat_convo , ())[- max_length :]
111
- assert all_tokens == should_be_tokens
127
+ if dialog [0 ] == DialogMessage (is_output = False , tokens = (self .tokenizer .eos_token_id ,)):
128
+ should_be_tokens = (self .tokenizer .eos_token_id , * should_be_tokens [- max_length + 1 :])
112
129
130
+ assert all_tokens == should_be_tokens
113
131
assert len (all_tokens ) <= max_length
0 commit comments