-
Notifications
You must be signed in to change notification settings - Fork 1.8k
Open
Labels
Description
Description
When I added new words to the vocabulary (Arabic), the model's performance deteriorated significantly when I added UTF-8 encoded words. I checked vocab.json and found that all the words were Latin-1 encoded, so I added Latin-1 encoded words again. However, if I passed a string (UTF-8), the model wouldn't segment the words correctly. It only worked when I passed a Latin-1 encoded string, but the original model was using UTF-8 encoded strings.
Reproduction
from transformers import AutoTokenizer
from transformers import AutoModelForCausalLM
import torch
from typing import Set
model_path = "./Qwen3/models/Qwen3-4B-Instruct-2507"
save_path = "./Qwen3/models/Qwen3-4B-Instruct-2507-arabic-v2"
def utf8_to_latin1_bytes(w: str) -> str:
return w.encode("utf-8").decode("latin-1")
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
existing_tokens = set(tokenizer.get_vocab().keys())
new_tokens = ["فقال", "جرة"]
end_add_tokens = []
utf_8_add_tokens = []
for word in new_tokens:
latin_form = utf8_to_latin1_bytes(word)
if latin_form not in existing_tokens:
end_add_tokens.append(latin_form)
utf_8_add_tokens.append(word)
else:
print(f"跳过已存在的词汇: [UTF-8] {word} [Latin-1] {latin_form}")
initial_vocab_size = len(tokenizer)
print('初始token数量:', initial_vocab_size) # 151669
num_added_tokens = tokenizer.add_tokens(end_add_tokens)
print("新增token数量:", num_added_tokens) # 3509
model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype=torch.bfloat16)
embedding = model.get_input_embeddings() # 原始 embedding
old_vocab_size = embedding.weight.shape[0] # 原始模型的 embedding 行数
print("old_vocab_size:", old_vocab_size) # 151936
model.resize_token_embeddings(len(tokenizer))
embedding = model.get_input_embeddings()
new_vocab_size = embedding.weight.shape[0]
print("new_vocab_size:", new_vocab_size) # 155178
print("新增可训练行数:", new_vocab_size-old_vocab_size) # 3242
print(f'最终需要训练的embedding维度: [{initial_vocab_size}:, :]') # [151669:, :]
merges_rules = []
print("UTF-8 格式编码 分词结果:")
for text in utf_8_add_tokens:
input_ids = tokenizer.encode(text, add_special_tokens=False)
# print(f"input_ids: {input_ids}")
decoded1_list = [tokenizer.decode(input_id, skip_special_tokens=True) for input_id in input_ids]
if len(decoded1_list) == 2:
merges_rules.append(decoded1_list)
else:
print(f"新增合并规则失败: {decoded1_list}")
print(f"新增合并规则: {merges_rules}")
merges_rules = []
print("Latin-1 格式编码 分词结果:")
for text in end_add_tokens:
input_ids = tokenizer.encode(text, add_special_tokens=False)
# print(f"input_ids: {input_ids}")
decoded1_list = [tokenizer.decode(input_id, skip_special_tokens=True) for input_id in input_ids]
if len(decoded1_list) == 2:
merges_rules.append(decoded1_list)
else:
print(f"新增合并规则失败: {decoded1_list}")
print(f"新增合并规则: {merges_rules}")Logs
新增token数量: 2
Loading checkpoint shards: 100%|█████████████████████████████████████████████████████████| 3/3 [00:00<00:00, 15.96it/s]
old_vocab_size: 151936
new_vocab_size: 151671
新增可训练行数: -265
最终需要训练的embedding维度: [151669:, :]
UTF-8 格式编码 分词结果:
新增合并规则: [['ف', 'قال'], ['جر', 'ة']]
Latin-1 格式编码 分词结果:
新增合并规则失败: ['Ù\x81Ù\x82اÙ\x84']
新增合并规则失败: ['جرة']
新增合并规则: []Environment Information
NVIDIA A800-SXM4-80GB NVIDIA-SMI 535.129.03 Driver Version: 535.129.03 CUDA Version: 12.4
Package Version
---------------------------------------- -------------
absl-py 2.3.1
accelerate 1.10.1
addict 2.4.0
aiofiles 24.1.0
aiohappyeyeballs 2.6.1
aiohttp 3.12.15
aiosignal 1.4.0
airportsdata 20250811
aliyun-python-sdk-core 2.16.0
aliyun-python-sdk-kms 2.16.5
annotated-types 0.7.0
antlr4-python3-runtime 4.13.2
anyio 4.10.0
astor 0.8.1
attrdict 2.0.1
attrs 25.3.0
binpacking 1.5.2
blake3 1.0.5
Brotli 1.1.0
cachetools 6.1.0
camel-kenlm 2025.9.16
certifi 2025.8.3
cffi 2.0.0
charset-normalizer 3.4.3
click 8.2.1
cloudpickle 3.1.1
compressed-tensors 0.9.3
contourpy 1.3.3
cpm-kernels 1.0.11
crcmod 1.7
cryptography 46.0.1
cupy-cuda12x 13.6.0
cycler 0.12.1
dacite 1.9.2
datasets 3.6.0
deepspeed 0.17.5
Deprecated 1.2.18
depyf 0.18.0
dill 0.3.8
diskcache 5.6.3
distro 1.9.0
dnspython 2.7.0
docopt 0.6.2
editdistance 0.8.1
einops 0.8.1
email_validator 2.2.0
emoji 2.15.0
fastapi 0.116.1
fastapi-cli 0.0.8
fastapi-cloud-cli 0.1.5
fastrlock 0.8.3
ffmpy 0.6.1
filelock 3.19.1
fonttools 4.59.2
frozenlist 1.7.0
fsspec 2025.3.0
future 1.0.0
gguf 0.17.1
googleapis-common-protos 1.70.0
gradio 5.46.0
gradio_client 1.13.0
groovy 0.1.2
grpcio 1.74.0
h11 0.16.0
hf-xet 1.1.8
hjson 3.1.0
httpcore 1.0.9
httptools 0.6.4
httpx 0.28.1
huggingface-hub 0.34.4
idna 3.10
importlib_metadata 8.0.0
interegular 0.3.3
jedi 0.19.2
jieba 0.42.1
Jinja2 3.1.6
jiter 0.10.0
jmespath 0.10.0
joblib 1.5.2
jsonschema 4.25.1
jsonschema-specifications 2025.4.1
kiwisolver 1.4.9
lark 1.2.2
latex2sympy2_extended 1.0.6
llguidance 0.7.30
llvmlite 0.44.0
lm-format-enforcer 0.10.12
Markdown 3.9
markdown-it-py 4.0.0
MarkupSafe 3.0.2
math-verify 0.5.2
matplotlib 3.10.6
mdurl 0.1.2
mistral_common 1.8.4
modelscope 1.29.1
mpmath 1.3.0
ms_swift 3.8.1
msgpack 1.1.1
msgspec 0.19.0
muddler 0.1.3
multidict 6.6.4
multiprocess 0.70.16
nest-asyncio 1.6.0
networkx 3.5
ninja 1.13.0
nltk 3.9.1
numba 0.61.2
numpy 2.2.6
nvidia-cublas-cu12 12.4.5.8
nvidia-cuda-cupti-cu12 12.4.127
nvidia-cuda-nvrtc-cu12 12.4.127
nvidia-cuda-runtime-cu12 12.4.127
nvidia-cudnn-cu12 9.1.0.70
nvidia-cufft-cu12 11.2.1.3
nvidia-curand-cu12 10.3.5.147
nvidia-cusolver-cu12 11.6.1.9
nvidia-cusparse-cu12 12.3.1.170
nvidia-cusparselt-cu12 0.6.2
nvidia-ml-py 13.580.82
nvidia-nccl-cu12 2.21.5
nvidia-nvjitlink-cu12 12.4.127
nvidia-nvtx-cu12 12.4.127
openai 1.101.0
opencv-python-headless 4.12.0.88
opentelemetry-api 1.26.0
opentelemetry-exporter-otlp 1.26.0
opentelemetry-exporter-otlp-proto-common 1.26.0
opentelemetry-exporter-otlp-proto-grpc 1.26.0
opentelemetry-exporter-otlp-proto-http 1.26.0
opentelemetry-proto 1.26.0
opentelemetry-sdk 1.26.0
opentelemetry-semantic-conventions 0.47b0
opentelemetry-semantic-conventions-ai 0.4.13
orjson 3.11.3
oss2 2.19.1
outlines 0.1.11
outlines_core 0.1.26
packaging 25.0
pandas 2.3.2
parso 0.8.5
partial-json-parser 0.2.1.1.post6
peft 0.17.1
pillow 11.3.0
pip 25.1
prometheus_client 0.22.1
prometheus-fastapi-instrumentator 7.1.0
propcache 0.3.2
protobuf 4.25.8
psutil 7.0.0
pudb 2025.1.1
py-cpuinfo 9.0.0
pyarrow 21.0.0
pycountry 24.6.1
pycparser 2.23
pycryptodome 3.23.0
pydantic 2.11.7
pydantic_core 2.33.2
pydantic-extra-types 2.10.5
pydub 0.25.1
Pygments 2.19.2
pyparsing 3.2.4
pyrsistent 0.20.0
python-dateutil 2.9.0.post0
python-dotenv 1.1.1
python-json-logger 3.3.0
python-multipart 0.0.20
pytz 2025.2
PyYAML 6.0.2
pyzmq 27.0.2
ray 2.48.0
referencing 0.36.2
regex 2025.7.34
requests 2.32.5
rich 14.1.0
rich-toolkit 0.15.0
rignore 0.6.4
rouge 1.0.1
rpds-py 0.27.0
ruff 0.13.0
safehttpx 0.1.6
safetensors 0.6.2
scikit-learn 1.7.2
scipy 1.16.1
semantic-version 2.10.0
sentencepiece 0.2.1
sentry-sdk 2.35.0
setuptools 78.1.1
shellingham 1.5.4
simplejson 3.20.1
six 1.17.0
sniffio 1.3.1
sortedcontainers 2.4.0
starlette 0.47.3
sympy 1.13.1
tabulate 0.9.0
tensorboard 2.20.0
tensorboard-data-server 0.7.2
threadpoolctl 3.6.0
tiktoken 0.11.0
tokenizers 0.21.4
tomlkit 0.13.3
torch 2.6.0
torchaudio 2.6.0
torchvision 0.21.0
tqdm 4.67.1
transformers 4.55.4
transformers-stream-generator 0.0.5
triton 3.2.0
trl 0.20.0
typer 0.16.1
typing_extensions 4.14.1
typing-inspection 0.4.1
tzdata 2025.2
urllib3 2.5.0
urwid 3.0.3
urwid_readline 0.15.1
uvicorn 0.35.0
uvloop 0.21.0
vllm 0.8.5.post1
watchfiles 1.1.0
wcwidth 0.2.14
websockets 15.0.1
Werkzeug 3.1.3
wheel 0.45.1
wrapt 1.17.3
xformers 0.0.29.post2
xgrammar 0.1.18
xxhash 3.5.0
yarl 1.20.1
zipp 3.23.0
zstandard 0.25.0
Known Issue
- The issue hasn't been already addressed in Documentation, Issues, and Discussions.