Skip to content

Using custom vocabulary with subword tokenizer #209

@aoezdTchibo

Description

@aoezdTchibo

The rationale for this problem stems from a related issue in the sentence-transformer project: UKPLab/sentence-transformers#3313

For testing purposes I created following training script (source @tomaarsen comment on UKPLab/sentence-transformers#3281):

Training script with a distilled model with custom vocabulary
import logging
import random

from nltk.corpus import words
from sentence_transformers import (
    SentenceTransformer,
    SentenceTransformerTrainer,
    SentenceTransformerTrainingArguments,
)
from sentence_transformers.losses import MultipleNegativesRankingLoss
from sentence_transformers.models.StaticEmbedding import StaticEmbedding
from sentence_transformers.training_args import BatchSamplers, MultiDatasetBatchSamplers

from datasets import Dataset, DatasetDict

logging.basicConfig(
    format="%(asctime)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO
)
random.seed(12)


def load_train_dataset():
    try:
        return DatasetDict.load_from_disk("datasets/train_dataset")
    except FileNotFoundError:
        wordlist = words.words()

        dataset_size = 1_000_000
        dataset_names = ["dataset_A", "dataset_B", "dataset_C", "dataset_D", "dataset_E"]
        datasets = {}
        for dataset_name in dataset_names:
            # Create a dummy dataset with random sentences
            sentence_As = [" ".join(random.sample(wordlist, k=30)) for _ in range(dataset_size)]
            sentence_Bs = [" ".join(random.sample(wordlist, k=30)) for _ in range(dataset_size)]
            dataset = Dataset.from_dict({"sentence_A": sentence_As, "sentence_B": sentence_Bs})
            datasets[dataset_name] = dataset

        train_dataset = DatasetDict(datasets)
        train_dataset.save_to_disk("datasets/train_dataset")

        return train_dataset


def main():
    base_model_name = "intfloat/multilingual-e5-large"
    base_model = SentenceTransformer(base_model_name)
    new_tokens = ["my_custom_token"]
    custom_vocab = list(base_model.tokenizer.get_vocab()) + new_tokens
    static_embedding = StaticEmbedding.from_distillation(
        base_model_name,
        device="cpu",
        vocabulary=custom_vocab,
    )
    model = SentenceTransformer(
        modules=[static_embedding],
        device="cpu",
    )
    train_dataset = load_train_dataset()
    loss = MultipleNegativesRankingLoss(
        model,
        scale=50,
    )
    run_name = f"{base_model_name}_custom_vocab"
    args = SentenceTransformerTrainingArguments(
        # Required parameter:
        output_dir=f"models/{run_name}",
        # Optional training parameters:
        num_train_epochs=1,
        per_device_train_batch_size=2048,
        per_device_eval_batch_size=2048,
        learning_rate=2e-1,
        warmup_ratio=0.1,
        fp16=False,  # Set to False if you get an error that your GPU can't run on FP16
        bf16=True,  # Set to True if you have a GPU that supports BF16
        batch_sampler=BatchSamplers.NO_DUPLICATES,
        # MultipleNegativesRankingLoss benefits from no duplicate samples in a batch
        multi_dataset_batch_sampler=MultiDatasetBatchSamplers.PROPORTIONAL,
        # Optional tracking/debugging parameters:
        eval_strategy="no",
        save_strategy="no",
        logging_steps=50,
        logging_first_step=True,
    )

    # 6. Create a trainer & train
    trainer = SentenceTransformerTrainer(
        model=model,
        args=args,
        train_dataset=train_dataset,
        loss=loss,
    )
    print("train...")
    trainer.train()

    # 7. Save the trained model
    model.save_pretrained(f"models/{run_name}")

if __name__ == "__main__":
    main()

If I run my script I get the following error (ValueError: Number of tokens (250002) does not match number of vectors (404627)):

2025-04-14 10:19:18 - Use pytorch device_name: mps
2025-04-14 10:19:18 - Load pretrained SentenceTransformer: intfloat/multilingual-e5-large
2025-04-14 10:19:25 - The `apply_zipf` parameter is deprecated and will be removed in the next release. Zipf weighting is applied based on the sif_coefficient parameter. If this is set to None, no weighting is applied.
2025-04-14 10:19:26 - Removed 95373 duplicate tokens.
2025-04-14 10:19:26 - Removed 4 multiword tokens.
2025-04-14 10:19:26 - Adding 154628 tokens to the vocabulary. Removed 95377 tokens during preprocessing.
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 396/396 [02:50<00:00,  2.32it/s]
2025-04-14 10:22:26 - Applying PCA with n_components 256
2025-04-14 10:22:40 - Reduced dimensionality from 1024 to 256.
2025-04-14 10:22:40 - Explained variance ratio: 0.914.
2025-04-14 10:22:40 - Explained variance: 167.012.
2025-04-14 10:22:40 - Estimating word frequencies using Zipf's law, and then applying SIF.
Traceback (most recent call last):
  File "(...)/train.py", line 99, in <module>
    main()
  File "(...)/train.py", line 49, in main
    static_embedding = StaticEmbedding.from_distillation(
                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "(...)/.venv/lib/python3.11/site-packages/sentence_transformers/models/StaticEmbedding.py", line 175, in from_distillation
    static_model = distill(
                   ^^^^^^^^
  File "(...)/.venv/lib/python3.11/site-packages/model2vec/distill/distillation.py", line 241, in distill
    return distill_from_model(
           ^^^^^^^^^^^^^^^^^^^
  File "(...)/.venv/lib/python3.11/site-packages/model2vec/distill/distillation.py", line 136, in distill_from_model
    return StaticModel(
           ^^^^^^^^^^^^
  File "(...)/.venv/lib/python3.11/site-packages/model2vec/model.py", line 50, in __init__
    raise ValueError(f"Number of tokens ({len(tokens)}) does not match number of vectors ({vectors.shape[0]})")
ValueError: Number of tokens (250002) does not match number of vectors (404627)

What issues exist with the script, or how can I effectively expand the vocabulary to include my required custom tokens?

Given that the subword tokenizer feature hasn't been released yet, when can we expect its availability? This information is crucial as it impacts our own release schedule. 😁

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions