Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions model2vec/distill/distillation.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,10 @@ def _remove_tokens_and_embeddings(

# Remove the unused tokens from the tokenizer.
new_tokenizer = remove_tokens(tokenizer.backend_tokenizer, wrong_tokens)
if new_tokenizer.get_vocab_size() == tokenizer.backend_tokenizer.get_vocab_size():
# This happens if we didn't remove any tokens.
return new_tokenizer, embeddings

# Remove the embeddings of the unused tokens.
embeddings = np.delete(embeddings, wrong_token_ids, axis=0)
logger.info(f"Removed {len(wrong_tokens)} unused tokens from the tokenizer and embeddings.")
Expand All @@ -199,6 +203,7 @@ def distill(
pca_dims: PCADimType = 256,
apply_zipf: bool = True,
use_subword: bool = True,
token_remove_pattern: str | None = r"\[unused\d+\]",
) -> StaticModel:
"""
Distill a staticmodel from a sentence transformer.
Expand All @@ -217,6 +222,7 @@ def distill(
If this is 'auto', we don't reduce dimenionality, but still apply PCA.
:param apply_zipf: Whether to apply Zipf weighting to the embeddings.
:param use_subword: Whether to keep subword tokens in the vocabulary. If this is False, you must pass a vocabulary, and the returned tokenizer will only detect full words.
:param token_remove_pattern: If this is set to a string, we compile this into a regex. Any tokens that conform to this regex pattern will be removed from the vocabulary.
:return: A StaticModel

"""
Expand All @@ -231,6 +237,7 @@ def distill(
pca_dims=pca_dims,
apply_zipf=apply_zipf,
use_subword=use_subword,
token_remove_pattern=token_remove_pattern,
)


Expand Down
Loading