-
Notifications
You must be signed in to change notification settings - Fork 97
Add fittable #140
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Add fittable #140
Changes from 19 commits
Commits
Show all changes
64 commits
Select commit
Hold shift + click to select a range
4078a3b
Fix tokenizer issue
stephantul 09f888d
fix issue with warning
stephantul 2167a4e
regenerate lock file
stephantul c95dca5
fix lock file
stephantul b5d8bb7
Try to not select 2.5.1
stephantul 3e68669
fix: issue with dividers in utils
stephantul 1ae4d61
Try to not select 2.5.0
stephantul 1349b0c
fix: do not up version
stephantul 4b83d59
Attempt special fix
stephantul 9515b83
merge
stephantul dfd865b
feat: add training
stephantul c4ba272
merge with old
stephantul 4713bfa
fix: no grad
stephantul e8058bb
use numpy
stephantul a59127e
Add train_test_split
stephantul 310fbb5
fix: issue with fit not resetting
stephantul b1899d1
feat: add lightning
stephantul e27f9dc
merge
stephantul 8df3aaf
Fix bugs
stephantul 839d88a
fix: reviewer comments
stephantul 8457357
fix train issue
stephantul a750709
fix issue with trainer
stephantul e83c54e
fix: truncate during training
stephantul 803565d
feat: tokenize maximum length truncation
stephantul 9052806
fixes
stephantul 2f9fbf4
typo
stephantul f1e08c3
Add progressbar
stephantul bb54a76
small code changes, add docs
stephantul 69ee4ee
fix training comments
stephantul 9962be7
Merge branch 'main' into add-fittable
stephantul ffec235
Add pipeline saving
stephantul 0af84fc
fix bug
stephantul c829745
fix issue with normalize test
stephantul 9ce65a1
change default batch size
stephantul e1169fb
feat: add sklearn skops pipeline
stephantul f096824
Device handling and automatic batch size
stephantul ff3ebdf
Add docstrings, defaults
stephantul b4e966a
docs
stephantul 8f65bfd
fix: rename
stephantul 8cdb668
fix: rename
stephantul e96a72a
fix installation
stephantul 3e76083
rename
stephantul 9f1cb5a
Add training tutorial
stephantul e2d92b9
Add tutorial link
stephantul 657cef0
Merge branch 'main' into add-fittable
stephantul 773009f
test: add tests
stephantul 7015341
fix tests
stephantul 8ab8456
tests: fix tests
stephantul e21e61f
Address comments
stephantul ff75af9
Add inference reqs to train reqs
stephantul 87de7c4
fix normalize
stephantul 1fb33f1
update lock file
stephantul 59f0076
Merge branch 'main' into add-fittable
stephantul 009342b
Merge branch 'main' into add-fittable
stephantul 261a9b4
fix: move modelcards
stephantul e1d53ac
fix: batch size
stephantul 6b5f991
update lock file
stephantul 759b96c
Update model2vec/inference/README.md
stephantul 7caf9bc
Update model2vec/inference/README.md
stephantul c7b68b6
Update model2vec/inference/README.md
stephantul be7baa1
Update model2vec/train/classifier.py
stephantul cc74618
fix: encode args
stephantul a4d8d6c
fix: trust_remote_code
stephantul a0d56d5
fix notebook
stephantul File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
from model2vec.utils import get_package_extras, importable | ||
|
||
_REQUIRED_EXTRA = "train" | ||
|
||
for extra_dependency in get_package_extras("model2vec", _REQUIRED_EXTRA): | ||
importable(extra_dependency, _REQUIRED_EXTRA) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,168 @@ | ||
from __future__ import annotations | ||
|
||
from typing import Any, TypeVar | ||
|
||
import torch | ||
from tokenizers import Encoding, Tokenizer | ||
from torch import nn | ||
from torch.nn.utils.rnn import pad_sequence | ||
from torch.utils.data import DataLoader, Dataset | ||
|
||
from model2vec import StaticModel | ||
|
||
|
||
class FinetunableStaticModel(nn.Module): | ||
def __init__(self, *, vectors: torch.Tensor, tokenizer: Tokenizer, out_dim: int, pad_id: int = 0) -> None: | ||
""" | ||
Initialize a trainable StaticModel from a StaticModel. | ||
|
||
:param vectors: The embeddings of the staticmodel. | ||
:param tokenizer: The tokenizer. | ||
:param out_dim: The output dimension of the head. | ||
:param pad_id: The padding id. This is set to 0 in almost all model2vec models | ||
""" | ||
super().__init__() | ||
self.pad_id = pad_id | ||
self.out_dim = out_dim | ||
self.embed_dim = vectors.shape[1] | ||
self.vectors = vectors | ||
|
||
self.embeddings = nn.Embedding.from_pretrained(vectors.clone(), freeze=False, padding_idx=pad_id) | ||
self.head = self.construct_head() | ||
|
||
weights = torch.zeros(len(vectors)) | ||
weights[pad_id] = -10_000 | ||
self.w = nn.Parameter(weights) | ||
self.tokenizer = tokenizer | ||
|
||
def construct_head(self) -> nn.Module: | ||
"""Method should be overridden for various other classes.""" | ||
return nn.Linear(self.embed_dim, self.out_dim) | ||
|
||
@classmethod | ||
def from_pretrained( | ||
cls: type[ModelType], out_dim: int, model_name: str = "minishlab/potion-base-8m", **kwargs: Any | ||
) -> ModelType: | ||
"""Load the model from a pretrained model2vec model.""" | ||
model = StaticModel.from_pretrained(model_name) | ||
return cls.from_static_model(model, out_dim, **kwargs) | ||
|
||
@classmethod | ||
def from_static_model(cls: type[ModelType], model: StaticModel, out_dim: int, **kwargs: Any) -> ModelType: | ||
"""Load the model from a static model.""" | ||
embeddings_converted = torch.from_numpy(model.embedding) | ||
return cls( | ||
vectors=embeddings_converted, | ||
pad_id=model.tokenizer.token_to_id("[PAD]"), | ||
out_dim=out_dim, | ||
tokenizer=model.tokenizer, | ||
**kwargs, | ||
) | ||
|
||
def _encode(self, input_ids: torch.Tensor) -> torch.Tensor: | ||
""" | ||
A forward pass and mean pooling. | ||
|
||
This function is analogous to `StaticModel.encode`, but reimplemented to allow gradients | ||
to pass through. | ||
|
||
:param input_ids: A 2D tensor of input ids. All input ids are have to be within bounds. | ||
:return: The mean over the input ids, weighted by token weights. | ||
""" | ||
w = self.w[input_ids] | ||
w = torch.sigmoid(w) | ||
zeros = (input_ids != self.pad_id).float() | ||
w = w * zeros | ||
# Add a small epsilon to avoid division by zero | ||
length = zeros.sum(1) + 1e-16 | ||
embedded = self.embeddings(input_ids) | ||
# Simulate actual mean | ||
# Zero out the padding | ||
embedded = torch.bmm(w[:, None, :], embedded).squeeze(1) | ||
# embedded = embedded.sum(1) | ||
embedded = embedded / length[:, None] | ||
|
||
return nn.functional.normalize(embedded) | ||
|
||
def forward(self, input_ids: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: | ||
"""Forward pass through the mean, and a classifier layer after.""" | ||
encoded = self._encode(input_ids) | ||
return self.head(encoded), encoded | ||
|
||
def tokenize(self, texts: list[str], max_length: int | None = 512) -> torch.Tensor: | ||
""" | ||
Tokenize a bunch of strings into a single padded 2D tensor. | ||
|
||
Note that this is not used during training. | ||
|
||
:param texts: The texts to tokenize. | ||
:param max_length: If this is None, the sequence lengths are truncated to 512. | ||
:return: A 2D padded tensor | ||
""" | ||
encoded: list[Encoding] = self.tokenizer.encode_batch_fast(texts, add_special_tokens=False) | ||
encoded_ids: list[torch.Tensor] = [torch.Tensor(encoding.ids[:max_length]).long() for encoding in encoded] | ||
return pad_sequence(encoded_ids, batch_first=True) | ||
|
||
@property | ||
def device(self) -> str: | ||
"""Get the device of the model.""" | ||
return self.embeddings.weight.device | ||
|
||
def to_static_model(self, config: dict[str, Any] | None = None) -> StaticModel: | ||
""" | ||
Convert the model to a static model. | ||
|
||
This is useful if you want to discard your head, and consolidate the information learned by | ||
the model to use it in a downstream task. | ||
|
||
:param config: The config used in the StaticModel. If this is set to None, it will have no config. | ||
:return: A static model. | ||
""" | ||
# Perform the forward pass on the selected device. | ||
with torch.no_grad(): | ||
all_indices = torch.arange(len(self.embeddings.weight))[:, None].to(self.device) | ||
vectors = self._encode(all_indices).cpu().numpy() | ||
|
||
new_model = StaticModel(vectors=vectors, tokenizer=self.tokenizer, config=config) | ||
|
||
return new_model | ||
|
||
|
||
class TextDataset(Dataset): | ||
def __init__(self, tokenized_texts: list[list[int]], targets: torch.Tensor) -> None: | ||
""" | ||
A dataset of texts. | ||
|
||
:param tokenized_texts: The tokenized texts. Each text is a list of token ids. | ||
:param targets: The targets. | ||
:raises ValueError: If the number of labels does not match the number of texts. | ||
""" | ||
if len(targets) != len(tokenized_texts): | ||
raise ValueError("Number of labels does not match number of texts.") | ||
self.tokenized_texts = tokenized_texts | ||
self.targets = targets | ||
|
||
def __len__(self) -> int: | ||
"""Return the length of the dataset.""" | ||
return len(self.tokenized_texts) | ||
|
||
def __getitem__(self, index: int) -> tuple[list[int], torch.Tensor]: | ||
"""Gets an item.""" | ||
return self.tokenized_texts[index], self.targets[index] | ||
|
||
@staticmethod | ||
def collate_fn(batch: list[tuple[list[list[int]], int]]) -> tuple[torch.Tensor, torch.Tensor]: | ||
"""Collate function.""" | ||
texts, targets = zip(*batch) | ||
|
||
tensors = [torch.LongTensor(x) for x in texts] | ||
padded = pad_sequence(tensors, batch_first=True, padding_value=0) | ||
|
||
return padded, torch.stack(targets) | ||
|
||
def to_dataloader(self, shuffle: bool, batch_size: int = 32) -> DataLoader: | ||
"""Convert the dataset to a DataLoader.""" | ||
return DataLoader(self, collate_fn=self.collate_fn, shuffle=shuffle, batch_size=batch_size) | ||
|
||
|
||
ModelType = TypeVar("ModelType", bound=FinetunableStaticModel) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,204 @@ | ||
from __future__ import annotations | ||
|
||
import logging | ||
from collections import Counter | ||
from typing import Any | ||
|
||
import lightning as pl | ||
import numpy as np | ||
import torch | ||
from lightning.pytorch.callbacks import Callback, EarlyStopping | ||
from lightning.pytorch.utilities.types import OptimizerLRScheduler | ||
from sklearn.model_selection import train_test_split | ||
from tokenizers import Tokenizer | ||
from torch import nn | ||
|
||
from model2vec.train.base import FinetunableStaticModel, TextDataset | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class ClassificationStaticModel(FinetunableStaticModel): | ||
def __init__( | ||
self, | ||
*, | ||
vectors: torch.Tensor, | ||
tokenizer: Tokenizer, | ||
n_layers: int, | ||
hidden_dim: int, | ||
out_dim: int, | ||
pad_id: int = 0, | ||
) -> None: | ||
stephantul marked this conversation as resolved.
Show resolved
Hide resolved
|
||
"""Initialize a standard classifier model.""" | ||
self.n_layers = n_layers | ||
self.hidden_dim = hidden_dim | ||
# Alias: Follows scikit-learn. Set to dummy classes | ||
self.classes_: list[str] = [str(x) for x in range(out_dim)] | ||
super().__init__(vectors=vectors, out_dim=out_dim, pad_id=pad_id, tokenizer=tokenizer) | ||
|
||
@property | ||
def classes(self) -> list[str]: | ||
"""Return all clasess in the correct order.""" | ||
return self.classes_ | ||
|
||
def construct_head(self) -> nn.Module: | ||
"""Constructs a simple classifier head.""" | ||
if self.n_layers == 0: | ||
return nn.Linear(self.embed_dim, self.out_dim) | ||
modules = [ | ||
nn.Linear(self.embed_dim, self.hidden_dim), | ||
nn.ReLU(), | ||
] | ||
for _ in range(self.n_layers - 1): | ||
modules.extend([nn.Linear(self.hidden_dim, self.hidden_dim), nn.ReLU()]) | ||
stephantul marked this conversation as resolved.
Show resolved
Hide resolved
|
||
modules.extend([nn.Linear(self.hidden_dim, self.out_dim)]) | ||
|
||
for module in modules: | ||
if isinstance(module, nn.Linear): | ||
nn.init.kaiming_uniform_(module.weight) | ||
nn.init.zeros_(module.bias) | ||
|
||
return nn.Sequential(*modules) | ||
|
||
def predict(self, X: list[str]) -> list[str]: | ||
"""Predict a class for a set of texts.""" | ||
pred: list[str] = [] | ||
for batch in range(0, len(X), 1024): | ||
logits = self._predict(X[batch : batch + 1024]) | ||
pred.extend([self.classes[idx] for idx in logits.argmax(1)]) | ||
|
||
return pred | ||
|
||
@torch.no_grad() | ||
def _predict(self, X: list[str]) -> torch.Tensor: | ||
input_ids = self.tokenize(X) | ||
vectors, _ = self.forward(input_ids) | ||
return vectors | ||
|
||
def predict_proba(self, X: list[str]) -> np.ndarray: | ||
"""Predict the probability of each class.""" | ||
pred: list[np.ndarray] = [] | ||
for batch in range(0, len(X), 1024): | ||
logits = self._predict(X[batch : batch + 1024]) | ||
pred.append(torch.softmax(logits, dim=1).numpy()) | ||
|
||
return np.concatenate(pred) | ||
|
||
def fit( | ||
stephantul marked this conversation as resolved.
Show resolved
Hide resolved
|
||
self, | ||
X: list[str], | ||
y: list[str], | ||
**kwargs: Any, | ||
) -> ClassificationStaticModel: | ||
"""Fit a model.""" | ||
pl.seed_everything(42) | ||
stephantul marked this conversation as resolved.
Show resolved
Hide resolved
|
||
classes = sorted(set(y)) | ||
self.classes_ = classes | ||
|
||
if len(self.classes) != self.out_dim: | ||
self.out_dim = len(self.classes) | ||
|
||
self.head = self.construct_head() | ||
self.embeddings = nn.Embedding.from_pretrained(self.vectors.clone(), freeze=False, padding_idx=self.pad_id) | ||
|
||
label_mapping = {label: idx for idx, label in enumerate(self.classes)} | ||
label_counts = Counter(y) | ||
if min(label_counts.values()) < 2: | ||
logger.info("Some classes have less than 2 samples. Stratification is disabled.") | ||
train_texts, validation_texts, train_labels, validation_labels = train_test_split( | ||
X, y, test_size=0.1, random_state=42, shuffle=True | ||
stephantul marked this conversation as resolved.
Show resolved
Hide resolved
|
||
) | ||
else: | ||
train_texts, validation_texts, train_labels, validation_labels = train_test_split( | ||
X, y, test_size=0.1, random_state=42, shuffle=True, stratify=y | ||
) | ||
|
||
# Turn labels into a LongTensor | ||
train_tokenized: list[list[int]] = [ | ||
encoding.ids for encoding in self.tokenizer.encode_batch_fast(train_texts, add_special_tokens=False) | ||
] | ||
train_labels_tensor = torch.Tensor([label_mapping[label] for label in train_labels]).long() | ||
train_dataset = TextDataset(train_tokenized, train_labels_tensor) | ||
|
||
val_tokenized: list[list[int]] = [ | ||
encoding.ids for encoding in self.tokenizer.encode_batch_fast(validation_texts, add_special_tokens=False) | ||
] | ||
val_labels_tensor = torch.Tensor([label_mapping[label] for label in validation_labels]).long() | ||
val_dataset = TextDataset(val_tokenized, val_labels_tensor) | ||
|
||
c = ClassifierLightningModule(self) | ||
|
||
batch_size = 32 | ||
stephantul marked this conversation as resolved.
Show resolved
Hide resolved
|
||
n_train_batches = len(train_dataset) // batch_size | ||
callbacks: list[Callback] = [EarlyStopping(monitor="val_accuracy", mode="max", patience=5)] | ||
if n_train_batches < 250: | ||
trainer = pl.Trainer(max_epochs=500, callbacks=callbacks, check_val_every_n_epoch=1) | ||
else: | ||
val_check_interval = max(250, 2 * len(val_dataset) // batch_size) | ||
trainer = pl.Trainer( | ||
max_epochs=500, callbacks=callbacks, val_check_interval=val_check_interval, check_val_every_n_epoch=None | ||
) | ||
|
||
trainer.fit( | ||
c, | ||
train_dataloaders=train_dataset.to_dataloader(shuffle=True, batch_size=batch_size), | ||
val_dataloaders=val_dataset.to_dataloader(shuffle=False, batch_size=batch_size), | ||
) | ||
best_model_path = trainer.checkpoint_callback.best_model_path # type: ignore | ||
|
||
state_dict = { | ||
k.removeprefix("model."): v for k, v in torch.load(best_model_path, weights_only=True)["state_dict"].items() | ||
} | ||
self.load_state_dict(state_dict) | ||
|
||
self.eval() | ||
|
||
return self | ||
|
||
|
||
class ClassifierLightningModule(pl.LightningModule): | ||
def __init__(self, model: ClassificationStaticModel) -> None: | ||
"""Initialize the lightningmodule.""" | ||
super().__init__() | ||
self.model = model | ||
|
||
def forward(self, x: torch.Tensor) -> torch.Tensor: | ||
"""Simple forward pass.""" | ||
return self.model(x) | ||
|
||
def training_step(self, batch: tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> torch.Tensor: | ||
"""Simple training step using cross entropy loss.""" | ||
x, y = batch | ||
head_out, _ = self.model(x) | ||
loss = nn.functional.cross_entropy(head_out, y).mean() | ||
|
||
self.log("train_loss", loss) | ||
return loss | ||
|
||
def validation_step(self, batch: tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> torch.Tensor: | ||
"""Simple validation step using cross entropy loss and accuracy.""" | ||
x, y = batch | ||
head_out, _ = self.model(x) | ||
loss = nn.functional.cross_entropy(head_out, y).mean() | ||
accuracy = (head_out.argmax(1) == y).float().mean() | ||
|
||
self.log("val_loss", loss) | ||
self.log("val_accuracy", accuracy, prog_bar=True) | ||
|
||
return loss | ||
|
||
def configure_optimizers(self) -> OptimizerLRScheduler: | ||
"""Simple Adam optimizer.""" | ||
optimizer = torch.optim.Adam(self.model.parameters(), lr=1e-3) | ||
stephantul marked this conversation as resolved.
Show resolved
Hide resolved
|
||
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( | ||
optimizer, | ||
mode="min", | ||
factor=0.5, | ||
patience=3, | ||
verbose=True, | ||
min_lr=1e-6, | ||
threshold=0.03, | ||
threshold_mode="rel", | ||
) | ||
|
||
return {"optimizer": optimizer, "lr_scheduler": {"scheduler": scheduler, "monitor": "val_loss"}} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.