Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 26 additions & 5 deletions model2vec/train/classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,8 @@ def fit(
early_stopping_patience: int | None = 5,
test_size: float = 0.1,
device: str = "auto",
X_val: list[str] | None = None,
y_val: LabelType | None = None,
) -> StaticModelForClassification:
"""
Fit a model.
Expand All @@ -146,6 +148,9 @@ def fit(
This function seeds everything with a seed of 42, so the results are reproducible.
It also splits the data into a train and validation set, again with a random seed.

If `X_val` and `y_val` are not provided, the function will automatically
split the training data into a train and validation set using `test_size`.

:param X: The texts to train on.
:param y: The labels to train on. If the first element is a list, multi-label classification is assumed.
:param learning_rate: The learning rate.
Expand All @@ -157,7 +162,10 @@ def fit(
If this is None, early stopping is disabled.
:param test_size: The test size for the train-test split.
:param device: The device to train on. If this is "auto", the device is chosen automatically.
:param X_val: The texts to be used for validation.
:param y_val: The labels to be used for validation.
:return: The fitted model.
:raises ValueError: If either X_val or y_val are provided, but not both.
"""
pl.seed_everything(_RANDOM_SEED)
logger.info("Re-initializing model.")
Expand All @@ -166,11 +174,24 @@ def fit(

self._initialize(y)

train_texts, validation_texts, train_labels, validation_labels = self._train_test_split(
X,
y,
test_size=test_size,
)
if (X_val is not None) != (y_val is not None):
raise ValueError("Both X_val and y_val must be provided together, or neither.")

if X_val is not None and y_val is not None:
# Additional check to ensure y_val is of the same type as y
if type(y_val[0]) != type(y[0]):
raise ValueError("X_val and y_val must be of the same type as X and y.")

train_texts = X
train_labels = y
validation_texts = X_val
validation_labels = y_val
else:
train_texts, validation_texts, train_labels, validation_labels = self._train_test_split(
X,
y,
test_size=test_size,
)

if batch_size is None:
# Set to a multiple of 32
Expand Down
45 changes: 45 additions & 0 deletions tests/test_trainable.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import pytest
import torch
from tokenizers import Tokenizer
from transformers import AutoTokenizer

from model2vec.model import StaticModel
from model2vec.train import StaticModelForClassification
Expand Down Expand Up @@ -154,6 +155,50 @@ def test_train_test_split(mock_trained_pipeline: StaticModelForClassification) -
assert len(d) == len(b)


def test_y_val_none() -> None:
"""Test the y_val function."""
tokenizer = AutoTokenizer.from_pretrained("tests/data/test_tokenizer").backend_tokenizer
torch.random.manual_seed(42)
vectors_torched = torch.randn(len(tokenizer.get_vocab()), 12)
model = StaticModelForClassification(vectors=vectors_torched, tokenizer=tokenizer, hidden_dim=12).to("cpu")

X = ["dog", "cat"]
y = ["0", "1"]

X_val = ["dog", "cat"]
y_val = ["0", "1"]

with pytest.raises(ValueError):
model.fit(X, y, X_val=X_val, y_val=None)
with pytest.raises(ValueError):
model.fit(X, y, X_val=None, y_val=y_val)
model.fit(X, y, X_val=None, y_val=None)


@pytest.mark.parametrize(
"y_multi,y_val_multi,should_crash",
[[True, True, False], [False, False, False], [True, False, True], [False, True, True]],
)
def test_y_val(y_multi: bool, y_val_multi: bool, should_crash: bool) -> None:
"""Test the y_val function."""
tokenizer = AutoTokenizer.from_pretrained("tests/data/test_tokenizer").backend_tokenizer
torch.random.manual_seed(42)
vectors_torched = torch.randn(len(tokenizer.get_vocab()), 12)
model = StaticModelForClassification(vectors=vectors_torched, tokenizer=tokenizer, hidden_dim=12).to("cpu")

X = ["dog", "cat"]
y = [["0", "1"], ["0"]] if y_multi else ["0", "1"] # type: ignore

X_val = ["dog", "cat"]
y_val = [["0", "1"], ["0"]] if y_val_multi else ["0", "1"] # type: ignore

if should_crash:
with pytest.raises(ValueError):
model.fit(X, y, X_val=X_val, y_val=y_val)
else:
model.fit(X, y, X_val=X_val, y_val=y_val)


def test_evaluate(mock_trained_pipeline: StaticModelForClassification) -> None:
"""Test the evaluate function."""
if mock_trained_pipeline.multilabel:
Expand Down