Skip to content

Commit 29c2e26

Browse files
Better tokenizing code for AuraFlow.
1 parent b6f09cf commit 29c2e26

File tree

5 files changed

+25
-1175
lines changed

5 files changed

+25
-1175
lines changed

comfy/text_encoders/aura_t5.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from comfy import sd1_clip
2-
from transformers import LlamaTokenizerFast
2+
from .llama_tokenizer import LLAMATokenizer
33
import comfy.t5
44
import os
55

@@ -10,8 +10,8 @@ def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None):
1010

1111
class PT5XlTokenizer(sd1_clip.SDTokenizer):
1212
def __init__(self, embedding_directory=None):
13-
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_pile_tokenizer")
14-
super().__init__(tokenizer_path, pad_with_end=False, embedding_size=2048, embedding_key='pile_t5xl', tokenizer_class=LlamaTokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=256, pad_token=1)
13+
tokenizer_path = os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_pile_tokenizer"), "tokenizer.model")
14+
super().__init__(tokenizer_path, pad_with_end=False, embedding_size=2048, embedding_key='pile_t5xl', tokenizer_class=LLAMATokenizer, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=256, pad_token=1)
1515

1616
class AuraT5Tokenizer(sd1_clip.SD1Tokenizer):
1717
def __init__(self, embedding_directory=None):
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
import os
2+
3+
class LLAMATokenizer:
4+
@staticmethod
5+
def from_pretrained(path):
6+
return LLAMATokenizer(path)
7+
8+
def __init__(self, tokenizer_path):
9+
import sentencepiece
10+
self.tokenizer = sentencepiece.SentencePieceProcessor(model_file=tokenizer_path)
11+
self.end = self.tokenizer.eos_id()
12+
13+
def get_vocab(self):
14+
out = {}
15+
for i in range(self.tokenizer.get_piece_size()):
16+
out[self.tokenizer.id_to_piece(i)] = i
17+
return out
18+
19+
def __call__(self, string):
20+
out = self.tokenizer.encode(string)
21+
out += [self.end]
22+
return {"input_ids": out}

comfy/text_encoders/t5_pile_tokenizer/added_tokens.json

Lines changed: 0 additions & 102 deletions
This file was deleted.

comfy/text_encoders/t5_pile_tokenizer/special_tokens_map.json

Lines changed: 0 additions & 125 deletions
This file was deleted.

0 commit comments

Comments
 (0)