|
| 1 | +"""Script to benchmark the speed of various text embedding models and generate a plot of the MTEB score vs samples per second.""" |
| 2 | + |
| 3 | +import argparse |
| 4 | +import json |
| 5 | +import logging |
| 6 | +from pathlib import Path |
| 7 | +from time import perf_counter |
| 8 | +from typing import Any |
| 9 | + |
| 10 | +import numpy as np |
| 11 | +import pandas as pd |
| 12 | +from bpemb import BPEmb |
| 13 | +from datasets import load_dataset |
| 14 | +from plotnine import ( |
| 15 | + aes, |
| 16 | + element_line, |
| 17 | + geom_point, |
| 18 | + geom_text, |
| 19 | + ggplot, |
| 20 | + guides, |
| 21 | + labs, |
| 22 | + scale_size, |
| 23 | + scale_y_continuous, |
| 24 | + theme, |
| 25 | + theme_classic, |
| 26 | + xlim, |
| 27 | + ylim, |
| 28 | +) |
| 29 | +from sentence_transformers import SentenceTransformer |
| 30 | + |
| 31 | +from model2vec import StaticModel |
| 32 | + |
| 33 | +logging.basicConfig(level=logging.INFO) |
| 34 | + |
| 35 | +logger = logging.getLogger(__name__) |
| 36 | + |
| 37 | + |
| 38 | +class BPEmbEmbedder: |
| 39 | + def __init__(self, vs: int = 50_000, dim: int = 300) -> None: |
| 40 | + """Initialize the BPEmbEmbedder.""" |
| 41 | + self.bpemb_en = BPEmb(lang="en", vs=vs, dim=dim) |
| 42 | + |
| 43 | + def mean_sentence_embedding(self, sentence: str) -> np.ndarray: |
| 44 | + """Encode a sentence to a mean embedding.""" |
| 45 | + encoded_ids = self.bpemb_en.encode_ids(sentence) |
| 46 | + embeddings = self.bpemb_en.vectors[encoded_ids] |
| 47 | + if embeddings.size == 0: |
| 48 | + return np.zeros(self.bpemb_en.dim) # Return a zero vector if no tokens are found |
| 49 | + return embeddings.mean(axis=0) |
| 50 | + |
| 51 | + def encode(self, sentences: list[str], **kwargs: Any) -> np.ndarray: |
| 52 | + """Encode a list of sentences to embeddings.""" |
| 53 | + return np.array([self.mean_sentence_embedding(sentence.lower()) for sentence in sentences]) |
| 54 | + |
| 55 | + |
| 56 | +def make_plot(df: pd.DataFrame) -> ggplot: |
| 57 | + """Create a plot of the MTEB score vs samples per second.""" |
| 58 | + df["label_y"] = ( |
| 59 | + df["Average score"] |
| 60 | + + 0.5 # a constant "base" offset for all bubbles |
| 61 | + + 0.08 * np.sqrt(df["Params (Million)"]) |
| 62 | + ) |
| 63 | + plot = ( |
| 64 | + ggplot(df, aes(x="Samples per second", y="Average score")) |
| 65 | + + geom_point(aes(size="Params (Million)", color="Model")) |
| 66 | + + geom_text(aes(y="label_y", label="Model"), color="black", size=7) |
| 67 | + + scale_size(range=(2, 30)) |
| 68 | + + theme_classic() |
| 69 | + + labs(title="Average MTEB Score vs Samples per Second") |
| 70 | + + ylim(df["Average score"].min(), df["Average score"].max() + 3) |
| 71 | + + scale_y_continuous(breaks=range(30, 70, 5)) |
| 72 | + + theme( |
| 73 | + panel_grid_major=element_line(color="lightgrey", size=0.5), |
| 74 | + panel_grid_minor=element_line(color="lightgrey", size=0.25), |
| 75 | + figure_size=(10, 6), |
| 76 | + ) |
| 77 | + + xlim(0, df["Samples per second"].max() + 100) |
| 78 | + + guides(None) |
| 79 | + ) |
| 80 | + return plot |
| 81 | + |
| 82 | + |
| 83 | +def benchmark_model(name: str, info: list[str], texts: list[str]) -> dict[str, float | str]: |
| 84 | + """Benchmark a single model.""" |
| 85 | + logger.info("Starting", name) |
| 86 | + if info[1] == "BPEmb": |
| 87 | + model = BPEmbEmbedder(vs=50_000, dim=300) # type: ignore |
| 88 | + elif info[1] == "ST": |
| 89 | + model = SentenceTransformer(info[0], device="cpu") # type: ignore |
| 90 | + else: |
| 91 | + model = StaticModel.from_pretrained(info[0]) # type: ignore |
| 92 | + |
| 93 | + start = perf_counter() |
| 94 | + if info[1] == "M2V": |
| 95 | + # If the model is a model2vec model, disable multiprocessing for a fair comparison |
| 96 | + model.encode(texts, use_multiprocessing=False) |
| 97 | + else: |
| 98 | + model.encode(texts) |
| 99 | + |
| 100 | + total_time = perf_counter() - start |
| 101 | + docs_per_second = len(texts) / total_time |
| 102 | + |
| 103 | + logger.info(f"{name}: {docs_per_second} docs per second") |
| 104 | + logger.info(f"Total time: {total_time}") |
| 105 | + |
| 106 | + return {"docs_per_second": docs_per_second, "total_time": total_time} |
| 107 | + |
| 108 | + |
| 109 | +def main(save_path: str, n_texts: int) -> None: |
| 110 | + """Benchmark text embedding models and generate a plot.""" |
| 111 | + # Define the models to benchmark |
| 112 | + models: dict[str, list[str]] = { |
| 113 | + "BPEmb-50k-300d": ["", "BPEmb"], |
| 114 | + "all-MiniLM-L6-v2": ["sentence-transformers/all-MiniLM-L6-v2", "ST"], |
| 115 | + "bge-base-en-v1.5": ["BAAI/bge-base-en-v1.5", "ST"], |
| 116 | + "GloVe 6B 300d": ["sentence-transformers/average_word_embeddings_glove.6B.300d", "ST"], |
| 117 | + "potion-base-8M": ["minishlab/potion-base-8M", "M2V"], |
| 118 | + } |
| 119 | + |
| 120 | + # Load the dataset |
| 121 | + ds = load_dataset("wikimedia/wikipedia", data_files="20231101.en/train-00000-of-00041.parquet")["train"] |
| 122 | + texts = ds["text"][:n_texts] |
| 123 | + |
| 124 | + summarized_results = [ |
| 125 | + {"Model": "potion-base-2M", "Average score": 44.77, "Samples per second": None, "Params (Million)": 1.875}, |
| 126 | + {"Model": "GloVe 6B 300d", "Average score": 42.36, "Samples per second": None, "Params (Million)": 120.000}, |
| 127 | + {"Model": "potion-base-4M", "Average score": 48.23, "Samples per second": None, "Params (Million)": 3.750}, |
| 128 | + {"Model": "all-MiniLM-L6-v2", "Average score": 56.09, "Samples per second": None, "Params (Million)": 23.000}, |
| 129 | + {"Model": "potion-base-8M", "Average score": 50.03, "Samples per second": None, "Params (Million)": 7.500}, |
| 130 | + {"Model": "bge-base-en-v1.5", "Average score": 63.56, "Samples per second": None, "Params (Million)": 109.000}, |
| 131 | + {"Model": "M2V_base_output", "Average score": 45.34, "Samples per second": None, "Params (Million)": 7.500}, |
| 132 | + {"Model": "BPEmb-50k-300d", "Average score": 37.78, "Samples per second": None, "Params (Million)": 15.000}, |
| 133 | + {"Model": "potion-base-32M", "Average score": 51.66, "Samples per second": None, "Params (Million)": 32.300}, |
| 134 | + ] |
| 135 | + |
| 136 | + timings = {} |
| 137 | + |
| 138 | + for name, info in models.items(): |
| 139 | + timing = benchmark_model(name, info, texts) |
| 140 | + timings[name] = timing |
| 141 | + # Update summarized results |
| 142 | + for result in summarized_results: |
| 143 | + if result["Model"] == name: |
| 144 | + result["Samples per second"] = timing["docs_per_second"] |
| 145 | + |
| 146 | + # Set potion-base-8M as the reference speed for the other potion models |
| 147 | + potion_base_8m_speed = next( |
| 148 | + result["Samples per second"] for result in summarized_results if result["Model"] == "potion-base-8M" |
| 149 | + ) |
| 150 | + for model_name in ["M2V_base_output", "potion-base-2M", "potion-base-4M", "potion-base-32M"]: |
| 151 | + for result in summarized_results: |
| 152 | + if result["Model"] == model_name: |
| 153 | + result["Samples per second"] = potion_base_8m_speed |
| 154 | + |
| 155 | + # Ensure save_path is a directory |
| 156 | + save_dir = Path(save_path) |
| 157 | + save_dir.mkdir(parents=True, exist_ok=True) |
| 158 | + |
| 159 | + # Save timings to JSON |
| 160 | + json_path = save_dir / "speed_benchmark_results.json" |
| 161 | + with open(json_path, "w") as file: |
| 162 | + json.dump(timings, file, indent=4) |
| 163 | + |
| 164 | + # Create and save the plot |
| 165 | + df = pd.DataFrame(summarized_results) |
| 166 | + plot = make_plot(df) |
| 167 | + plot_path = save_dir / "speed_vs_mteb_plot.png" |
| 168 | + plot.save(plot_path, width=12, height=10) |
| 169 | + |
| 170 | + logger.info(f"Timings saved to {json_path}") |
| 171 | + logger.info(f"Plot saved to {plot_path}") |
| 172 | + |
| 173 | + |
| 174 | +if __name__ == "__main__": |
| 175 | + parser = argparse.ArgumentParser(description="Benchmark text embedding models and generate a plot.") |
| 176 | + parser.add_argument( |
| 177 | + "--save-path", type=str, required=True, help="Directory to save the benchmark results and plot." |
| 178 | + ) |
| 179 | + parser.add_argument( |
| 180 | + "--n-texts", type=int, default=100_000, help="Number of texts to use from the dataset for benchmarking." |
| 181 | + ) |
| 182 | + args = parser.parse_args() |
| 183 | + |
| 184 | + main(save_path=args.save_path, n_texts=args.n_texts) |
0 commit comments