Skip to content
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,13 @@ Model2Vec is a technique to turn any sentence transformer into a really small fa

## Quickstart

Install the package with:
Install the package and all required extras with:
```bash
pip install model2vec[distill]
```

If you want a light-weight version of the package which only requires `numpy`, omit the `distill` extra.

```bash
pip install model2vec
```
Expand Down Expand Up @@ -118,7 +124,7 @@ For more documentation, please refer to the [Sentence Transformers documentation

## Main Features

Model2Vec is:
Model2Vec has the following features:

- **Small**: reduces the size of a Sentence Transformer model by a factor of 15, from 120M params, down to 7.5M (30 MB on disk, making it the smallest model on [MTEB](https://huggingface.co/spaces/mteb/leaderboard)!).
- **Static, but better**: smaller than GLoVe and BPEmb, but [much more performant](results/README.md), even with the same vocabulary.
Expand Down
3 changes: 1 addition & 2 deletions model2vec/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from model2vec.distill import distill
from model2vec.model import StaticModel

__all__ = ["distill", "StaticModel"]
__all__ = ["StaticModel"]
7 changes: 7 additions & 0 deletions model2vec/distill/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,10 @@
from model2vec.utils import get_package_extras, importable

_REQUIRED_EXTRA = "distill"

for extra_dependency in get_package_extras("model2vec", _REQUIRED_EXTRA):
importable(extra_dependency, _REQUIRED_EXTRA)

from model2vec.distill.distillation import distill, distill_from_model

__all__ = ["distill", "distill_from_model"]
202 changes: 202 additions & 0 deletions model2vec/hf_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,202 @@
import json
import logging
from pathlib import Path
from typing import Any, Protocol, cast

import huggingface_hub
import huggingface_hub.errors
import numpy as np
import safetensors
from huggingface_hub import ModelCard, ModelCardData
from safetensors.numpy import save_file
from tokenizers import Tokenizer

logger = logging.getLogger(__name__)


class SafeOpenProtocol(Protocol):
"""Protocol to fix safetensors safe open."""

def get_tensor(self, key: str) -> np.ndarray:
"""Get a tensor."""
...

Check warning on line 22 in model2vec/hf_utils.py

View check run for this annotation

Codecov / codecov/patch

model2vec/hf_utils.py#L22

Added line #L22 was not covered by tests


def save_pretrained(
folder_path: Path,
embeddings: np.ndarray,
tokenizer: Tokenizer,
config: dict[str, Any],
create_model_card: bool = True,
**kwargs: Any,
) -> None:
"""
Save a model to a folder.

:param folder_path: The path to the folder.
:param embeddings: The embeddings.
:param tokenizer: The tokenizer.
:param config: A metadata config.
:param create_model_card: Whether to create a model card.
:param **kwargs: Any additional arguments.
"""
folder_path.mkdir(exist_ok=True, parents=True)
save_file({"embeddings": embeddings}, folder_path / "model.safetensors")
tokenizer.save(str(folder_path / "tokenizer.json"))
json.dump(config, open(folder_path / "config.json", "w"))

logger.info(f"Saved model to {folder_path}")

# Optionally create the model card
if create_model_card:
_create_model_card(folder_path, **kwargs)


def _create_model_card(
folder_path: Path,
base_model_name: str = "unknown",
license: str = "mit",
language: list[str] | None = None,
model_name: str | None = None,
**kwargs: Any,
) -> None:
"""
Create a model card and store it in the specified path.

:param folder_path: The path where the model card will be stored.
:param base_model_name: The name of the base model.
:param license: The license to use.
:param language: The language of the model.
:param model_name: The name of the model to use in the Model Card.
:param **kwargs: Additional metadata for the model card (e.g., model_name, base_model, etc.).
"""
folder_path = Path(folder_path)
model_name = model_name or folder_path.name
template_path = Path(__file__).parent / "model_card_template.md"

model_card_data = ModelCardData(
model_name=model_name,
base_model=base_model_name,
license=license,
language=language,
tags=["embeddings", "static-embeddings"],
library_name="model2vec",
**kwargs,
)
model_card = ModelCard.from_template(model_card_data, template_path=template_path)
model_card.save(folder_path / "README.md")


def load_pretrained(
folder_or_repo_path: str | Path, token: str | None = None
) -> tuple[np.ndarray, Tokenizer, dict[str, Any], dict[str, Any]]:
"""
Loads a pretrained model from a folder.

:param folder_or_repo_path: The folder or repo path to load from.
- If this is a local path, we will load from the local path.
- If the local path is not found, we will attempt to load from the huggingface hub.
:param token: The huggingface token to use.
:raises: FileNotFoundError if the folder exists, but the file does not exist locally.
:return: The embeddings, tokenizer, config, and metadata.

"""
folder_or_repo_path = Path(folder_or_repo_path)
if folder_or_repo_path.exists():
embeddings_path = folder_or_repo_path / "model.safetensors"
if not embeddings_path.exists():
old_embeddings_path = folder_or_repo_path / "embeddings.safetensors"
if old_embeddings_path.exists():
logger.warning("Old embeddings file found. Please rename to `model.safetensors` and re-save.")
embeddings_path = old_embeddings_path

Check warning on line 111 in model2vec/hf_utils.py

View check run for this annotation

Codecov / codecov/patch

model2vec/hf_utils.py#L108-L111

Added lines #L108 - L111 were not covered by tests
else:
raise FileNotFoundError(f"Embeddings file does not exist in {folder_or_repo_path}")

Check warning on line 113 in model2vec/hf_utils.py

View check run for this annotation

Codecov / codecov/patch

model2vec/hf_utils.py#L113

Added line #L113 was not covered by tests

config_path = folder_or_repo_path / "config.json"
if not config_path.exists():
raise FileNotFoundError(f"Config file does not exist in {folder_or_repo_path}")

Check warning on line 117 in model2vec/hf_utils.py

View check run for this annotation

Codecov / codecov/patch

model2vec/hf_utils.py#L117

Added line #L117 was not covered by tests

tokenizer_path = folder_or_repo_path / "tokenizer.json"
if not tokenizer_path.exists():
raise FileNotFoundError(f"Tokenizer file does not exist in {folder_or_repo_path}")

Check warning on line 121 in model2vec/hf_utils.py

View check run for this annotation

Codecov / codecov/patch

model2vec/hf_utils.py#L121

Added line #L121 was not covered by tests

# README is optional, so this is a bit finicky.
readme_path = folder_or_repo_path / "README.md"
metadata = _get_metadata_from_readme(readme_path)

else:
logger.info("Folder does not exist locally, attempting to use huggingface hub.")
try:
embeddings_path = huggingface_hub.hf_hub_download(

Check warning on line 130 in model2vec/hf_utils.py

View check run for this annotation

Codecov / codecov/patch

model2vec/hf_utils.py#L128-L130

Added lines #L128 - L130 were not covered by tests
folder_or_repo_path.as_posix(), "model.safetensors", token=token
)
except huggingface_hub.utils.EntryNotFoundError as e:
try:
embeddings_path = huggingface_hub.hf_hub_download(

Check warning on line 135 in model2vec/hf_utils.py

View check run for this annotation

Codecov / codecov/patch

model2vec/hf_utils.py#L133-L135

Added lines #L133 - L135 were not covered by tests
folder_or_repo_path.as_posix(), "embeddings.safetensors", token=token
)
except huggingface_hub.utils.EntryNotFoundError:

Check warning on line 138 in model2vec/hf_utils.py

View check run for this annotation

Codecov / codecov/patch

model2vec/hf_utils.py#L138

Added line #L138 was not covered by tests
# Raise original exception.
raise e

Check warning on line 140 in model2vec/hf_utils.py

View check run for this annotation

Codecov / codecov/patch

model2vec/hf_utils.py#L140

Added line #L140 was not covered by tests

try:
readme_path = huggingface_hub.hf_hub_download(folder_or_repo_path.as_posix(), "README.md", token=token)
metadata = _get_metadata_from_readme(Path(readme_path))
except huggingface_hub.utils.EntryNotFoundError:
logger.info("No README found in the model folder. No model card loaded.")
metadata = {}

Check warning on line 147 in model2vec/hf_utils.py

View check run for this annotation

Codecov / codecov/patch

model2vec/hf_utils.py#L142-L147

Added lines #L142 - L147 were not covered by tests

config_path = huggingface_hub.hf_hub_download(folder_or_repo_path.as_posix(), "config.json", token=token)
tokenizer_path = huggingface_hub.hf_hub_download(folder_or_repo_path.as_posix(), "tokenizer.json", token=token)

Check warning on line 150 in model2vec/hf_utils.py

View check run for this annotation

Codecov / codecov/patch

model2vec/hf_utils.py#L149-L150

Added lines #L149 - L150 were not covered by tests

opened_tensor_file = cast(SafeOpenProtocol, safetensors.safe_open(embeddings_path, framework="numpy"))
embeddings = opened_tensor_file.get_tensor("embeddings")

tokenizer: Tokenizer = Tokenizer.from_file(str(tokenizer_path))
config = json.load(open(config_path))

if len(tokenizer.get_vocab()) != len(embeddings):
logger.warning(

Check warning on line 159 in model2vec/hf_utils.py

View check run for this annotation

Codecov / codecov/patch

model2vec/hf_utils.py#L159

Added line #L159 was not covered by tests
f"Number of tokens does not match number of embeddings: `{len(tokenizer.get_vocab())}` vs `{len(embeddings)}`"
)

return embeddings, tokenizer, config, metadata


def _get_metadata_from_readme(readme_path: Path) -> dict[str, Any]:
"""Get metadata from a README file."""
if not readme_path.exists():
logger.info(f"README file not found in {readme_path}. No model card loaded.")
return {}
model_card = ModelCard.load(readme_path)
data: dict[str, Any] = model_card.data.to_dict()
if not data:
logger.info("File README.md exists, but was empty. No model card loaded.")
return data


def push_folder_to_hub(folder_path: Path, repo_id: str, private: bool, token: str | None) -> None:
"""
Push a model folder to the huggingface hub, including model card.

:param folder_path: The path to the folder.
:param repo_id: The repo name.
:param private: Whether the repo is private.
:param token: The huggingface token.
"""
if not huggingface_hub.repo_exists(repo_id=repo_id, token=token):
huggingface_hub.create_repo(repo_id, token=token, private=private)

Check warning on line 188 in model2vec/hf_utils.py

View check run for this annotation

Codecov / codecov/patch

model2vec/hf_utils.py#L187-L188

Added lines #L187 - L188 were not covered by tests

# Push model card and all model files to the Hugging Face hub
huggingface_hub.upload_folder(repo_id=repo_id, folder_path=folder_path, token=token)

Check warning on line 191 in model2vec/hf_utils.py

View check run for this annotation

Codecov / codecov/patch

model2vec/hf_utils.py#L191

Added line #L191 was not covered by tests

# Check if the model card exists, and push it if available
model_card_path = folder_path / "README.md"
if model_card_path.exists():
card = ModelCard.load(model_card_path)
card.push_to_hub(repo_id=repo_id, token=token)
logger.info(f"Pushed model card to {repo_id}")

Check warning on line 198 in model2vec/hf_utils.py

View check run for this annotation

Codecov / codecov/patch

model2vec/hf_utils.py#L194-L198

Added lines #L194 - L198 were not covered by tests
else:
logger.warning(f"Model card README.md not found in {folder_path}. Skipping model card upload.")

Check warning on line 200 in model2vec/hf_utils.py

View check run for this annotation

Codecov / codecov/patch

model2vec/hf_utils.py#L200

Added line #L200 was not covered by tests

logger.info(f"Pushed model to {repo_id}")

Check warning on line 202 in model2vec/hf_utils.py

View check run for this annotation

Codecov / codecov/patch

model2vec/hf_utils.py#L202

Added line #L202 was not covered by tests
Loading