-
Notifications
You must be signed in to change notification settings - Fork 78
Description
Location:
_decode_function called from _apply_new_characters
What’s happening
The code in _apply_new_characters computes
prev_decoded = self.decoder(state.current_word_tokens)
new_decoded = self.decoder(new_state.current_word_tokens)
new_characters = new_decoded[len(prev_decoded):]
using the default tokenizer.decode() (which has clean_up_tokenization_spaces=True). That cleanup step collapses runs of spaces and normalizes punctuation spacing, so in practice:
# Llama-3 tokenizer example
state.current_word_tokens[-5:] # last 5 tokens
# [5766, 264, 702, 220, 220]
new_state.current_word_tokens[-6:] # same 5 + new_token=1359
# [5766, 264, 702, 220, 220, 1359]
self.decoder(state.current_word_tokens[-5:])
# => ' avoid a"\n ' # two spaces preserved
self.decoder(new_state.current_word_tokens[-6:])
# => ' avoid a"\n ,"' # only one space!
new_characters = new_decoded[len(prev_decoded):]
# resulting new_characters = '"' # comma lost!
Therefore the ',' never gets added to the new_state parser and the state is wrong.
In turn, the JSON string parser allows an EOT token to be generated and you end up with a truncated invalid JSON.
Furthermore, the .rstrip('�')
induces the same behavior. I had a case where a model generated this token. It then got striped in the decoding, which caused it not be added as a new char, messing up the parser state such that it allowed EOS in the next generation. This also caused premature generation termination.
Suggested fix
Disable the space cleanup in the integration’s decode function:
def _decode_function(tokenizer, tokens):
- decoded = tokenizer.decode(tokens)
- return decoded.rstrip('�')
+ decoded = tokenizer.decode(
+ tokens,
+ clean_up_tokenization_spaces=False
+ )
+ return decoded
This ensures every space, comma, newline, and quote is counted, so closing punctuation never gets dropped and JSON fields finish correctly.