-
Notifications
You must be signed in to change notification settings - Fork 8
Description
In building a Model2Vec model, I've been exploring different parameter configurations.
With that, I've also looked at the post training regularization. I explored a similar problem space years back (see this article).
Back at that time, I did something similar, except that process weighted fastText embeddings. I found that BM25 weighting worked pretty well.
Not sure if you've explored this but I did a quick prototype with a model I'm training and found a performance gain - the pearson correlation coefficient (PCC) increased from 90.37 to 91.99.
The code I used is below if you'd like to try. This can be called instead of weight_model.
import numpy as np
from model2vec import StaticModel
from more_itertools import batched
from sklearn.decomposition import PCA
from tokenlearn.train import train_model
from txtai.scoring import ScoringFactory
from tqdm import tqdm
def tokenize(tokenizer, texts):
for t in tqdm(batched(texts, 1024)):
encodings = tokenizer.encode_batch_fast(t, add_special_tokens=False)
for e in encodings:
yield (None, e.ids, None)
def weight(model, texts, pca, method):
tokenizer = model.tokenizer
# Build scoring index
scoring = ScoringFactory.create({"method": method, "terms": True})
scoring.index(tokenize(tokenizer, texts))
# Calculate weights
scores = {}
for token in scoring.idf:
_, weights = scoring.terms.weights(token)
scores[token] = np.mean(weights)
# Get weights array
f = np.zeros(tokenizer.get_vocab_size())
for uid, score in scores.items():
f[uid] += score
# Get embeddings
w = model.embedding
w = np.nan_to_num(w)
# Apply PCA
p = PCA(n_components=pca)
w = p.fit_transform(w)
# Apply weights
w *= f[:, None]
# Save embeddings to model and normalize
model.embedding = w
model.normalize = True
return model
# Train the model
model = train_model(name, texts, vectors)
# Weight using BM25
weight(model, texts, 256, "bm25")The code above uses BM25 scoring from txtai but there are other Python libraries available as well from BM25 scoring or you could roll your own.