Skip to content

Commit a7ef098

Browse files
authored
docs: Added new model results (#167)
1 parent 97cde9b commit a7ef098

File tree

4 files changed

+218
-10
lines changed

4 files changed

+218
-10
lines changed

README.md

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,11 +39,12 @@
3939
<img src="assets/images/model2vec_model_diagram_transparant_light.png#gh-light-mode-only" width="90%">
4040
</div>
4141

42-
Model2Vec is a technique to turn any sentence transformer into a really small static model, reducing model size by 15x and making the models up to 500x faster, with a small drop in performance. Our [best model](https://huggingface.co/minishlab/potion-base-8M) is the most performant static embedding model in the world. See our results [here](results/README.md), or dive in to see how it works.
42+
Model2Vec is a technique to turn any sentence transformer into a really small static model, reducing model size by 15x and making the models up to 500x faster, with a small drop in performance. Our [best model](https://huggingface.co/minishlab/potion-base-32M) is the most performant static embedding model in the world. See our results [here](results/README.md), or dive in to see how it works.
4343

4444

4545
## Updates & Announcements
4646

47+
- **30/01/2024**: We released two new models: [potion-base-32M](https://huggingface.co/minishlab/potion-base-32M) and [potion-retrieval-32M](https://huggingface.co/minishlab/potion-retrieval-32M). [potion-base-32M](https://huggingface.co/minishlab/potion-base-32M) is our most performant model to date, using a larger vocabulary and higher dimensions. [potion-retrieval-32M](https://huggingface.co/minishlab/potion-retrieval-32M) is a finetune of [potion-base-32M](https://huggingface.co/minishlab/potion-base-32M) that is optimized for retrieval tasks, and is the best performing static retrieval model currently available.
4748
- **30/10/2024**: We released three new models: [potion-base-8M](https://huggingface.co/minishlab/potion-base-8M), [potion-base-4M](https://huggingface.co/minishlab/potion-base-4M), and [potion-base-2M](https://huggingface.co/minishlab/potion-base-2M). These models are trained using [Tokenlearn](https://github.com/MinishLab/tokenlearn). Find out more in our [blog post](https://minishlab.github.io/tokenlearn_blogpost/). NOTE: for users of any of our old English M2V models, we recommend switching to these new models as they [perform better on all tasks](https://github.com/MinishLab/model2vec/tree/main/results).
4849

4950
## Table of Contents
@@ -491,9 +492,11 @@ We provide a number of models that can be used out of the box. These models are
491492

492493
| Model | Language | Vocab | Sentence Transformer | Tokenizer Type | Params | Tokenlearn |
493494
|-----------------------------------------------------------------------|-------------|------------------|-----------------------------------------------------------------|----------------|---------|-------------------|
495+
| [potion-base-32M](https://huggingface.co/minishlab/potion-base-32M) | English | Output + Frequent C4 tokens | [bge-base-en-v1.5](https://huggingface.co/BAAI/bge-base-en-v1.5) | Subword | 32.3M | <div align="center">✅</div> |
494496
| [potion-base-8M](https://huggingface.co/minishlab/potion-base-8M) | English | Output | [bge-base-en-v1.5](https://huggingface.co/BAAI/bge-base-en-v1.5) | Subword | 7.5M | <div align="center">✅</div> |
495497
| [potion-base-4M](https://huggingface.co/minishlab/potion-base-4M) | English | Output | [bge-base-en-v1.5](https://huggingface.co/BAAI/bge-base-en-v1.5) | Subword | 3.7M | <div align="center">✅</div> |
496498
| [potion-base-2M](https://huggingface.co/minishlab/potion-base-2M) | English | Output | [bge-base-en-v1.5](https://huggingface.co/BAAI/bge-base-en-v1.5) | Subword | 1.8M | <div align="center">✅</div> |
499+
| [potion-retrieval-32M](https://huggingface.co/minishlab/potion-retrieval-32M) | English | Output + Frequent C4 tokens | [bge-base-en-v1.5](https://huggingface.co/BAAI/bge-base-en-v1.5) | Subword | 32.3M | <div align="center">✅</div> |
497500
| [M2V_multilingual_output](https://huggingface.co/minishlab/M2V_multilingual_output) | Multilingual | Output | [LaBSE](https://huggingface.co/sentence-transformers/LaBSE) | Subword | 471M | <div align="center">❌</div> |
498501

499502

154 KB
Loading

results/README.md

Lines changed: 30 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,16 @@ Note: The `potion` and `M2V` models are our static models.
1313
| Model | Avg (All) | Avg (MTEB) | Class | Clust | PairClass | Rank | Ret | STS | Sum | Pearl | WordSim |
1414
|:-----------------------|------------:|-------------:|--------:|--------:|------------:|-------:|-------:|-------:|-------:|--------:|----------:|
1515
| [all-MiniLM-L6-v2](https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2) | 56.08 | 56.09 | 62.62 | 41.94 | 82.37 | 58.04 | 41.95 | 78.90 | 30.81 | 60.83 | 49.91 |
16-
| [potion-base-8M](https://huggingface.co/minishlab/potion-base-8M) | 50.54 | 50.03 | 64.44 | 32.93 | 76.62 | 49.73 | 31.71 | 73.24 | 29.28 | 53.54 | 50.75 |
17-
| [M2V_base_glove_subword](https://huggingface.co/minishlab/M2V_base_glove_subword) | 49.06 | 46.69 | 61.27 | 30.03 | 74.71 | 49.15 | 27.16 | 69.09 | 30.08 | 56.82 | 57.99 |
18-
| [potion-base-4M](https://huggingface.co/minishlab/potion-base-4M) | 48.87 | 48.23 | 62.19 | 31.47 | 75.37 | 48.75 | 29.11 | 72.19 | 28.89 | 52.55 | 49.21 |
19-
| [M2V_base_glove](https://huggingface.co/minishlab/M2V_base_glove) | 48.58 | 47.6 | 61.35 | 30.52 | 75.34 | 48.5 | 29.26 | 70.31 | 31.5 | 50.28 | 54.29 |
20-
| [M2V_base_output](https://huggingface.co/minishlab/M2V_base_output) | 46.79 | 45.34 | 61.25 | 25.58 | 74.9 | 47.63 | 26.14 | 68.58 | 29.2 | 54.02 | 49.18 |
21-
| [potion-base-2M](https://huggingface.co/minishlab/potion-base-2M) | 45.52 | 44.77 | 58.45 | 27.5 | 73.72 | 46.82 | 24.13 | 70.14 | 31.51 | 50.82 | 44.72 |
22-
| [GloVe_300d](https://huggingface.co/sentence-transformers/average_word_embeddings_glove.6B.300d) | 42.84 | 42.36 | 57.31 | 27.66 | 72.48 | 43.3 | 22.78 | 61.9 | 28.81 | 45.65 | 43.05 |
23-
| [BPEmb_50k_300d](https://github.com/bheinzerling/bpemb) | 39.34 | 37.78 | 55.76 | 23.35 | 57.86 | 43.21 | 17.5 | 55.1 | 29.74 | 47.56 | 41.28 |
16+
| [potion-base-32M](https://huggingface.co/minishlab/potion-base-32M) | 52.46 | 51.66 | 65.97 | 35.29 | 78.17 | 50.92 | 33.52 | 74.22 | 29.78 | 55.37 | 55.15 |
17+
| [potion-base-8M](https://huggingface.co/minishlab/potion-base-8M) | 50.54 | 50.03 | 64.44 | 32.93 | 76.62 | 49.73 | 31.71 | 73.24 | 29.28 | 53.54 | 50.75 |
18+
| [potion-retrieval-32M](https://huggingface.co/minishlab/potion-retrieval-32M) | 49.73 | 49.76 | 59.56 | 30.55 | 76.38 | 50.05 | 36.35 | 73.22 | 28.85 | 49.31 | 50.02 |
19+
| [potion-base-4M](https://huggingface.co/minishlab/potion-base-4M) | 48.87 | 48.23 | 62.19 | 31.47 | 75.37 | 48.75 | 29.11 | 72.19 | 28.89 | 52.55 | 49.21 |
20+
| [static-retrieval-mrl-en-v1](https://huggingface.co/minishlab/static-retrieval-mrl-en-v1) | 48.18 | 48.36 | 57.39 | 28.32 | 75.63 | 49.16 | 35.61 | 72.18 | 28.64 | 49.68 | 44.76 |
21+
| [static-similarity-mrl-multilingual-v1](https://huggingface.co/minishlab/static-similarity-mrl-multilingual-v1) | 48.15 | 47.15 | 59.96 | 24.40 | 79.02 | 48.25 | 29.54 | 74.88 | 30.28 | 51.66 | 51.66 |
22+
| [M2V_base_output](https://huggingface.co/minishlab/M2V_base_output) | 46.79 | 45.34 | 61.25 | 25.58 | 74.9 | 47.63 | 26.14 | 68.58 | 29.2 | 54.02 | 49.18 |
23+
| [potion-base-2M](https://huggingface.co/minishlab/potion-base-2M) | 45.52 | 44.77 | 58.45 | 27.5 | 73.72 | 46.82 | 24.13 | 70.14 | 31.51 | 50.82 | 44.72 |
24+
| [GloVe_300d](https://huggingface.co/sentence-transformers/average_word_embeddings_glove.6B.300d) | 42.84 | 42.36 | 57.31 | 27.66 | 72.48 | 43.3 | 22.78 | 61.9 | 28.81 | 45.65 | 43.05 |
25+
| [BPEmb_50k_300d](https://github.com/bheinzerling/bpemb) | 39.34 | 37.78 | 55.76 | 23.35 | 57.86 | 43.21 | 17.5 | 55.1 | 29.74 | 47.56 | 41.28 |
2426

2527

2628
<details>
@@ -36,14 +38,33 @@ For readability, the MTEB task names are abbreviated as follows:
3638
- Sum: Summarization
3739
</details>
3840

41+
The results show that [potion-base-32M](https://huggingface.co/minishlab/potion-base-32M) is the most performant static embedding model. It reaches 92.11% of the performance of [all-MiniLM-L6-v2](https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2) with an average MTEB score of 51.66 while being orders of magnitude faster.
42+
43+
Note: the [potion-retrieval-32M](https://huggingface.co/minishlab/potion-retrieval-32M), [static-retrieval-mrl-en-v1](https://huggingface.co/minishlab/static-retrieval-mrl-en-v1), and [static-similarity-mrl-multilingual-v1](https://huggingface.co/minishlab/static-similarity-mrl-multilingual-v1) models are task-specific models. We've included them for completeness, but they should not be compared directly to the other models for tasks that they are not designed for.
44+
3945
The figure below shows the relationship between the number of sentences per second and the average MTEB score. The circle sizes correspond to the number of parameters in the models (larger = more parameters).
4046
This plot shows that the potion and M2V models are much faster than the other models, while still being competitive in terms of performance with the [all-MiniLM-L6-v2](https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2) model.
47+
NOTE: for fairness of comparison, we disabled multiprocessing for Model2Vec for this benchmark. All sentence-transformers models are run with the [sentence-transformers](https://github.com/UKPLab/sentence-transformers) library's default settings for `encode`.
4148

42-
| ![Description](../assets/images/speed_vs_mteb_score_v2.png) |
49+
| ![Description](../assets/images/speed_vs_mteb_score_v3.png) |
4350
|:--:|
4451
|*Figure: The average MTEB score plotted against sentences per second. The circle size indicates model size.*|
4552

4653

54+
## Retrieval Results
55+
56+
A subset of models we created and compare against are specifically designed for retrieval tasks. The results are shown in the table below, including two general-purpose models for comparison and a transformer.
57+
58+
| Model | Retrieval Score |
59+
|:-----------------------|------------------:|
60+
| [all-MiniLM-L6-v2](https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2) | 41.95 |
61+
| [potion-retrieval-32M](https://huggingface.co/minishlab/potion-retrieval-32M) | 36.35 |
62+
| [static-retrieval-mrl-en-v1](https://huggingface.co/minishlab/static-retrieval-mrl-en-v1) | 35.61 |
63+
| [potion-base-32M](https://huggingface.co/minishlab/potion-base-32M) | 33.52 |
64+
| [potion-base-8M](https://huggingface.co/minishlab/potion-base-8M) | 31.71 |
65+
66+
As can be seen, [potion-retrieval-32M](https://huggingface.co/minishlab/potion-retrieval-32M) model is the most performant static retrieval model, reaching 86.65%% of the performance of [all-MiniLM-L6-v2](https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2) with a retrieval score of 36.35.
67+
4768
## Ablations
4869

4970
To better understand the factors contributing to the performance of Model2Vec, we conducted a comprehensive set of ablation studies, covering various aspects of the model's architecture and preprocessing methods. In these studies, we examined the impact of key elements such as PCA, Zipf weighting, and the use of Sentence Transformers versus regular transformer models. We also compared the performance of input embeddings versus output embeddings, since it would seem plausible that these should also work well. The results are shown in the table below.

results/make_speed_vs_mteb_plot.py

Lines changed: 184 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,184 @@
1+
"""Script to benchmark the speed of various text embedding models and generate a plot of the MTEB score vs samples per second."""
2+
3+
import argparse
4+
import json
5+
import logging
6+
from pathlib import Path
7+
from time import perf_counter
8+
from typing import Any
9+
10+
import numpy as np
11+
import pandas as pd
12+
from bpemb import BPEmb
13+
from datasets import load_dataset
14+
from plotnine import (
15+
aes,
16+
element_line,
17+
geom_point,
18+
geom_text,
19+
ggplot,
20+
guides,
21+
labs,
22+
scale_size,
23+
scale_y_continuous,
24+
theme,
25+
theme_classic,
26+
xlim,
27+
ylim,
28+
)
29+
from sentence_transformers import SentenceTransformer
30+
31+
from model2vec import StaticModel
32+
33+
logging.basicConfig(level=logging.INFO)
34+
35+
logger = logging.getLogger(__name__)
36+
37+
38+
class BPEmbEmbedder:
39+
def __init__(self, vs: int = 50_000, dim: int = 300) -> None:
40+
"""Initialize the BPEmbEmbedder."""
41+
self.bpemb_en = BPEmb(lang="en", vs=vs, dim=dim)
42+
43+
def mean_sentence_embedding(self, sentence: str) -> np.ndarray:
44+
"""Encode a sentence to a mean embedding."""
45+
encoded_ids = self.bpemb_en.encode_ids(sentence)
46+
embeddings = self.bpemb_en.vectors[encoded_ids]
47+
if embeddings.size == 0:
48+
return np.zeros(self.bpemb_en.dim) # Return a zero vector if no tokens are found
49+
return embeddings.mean(axis=0)
50+
51+
def encode(self, sentences: list[str], **kwargs: Any) -> np.ndarray:
52+
"""Encode a list of sentences to embeddings."""
53+
return np.array([self.mean_sentence_embedding(sentence.lower()) for sentence in sentences])
54+
55+
56+
def make_plot(df: pd.DataFrame) -> ggplot:
57+
"""Create a plot of the MTEB score vs samples per second."""
58+
df["label_y"] = (
59+
df["Average score"]
60+
+ 0.5 # a constant "base" offset for all bubbles
61+
+ 0.08 * np.sqrt(df["Params (Million)"])
62+
)
63+
plot = (
64+
ggplot(df, aes(x="Samples per second", y="Average score"))
65+
+ geom_point(aes(size="Params (Million)", color="Model"))
66+
+ geom_text(aes(y="label_y", label="Model"), color="black", size=7)
67+
+ scale_size(range=(2, 30))
68+
+ theme_classic()
69+
+ labs(title="Average MTEB Score vs Samples per Second")
70+
+ ylim(df["Average score"].min(), df["Average score"].max() + 3)
71+
+ scale_y_continuous(breaks=range(30, 70, 5))
72+
+ theme(
73+
panel_grid_major=element_line(color="lightgrey", size=0.5),
74+
panel_grid_minor=element_line(color="lightgrey", size=0.25),
75+
figure_size=(10, 6),
76+
)
77+
+ xlim(0, df["Samples per second"].max() + 100)
78+
+ guides(None)
79+
)
80+
return plot
81+
82+
83+
def benchmark_model(name: str, info: list[str], texts: list[str]) -> dict[str, float | str]:
84+
"""Benchmark a single model."""
85+
logger.info("Starting", name)
86+
if info[1] == "BPEmb":
87+
model = BPEmbEmbedder(vs=50_000, dim=300) # type: ignore
88+
elif info[1] == "ST":
89+
model = SentenceTransformer(info[0], device="cpu") # type: ignore
90+
else:
91+
model = StaticModel.from_pretrained(info[0]) # type: ignore
92+
93+
start = perf_counter()
94+
if info[1] == "M2V":
95+
# If the model is a model2vec model, disable multiprocessing for a fair comparison
96+
model.encode(texts, use_multiprocessing=False)
97+
else:
98+
model.encode(texts)
99+
100+
total_time = perf_counter() - start
101+
docs_per_second = len(texts) / total_time
102+
103+
logger.info(f"{name}: {docs_per_second} docs per second")
104+
logger.info(f"Total time: {total_time}")
105+
106+
return {"docs_per_second": docs_per_second, "total_time": total_time}
107+
108+
109+
def main(save_path: str, n_texts: int) -> None:
110+
"""Benchmark text embedding models and generate a plot."""
111+
# Define the models to benchmark
112+
models: dict[str, list[str]] = {
113+
"BPEmb-50k-300d": ["", "BPEmb"],
114+
"all-MiniLM-L6-v2": ["sentence-transformers/all-MiniLM-L6-v2", "ST"],
115+
"bge-base-en-v1.5": ["BAAI/bge-base-en-v1.5", "ST"],
116+
"GloVe 6B 300d": ["sentence-transformers/average_word_embeddings_glove.6B.300d", "ST"],
117+
"potion-base-8M": ["minishlab/potion-base-8M", "M2V"],
118+
}
119+
120+
# Load the dataset
121+
ds = load_dataset("wikimedia/wikipedia", data_files="20231101.en/train-00000-of-00041.parquet")["train"]
122+
texts = ds["text"][:n_texts]
123+
124+
summarized_results = [
125+
{"Model": "potion-base-2M", "Average score": 44.77, "Samples per second": None, "Params (Million)": 1.875},
126+
{"Model": "GloVe 6B 300d", "Average score": 42.36, "Samples per second": None, "Params (Million)": 120.000},
127+
{"Model": "potion-base-4M", "Average score": 48.23, "Samples per second": None, "Params (Million)": 3.750},
128+
{"Model": "all-MiniLM-L6-v2", "Average score": 56.09, "Samples per second": None, "Params (Million)": 23.000},
129+
{"Model": "potion-base-8M", "Average score": 50.03, "Samples per second": None, "Params (Million)": 7.500},
130+
{"Model": "bge-base-en-v1.5", "Average score": 63.56, "Samples per second": None, "Params (Million)": 109.000},
131+
{"Model": "M2V_base_output", "Average score": 45.34, "Samples per second": None, "Params (Million)": 7.500},
132+
{"Model": "BPEmb-50k-300d", "Average score": 37.78, "Samples per second": None, "Params (Million)": 15.000},
133+
{"Model": "potion-base-32M", "Average score": 51.66, "Samples per second": None, "Params (Million)": 32.300},
134+
]
135+
136+
timings = {}
137+
138+
for name, info in models.items():
139+
timing = benchmark_model(name, info, texts)
140+
timings[name] = timing
141+
# Update summarized results
142+
for result in summarized_results:
143+
if result["Model"] == name:
144+
result["Samples per second"] = timing["docs_per_second"]
145+
146+
# Set potion-base-8M as the reference speed for the other potion models
147+
potion_base_8m_speed = next(
148+
result["Samples per second"] for result in summarized_results if result["Model"] == "potion-base-8M"
149+
)
150+
for model_name in ["M2V_base_output", "potion-base-2M", "potion-base-4M", "potion-base-32M"]:
151+
for result in summarized_results:
152+
if result["Model"] == model_name:
153+
result["Samples per second"] = potion_base_8m_speed
154+
155+
# Ensure save_path is a directory
156+
save_dir = Path(save_path)
157+
save_dir.mkdir(parents=True, exist_ok=True)
158+
159+
# Save timings to JSON
160+
json_path = save_dir / "speed_benchmark_results.json"
161+
with open(json_path, "w") as file:
162+
json.dump(timings, file, indent=4)
163+
164+
# Create and save the plot
165+
df = pd.DataFrame(summarized_results)
166+
plot = make_plot(df)
167+
plot_path = save_dir / "speed_vs_mteb_plot.png"
168+
plot.save(plot_path, width=12, height=10)
169+
170+
logger.info(f"Timings saved to {json_path}")
171+
logger.info(f"Plot saved to {plot_path}")
172+
173+
174+
if __name__ == "__main__":
175+
parser = argparse.ArgumentParser(description="Benchmark text embedding models and generate a plot.")
176+
parser.add_argument(
177+
"--save-path", type=str, required=True, help="Directory to save the benchmark results and plot."
178+
)
179+
parser.add_argument(
180+
"--n-texts", type=int, default=100_000, help="Number of texts to use from the dataset for benchmarking."
181+
)
182+
args = parser.parse_args()
183+
184+
main(save_path=args.save_path, n_texts=args.n_texts)

0 commit comments

Comments
 (0)