Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,12 @@
<img src="assets/images/model2vec_model_diagram_transparant_light.png#gh-light-mode-only" width="90%">
</div>

Model2Vec is a technique to turn any sentence transformer into a really small static model, reducing model size by 15x and making the models up to 500x faster, with a small drop in performance. Our [best model](https://huggingface.co/minishlab/potion-base-8M) is the most performant static embedding model in the world. See our results [here](results/README.md), or dive in to see how it works.
Model2Vec is a technique to turn any sentence transformer into a really small static model, reducing model size by 15x and making the models up to 500x faster, with a small drop in performance. Our [best model](https://huggingface.co/minishlab/potion-base-32M) is the most performant static embedding model in the world. See our results [here](results/README.md), or dive in to see how it works.


## Updates & Announcements

- **30/01/2024**: We released two new models: [potion-base-32M](https://huggingface.co/minishlab/potion-base-32M) and [potion-retrieval-32M](https://huggingface.co/minishlab/potion-retrieval-32M). [potion-base-32M](https://huggingface.co/minishlab/potion-base-32M) is our most performant model to date, using a larger vocabulary and higher dimensions. [potion-retrieval-32M](https://huggingface.co/minishlab/potion-retrieval-32M) is a finetune of [potion-base-32M](https://huggingface.co/minishlab/potion-base-32M) that is optimized for retrieval tasks, and is the best performing static retrieval model currently available.
- **30/10/2024**: We released three new models: [potion-base-8M](https://huggingface.co/minishlab/potion-base-8M), [potion-base-4M](https://huggingface.co/minishlab/potion-base-4M), and [potion-base-2M](https://huggingface.co/minishlab/potion-base-2M). These models are trained using [Tokenlearn](https://github.com/MinishLab/tokenlearn). Find out more in our [blog post](https://minishlab.github.io/tokenlearn_blogpost/). NOTE: for users of any of our old English M2V models, we recommend switching to these new models as they [perform better on all tasks](https://github.com/MinishLab/model2vec/tree/main/results).

## Table of Contents
Expand Down Expand Up @@ -491,9 +492,11 @@ We provide a number of models that can be used out of the box. These models are

| Model | Language | Vocab | Sentence Transformer | Tokenizer Type | Params | Tokenlearn |
|-----------------------------------------------------------------------|-------------|------------------|-----------------------------------------------------------------|----------------|---------|-------------------|
| [potion-base-32M](https://huggingface.co/minishlab/potion-base-32M) | English | Output + Frequent C4 tokens | [bge-base-en-v1.5](https://huggingface.co/BAAI/bge-base-en-v1.5) | Subword | 32.3M | <div align="center">✅</div> |
| [potion-base-8M](https://huggingface.co/minishlab/potion-base-8M) | English | Output | [bge-base-en-v1.5](https://huggingface.co/BAAI/bge-base-en-v1.5) | Subword | 7.5M | <div align="center">✅</div> |
| [potion-base-4M](https://huggingface.co/minishlab/potion-base-4M) | English | Output | [bge-base-en-v1.5](https://huggingface.co/BAAI/bge-base-en-v1.5) | Subword | 3.7M | <div align="center">✅</div> |
| [potion-base-2M](https://huggingface.co/minishlab/potion-base-2M) | English | Output | [bge-base-en-v1.5](https://huggingface.co/BAAI/bge-base-en-v1.5) | Subword | 1.8M | <div align="center">✅</div> |
| [potion-retrieval-32M](https://huggingface.co/minishlab/potion-retrieval-32M) | English | Output + Frequent C4 tokens | [bge-base-en-v1.5](https://huggingface.co/BAAI/bge-base-en-v1.5) | Subword | 32.3M | <div align="center">✅</div> |
| [M2V_multilingual_output](https://huggingface.co/minishlab/M2V_multilingual_output) | Multilingual | Output | [LaBSE](https://huggingface.co/sentence-transformers/LaBSE) | Subword | 471M | <div align="center">❌</div> |


Expand Down
Binary file added assets/images/speed_vs_mteb_score_v3.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
39 changes: 30 additions & 9 deletions results/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,16 @@ Note: The `potion` and `M2V` models are our static models.
| Model | Avg (All) | Avg (MTEB) | Class | Clust | PairClass | Rank | Ret | STS | Sum | Pearl | WordSim |
|:-----------------------|------------:|-------------:|--------:|--------:|------------:|-------:|-------:|-------:|-------:|--------:|----------:|
| [all-MiniLM-L6-v2](https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2) | 56.08 | 56.09 | 62.62 | 41.94 | 82.37 | 58.04 | 41.95 | 78.90 | 30.81 | 60.83 | 49.91 |
| [potion-base-8M](https://huggingface.co/minishlab/potion-base-8M) | 50.54 | 50.03 | 64.44 | 32.93 | 76.62 | 49.73 | 31.71 | 73.24 | 29.28 | 53.54 | 50.75 |
| [M2V_base_glove_subword](https://huggingface.co/minishlab/M2V_base_glove_subword) | 49.06 | 46.69 | 61.27 | 30.03 | 74.71 | 49.15 | 27.16 | 69.09 | 30.08 | 56.82 | 57.99 |
| [potion-base-4M](https://huggingface.co/minishlab/potion-base-4M) | 48.87 | 48.23 | 62.19 | 31.47 | 75.37 | 48.75 | 29.11 | 72.19 | 28.89 | 52.55 | 49.21 |
| [M2V_base_glove](https://huggingface.co/minishlab/M2V_base_glove) | 48.58 | 47.6 | 61.35 | 30.52 | 75.34 | 48.5 | 29.26 | 70.31 | 31.5 | 50.28 | 54.29 |
| [M2V_base_output](https://huggingface.co/minishlab/M2V_base_output) | 46.79 | 45.34 | 61.25 | 25.58 | 74.9 | 47.63 | 26.14 | 68.58 | 29.2 | 54.02 | 49.18 |
| [potion-base-2M](https://huggingface.co/minishlab/potion-base-2M) | 45.52 | 44.77 | 58.45 | 27.5 | 73.72 | 46.82 | 24.13 | 70.14 | 31.51 | 50.82 | 44.72 |
| [GloVe_300d](https://huggingface.co/sentence-transformers/average_word_embeddings_glove.6B.300d) | 42.84 | 42.36 | 57.31 | 27.66 | 72.48 | 43.3 | 22.78 | 61.9 | 28.81 | 45.65 | 43.05 |
| [BPEmb_50k_300d](https://github.com/bheinzerling/bpemb) | 39.34 | 37.78 | 55.76 | 23.35 | 57.86 | 43.21 | 17.5 | 55.1 | 29.74 | 47.56 | 41.28 |
| [potion-base-32M](https://huggingface.co/minishlab/potion-base-32M) | 52.46 | 51.66 | 65.97 | 35.29 | 78.17 | 50.92 | 33.52 | 74.22 | 29.78 | 55.37 | 55.15 |
| [potion-base-8M](https://huggingface.co/minishlab/potion-base-8M) | 50.54 | 50.03 | 64.44 | 32.93 | 76.62 | 49.73 | 31.71 | 73.24 | 29.28 | 53.54 | 50.75 |
| [potion-retrieval-32M](https://huggingface.co/minishlab/potion-retrieval-32M) | 49.73 | 49.76 | 59.56 | 30.55 | 76.38 | 50.05 | 36.35 | 73.22 | 28.85 | 49.31 | 50.02 |
| [potion-base-4M](https://huggingface.co/minishlab/potion-base-4M) | 48.87 | 48.23 | 62.19 | 31.47 | 75.37 | 48.75 | 29.11 | 72.19 | 28.89 | 52.55 | 49.21 |
| [static-retrieval-mrl-en-v1](https://huggingface.co/minishlab/static-retrieval-mrl-en-v1) | 48.18 | 48.36 | 57.39 | 28.32 | 75.63 | 49.16 | 35.61 | 72.18 | 28.64 | 49.68 | 44.76 |
| [static-similarity-mrl-multilingual-v1](https://huggingface.co/minishlab/static-similarity-mrl-multilingual-v1) | 48.15 | 47.15 | 59.96 | 24.40 | 79.02 | 48.25 | 29.54 | 74.88 | 30.28 | 51.66 | 51.66 |
| [M2V_base_output](https://huggingface.co/minishlab/M2V_base_output) | 46.79 | 45.34 | 61.25 | 25.58 | 74.9 | 47.63 | 26.14 | 68.58 | 29.2 | 54.02 | 49.18 |
| [potion-base-2M](https://huggingface.co/minishlab/potion-base-2M) | 45.52 | 44.77 | 58.45 | 27.5 | 73.72 | 46.82 | 24.13 | 70.14 | 31.51 | 50.82 | 44.72 |
| [GloVe_300d](https://huggingface.co/sentence-transformers/average_word_embeddings_glove.6B.300d) | 42.84 | 42.36 | 57.31 | 27.66 | 72.48 | 43.3 | 22.78 | 61.9 | 28.81 | 45.65 | 43.05 |
| [BPEmb_50k_300d](https://github.com/bheinzerling/bpemb) | 39.34 | 37.78 | 55.76 | 23.35 | 57.86 | 43.21 | 17.5 | 55.1 | 29.74 | 47.56 | 41.28 |


<details>
Expand All @@ -36,14 +38,33 @@ For readability, the MTEB task names are abbreviated as follows:
- Sum: Summarization
</details>

The results show that [potion-base-32M](https://huggingface.co/minishlab/potion-base-32M) is the most performant static embedding model. It reaches 92.11% of the performance of [all-MiniLM-L6-v2](https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2) with an average MTEB score of 51.66 while being orders of magnitude faster.

Note: the [potion-retrieval-32M](https://huggingface.co/minishlab/potion-retrieval-32M), [static-retrieval-mrl-en-v1](https://huggingface.co/minishlab/static-retrieval-mrl-en-v1), and [static-similarity-mrl-multilingual-v1](https://huggingface.co/minishlab/static-similarity-mrl-multilingual-v1) models are task-specific models. We've included them for completeness, but they should not be compared directly to the other models for tasks that they are not designed for.

The figure below shows the relationship between the number of sentences per second and the average MTEB score. The circle sizes correspond to the number of parameters in the models (larger = more parameters).
This plot shows that the potion and M2V models are much faster than the other models, while still being competitive in terms of performance with the [all-MiniLM-L6-v2](https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2) model.
NOTE: for fairness of comparison, we disabled multiprocessing for Model2Vec for this benchmark. All sentence-transformers models are run with the [sentence-transformers](https://github.com/UKPLab/sentence-transformers) library's default settings for `encode`.

| ![Description](../assets/images/speed_vs_mteb_score_v2.png) |
| ![Description](../assets/images/speed_vs_mteb_score_v3.png) |
|:--:|
|*Figure: The average MTEB score plotted against sentences per second. The circle size indicates model size.*|


## Retrieval Results

A subset of models we created and compare against are specifically designed for retrieval tasks. The results are shown in the table below, including two general-purpose models for comparison and a transformer.

| Model | Retrieval Score |
|:-----------------------|------------------:|
| [all-MiniLM-L6-v2](https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2) | 41.95 |
| [potion-retrieval-32M](https://huggingface.co/minishlab/potion-retrieval-32M) | 36.35 |
| [static-retrieval-mrl-en-v1](https://huggingface.co/minishlab/static-retrieval-mrl-en-v1) | 35.61 |
| [potion-base-32M](https://huggingface.co/minishlab/potion-base-32M) | 33.52 |
| [potion-base-8M](https://huggingface.co/minishlab/potion-base-8M) | 31.71 |

As can be seen, [potion-retrieval-32M](https://huggingface.co/minishlab/potion-retrieval-32M) model is the most performant static retrieval model, reaching 86.65%% of the performance of [all-MiniLM-L6-v2](https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2) with a retrieval score of 36.35.

## Ablations

To better understand the factors contributing to the performance of Model2Vec, we conducted a comprehensive set of ablation studies, covering various aspects of the model's architecture and preprocessing methods. In these studies, we examined the impact of key elements such as PCA, Zipf weighting, and the use of Sentence Transformers versus regular transformer models. We also compared the performance of input embeddings versus output embeddings, since it would seem plausible that these should also work well. The results are shown in the table below.
Expand Down
184 changes: 184 additions & 0 deletions results/make_speed_vs_mteb_plot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,184 @@
"""Script to benchmark the speed of various text embedding models and generate a plot of the MTEB score vs samples per second."""

import argparse
import json
import logging
from pathlib import Path
from time import perf_counter
from typing import Any

import numpy as np
import pandas as pd
from bpemb import BPEmb
from datasets import load_dataset
from plotnine import (
aes,
element_line,
geom_point,
geom_text,
ggplot,
guides,
labs,
scale_size,
scale_y_continuous,
theme,
theme_classic,
xlim,
ylim,
)
from sentence_transformers import SentenceTransformer

from model2vec import StaticModel

logging.basicConfig(level=logging.INFO)

logger = logging.getLogger(__name__)


class BPEmbEmbedder:
def __init__(self, vs: int = 50_000, dim: int = 300) -> None:
"""Initialize the BPEmbEmbedder."""
self.bpemb_en = BPEmb(lang="en", vs=vs, dim=dim)

def mean_sentence_embedding(self, sentence: str) -> np.ndarray:
"""Encode a sentence to a mean embedding."""
encoded_ids = self.bpemb_en.encode_ids(sentence)
embeddings = self.bpemb_en.vectors[encoded_ids]
if embeddings.size == 0:
return np.zeros(self.bpemb_en.dim) # Return a zero vector if no tokens are found
return embeddings.mean(axis=0)

def encode(self, sentences: list[str], **kwargs: Any) -> np.ndarray:
"""Encode a list of sentences to embeddings."""
return np.array([self.mean_sentence_embedding(sentence.lower()) for sentence in sentences])


def make_plot(df: pd.DataFrame) -> ggplot:
"""Create a plot of the MTEB score vs samples per second."""
df["label_y"] = (
df["Average score"]
+ 0.5 # a constant "base" offset for all bubbles
+ 0.08 * np.sqrt(df["Params (Million)"])
)
plot = (
ggplot(df, aes(x="Samples per second", y="Average score"))
+ geom_point(aes(size="Params (Million)", color="Model"))
+ geom_text(aes(y="label_y", label="Model"), color="black", size=7)
+ scale_size(range=(2, 30))
+ theme_classic()
+ labs(title="Average MTEB Score vs Samples per Second")
+ ylim(df["Average score"].min(), df["Average score"].max() + 3)
+ scale_y_continuous(breaks=range(30, 70, 5))
+ theme(
panel_grid_major=element_line(color="lightgrey", size=0.5),
panel_grid_minor=element_line(color="lightgrey", size=0.25),
figure_size=(10, 6),
)
+ xlim(0, df["Samples per second"].max() + 100)
+ guides(None)
)
return plot


def benchmark_model(name: str, info: list[str], texts: list[str]) -> dict[str, float | str]:
"""Benchmark a single model."""
logger.info("Starting", name)
if info[1] == "BPEmb":
model = BPEmbEmbedder(vs=50_000, dim=300) # type: ignore
elif info[1] == "ST":
model = SentenceTransformer(info[0], device="cpu") # type: ignore
else:
model = StaticModel.from_pretrained(info[0]) # type: ignore

start = perf_counter()
if info[1] == "M2V":
# If the model is a model2vec model, disable multiprocessing for a fair comparison
model.encode(texts, use_multiprocessing=False)
else:
model.encode(texts)

total_time = perf_counter() - start
docs_per_second = len(texts) / total_time

logger.info(f"{name}: {docs_per_second} docs per second")
logger.info(f"Total time: {total_time}")

return {"docs_per_second": docs_per_second, "total_time": total_time}


def main(save_path: str, n_texts: int) -> None:
"""Benchmark text embedding models and generate a plot."""
# Define the models to benchmark
models: dict[str, list[str]] = {
"BPEmb-50k-300d": ["", "BPEmb"],
"all-MiniLM-L6-v2": ["sentence-transformers/all-MiniLM-L6-v2", "ST"],
"bge-base-en-v1.5": ["BAAI/bge-base-en-v1.5", "ST"],
"GloVe 6B 300d": ["sentence-transformers/average_word_embeddings_glove.6B.300d", "ST"],
"potion-base-8M": ["minishlab/potion-base-8M", "M2V"],
}

# Load the dataset
ds = load_dataset("wikimedia/wikipedia", data_files="20231101.en/train-00000-of-00041.parquet")["train"]
texts = ds["text"][:n_texts]

summarized_results = [
{"Model": "potion-base-2M", "Average score": 44.77, "Samples per second": None, "Params (Million)": 1.875},
{"Model": "GloVe 6B 300d", "Average score": 42.36, "Samples per second": None, "Params (Million)": 120.000},
{"Model": "potion-base-4M", "Average score": 48.23, "Samples per second": None, "Params (Million)": 3.750},
{"Model": "all-MiniLM-L6-v2", "Average score": 56.09, "Samples per second": None, "Params (Million)": 23.000},
{"Model": "potion-base-8M", "Average score": 50.03, "Samples per second": None, "Params (Million)": 7.500},
{"Model": "bge-base-en-v1.5", "Average score": 63.56, "Samples per second": None, "Params (Million)": 109.000},
{"Model": "M2V_base_output", "Average score": 45.34, "Samples per second": None, "Params (Million)": 7.500},
{"Model": "BPEmb-50k-300d", "Average score": 37.78, "Samples per second": None, "Params (Million)": 15.000},
{"Model": "potion-base-32M", "Average score": 51.66, "Samples per second": None, "Params (Million)": 32.300},
]

timings = {}

for name, info in models.items():
timing = benchmark_model(name, info, texts)
timings[name] = timing
# Update summarized results
for result in summarized_results:
if result["Model"] == name:
result["Samples per second"] = timing["docs_per_second"]

# Set potion-base-8M as the reference speed for the other potion models
potion_base_8m_speed = next(
result["Samples per second"] for result in summarized_results if result["Model"] == "potion-base-8M"
)
for model_name in ["M2V_base_output", "potion-base-2M", "potion-base-4M", "potion-base-32M"]:
for result in summarized_results:
if result["Model"] == model_name:
result["Samples per second"] = potion_base_8m_speed

# Ensure save_path is a directory
save_dir = Path(save_path)
save_dir.mkdir(parents=True, exist_ok=True)

# Save timings to JSON
json_path = save_dir / "speed_benchmark_results.json"
with open(json_path, "w") as file:
json.dump(timings, file, indent=4)

# Create and save the plot
df = pd.DataFrame(summarized_results)
plot = make_plot(df)
plot_path = save_dir / "speed_vs_mteb_plot.png"
plot.save(plot_path, width=12, height=10)

logger.info(f"Timings saved to {json_path}")
logger.info(f"Plot saved to {plot_path}")


if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Benchmark text embedding models and generate a plot.")
parser.add_argument(
"--save-path", type=str, required=True, help="Directory to save the benchmark results and plot."
)
parser.add_argument(
"--n-texts", type=int, default=100_000, help="Number of texts to use from the dataset for benchmarking."
)
args = parser.parse_args()

main(save_path=args.save_path, n_texts=args.n_texts)