Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 69 additions & 31 deletions worker/worker.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import (
"net/http"
"os"
"strings"
"sync"
"time"

mapset "github.com/deckarep/golang-set/v2"
Expand Down Expand Up @@ -99,6 +100,8 @@ type Worker struct {
conn *grpc.ClientConn
masterClient protocol.MasterClient

// model state
modelMutex sync.RWMutex
latestRankingModelVersion int64
latestClickModelVersion int64
matrixFactorization *logics.MatrixFactorization
Expand Down Expand Up @@ -253,14 +256,37 @@ func (w *Worker) Sync() {
}
}

// getModelState returns the current ranking model and matrix factorization in a thread-safe way.
// The returned values should be used immediately and not stored.
func (w *Worker) getModelState() (cf.MatrixFactorization, *logics.MatrixFactorization) {
w.modelMutex.RLock()
defer w.modelMutex.RUnlock()
return w.RankingModel, w.matrixFactorization
}

func (w *Worker) getModelVersion() int64 {
w.modelMutex.RLock()
defer w.modelMutex.RUnlock()
return w.RankingModelVersion
}

// updateModel updates the model state and invalidates the matrix factorization.
func (w *Worker) updateModel(model cf.MatrixFactorization, version int64) {
w.modelMutex.Lock()
defer w.modelMutex.Unlock()
w.RankingModel = model
w.matrixFactorization = nil
w.RankingModelVersion = version
}

// Pull user index and ranking model from master.
func (w *Worker) Pull() {
defer base.CheckPanic()
for range w.syncedChan.C {
pulled := false

// pull ranking model
if w.latestRankingModelVersion != w.RankingModelVersion {
if w.latestRankingModelVersion != w.getModelVersion() {
log.Logger().Info("start pull ranking model")
if rankingModelReceiver, err := w.masterClient.GetRankingModel(context.Background(),
&protocol.VersionInfo{Version: w.latestRankingModelVersion},
Expand All @@ -272,12 +298,10 @@ func (w *Worker) Pull() {
if err != nil {
log.Logger().Error("failed to unmarshal ranking model", zap.Error(err))
} else {
w.RankingModel = rankingModel
w.matrixFactorization = nil
w.RankingModelVersion = w.latestRankingModelVersion
w.updateModel(rankingModel, w.latestRankingModelVersion)
log.Logger().Info("synced ranking model",
zap.String("version", encoding.Hex(w.RankingModelVersion)))
MemoryInuseBytesVec.WithLabelValues("collaborative_filtering_model").Set(float64(sizeof.DeepSize(w.RankingModel)))
zap.String("version", encoding.Hex(w.latestRankingModelVersion)))
MemoryInuseBytesVec.WithLabelValues("collaborative_filtering_model").Set(float64(sizeof.DeepSize(rankingModel)))
pulled = true
}
}
Expand Down Expand Up @@ -542,20 +566,29 @@ func (w *Worker) Recommend(users []data.User) {
}()

// build ranking index
if w.RankingModel != nil && !w.RankingModel.Invalid() && w.matrixFactorization == nil {
startTime := time.Now()
log.Logger().Info("start building ranking index")
itemIndex := w.RankingModel.GetItemIndex()
matrixFactorization := logics.NewMatrixFactorization(time.Now())
for i := int32(0); i < itemIndex.Count(); i++ {
if itemId, ok := itemIndex.String(i); ok && itemCache.IsAvailable(itemId) {
item, _ := itemCache.Get(itemId)
matrixFactorization.Add(item, w.RankingModel.GetItemFactor(int32(i)))
rankingModel, matrix := w.getModelState()
if rankingModel != nil && !rankingModel.Invalid() && matrix == nil {
w.modelMutex.Lock()

// double check
if w.matrixFactorization == nil {
startTime := time.Now()
log.Logger().Info("start building ranking index")
itemIndex := rankingModel.GetItemIndex()
matrixFactorization := logics.NewMatrixFactorization(time.Now())
for i := int32(0); i < itemIndex.Count(); i++ {
if itemId, ok := itemIndex.String(i); ok && itemCache.IsAvailable(itemId) {
item, _ := itemCache.Get(itemId)
itemFactor := rankingModel.GetItemFactor(int32(i))
matrixFactorization.Add(item, itemFactor)
}
}
w.matrixFactorization = matrixFactorization
log.Logger().Info("complete building ranking index",
zap.Duration("build_time", time.Since(startTime)))
}
w.matrixFactorization = matrixFactorization
log.Logger().Info("complete building ranking index",
zap.Duration("build_time", time.Since(startTime)))
matrix = w.matrixFactorization
w.modelMutex.Unlock()
}

// recommendation
Expand Down Expand Up @@ -613,11 +646,11 @@ func (w *Worker) Recommend(users []data.User) {

// Recommender #1: collaborative filtering.
collaborativeUsed := false
if w.Config.Recommend.Offline.EnableColRecommend && w.RankingModel != nil && !w.RankingModel.Invalid() {
if userIndex := w.RankingModel.GetUserIndex().Id(userId); w.RankingModel.IsUserPredictable(int32(userIndex)) {
if w.Config.Recommend.Offline.EnableColRecommend && rankingModel != nil && !rankingModel.Invalid() {
if userIndex := rankingModel.GetUserIndex().Id(userId); rankingModel.IsUserPredictable(int32(userIndex)) {
var recommend map[string][]string
var usedTime time.Duration
recommend, usedTime, err = w.collaborativeRecommendHNSW(w.matrixFactorization, userId, excludeSet, itemCache)
recommend, usedTime, err = w.collaborativeRecommendHNSW(rankingModel, matrix, userId, excludeSet, itemCache)
if err != nil {
log.Logger().Error("failed to recommend by collaborative filtering",
zap.String("user_id", userId), zap.Error(err))
Expand All @@ -628,10 +661,10 @@ func (w *Worker) Recommend(users []data.User) {
}
collaborativeUsed = true
collaborativeRecommendSeconds.Add(usedTime.Seconds())
} else if !w.RankingModel.IsUserPredictable(int32(userIndex)) {
} else if !rankingModel.IsUserPredictable(int32(userIndex)) {
log.Logger().Debug("user is unpredictable", zap.String("user_id", userId))
}
} else if w.RankingModel == nil || w.RankingModel.Invalid() {
} else if rankingModel == nil || rankingModel.Invalid() {
log.Logger().Debug("no collaborative filtering model")
}

Expand Down Expand Up @@ -792,8 +825,8 @@ func (w *Worker) Recommend(users []data.User) {
return errors.Trace(err)
}
ctrUsed = true
} else if w.RankingModel != nil && !w.RankingModel.Invalid() &&
w.RankingModel.IsUserPredictable(w.RankingModel.GetUserIndex().Id(userId)) {
} else if rankingModel != nil && !rankingModel.Invalid() &&
rankingModel.IsUserPredictable(rankingModel.GetUserIndex().Id(userId)) {
results[category], err = w.rankByCollaborativeFiltering(userId, catCandidates)
if err != nil {
log.Logger().Error("failed to rank items", zap.Error(err))
Expand Down Expand Up @@ -860,11 +893,12 @@ func (w *Worker) Recommend(users []data.User) {
OfflineRecommendStepSecondsVec.WithLabelValues("popular_recommend").Set(popularRecommendSeconds.Load())
}

func (w *Worker) collaborativeRecommendHNSW(rankingIndex *logics.MatrixFactorization, userId string, excludeSet mapset.Set[string], itemCache *ItemCache) (map[string][]string, time.Duration, error) {
func (w *Worker) collaborativeRecommendHNSW(rankingModel cf.MatrixFactorization, rankingIndex *logics.MatrixFactorization, userId string, excludeSet mapset.Set[string], itemCache *ItemCache) (map[string][]string, time.Duration, error) {
ctx := context.Background()
userIndex := w.RankingModel.GetUserIndex().Id(userId)
userIndex := rankingModel.GetUserIndex().Id(userId)
userFactor := rankingModel.GetUserFactor(userIndex)
localStartTime := time.Now()
scores := rankingIndex.Search(w.RankingModel.GetUserFactor(userIndex), w.Config.Recommend.CacheSize+excludeSet.Cardinality())
scores := rankingIndex.Search(userFactor, w.Config.Recommend.CacheSize+excludeSet.Cardinality())
// save result
recommend := make(map[string][]string)
for _, score := range scores {
Expand Down Expand Up @@ -901,11 +935,12 @@ func (w *Worker) rankByCollaborativeFiltering(userId string, candidates [][]stri
}
}
// rank by collaborative filtering
rankingModel, _ := w.getModelState()
topItems := make([]cache.Score, 0, len(candidates))
for _, itemId := range itemIds {
topItems = append(topItems, cache.Score{
Id: itemId,
Score: float64(w.RankingModel.Predict(userId, itemId)),
Score: float64(rankingModel.Predict(userId, itemId)),
})
}
cache.SortDocuments(topItems)
Expand Down Expand Up @@ -1229,6 +1264,9 @@ func (w *Worker) replacement(recommend map[string][]cache.Score, user *data.User
}
}

// Get model state once before the loop
rankingModel, _ := w.getModelState()

for _, itemId := range distinctItems.ToSlice() {
if item, exist := itemCache.Get(itemId); exist {
// scoring item
Expand All @@ -1238,8 +1276,8 @@ func (w *Worker) replacement(recommend map[string][]cache.Score, user *data.User
var score float64
if w.Config.Recommend.Offline.EnableClickThroughPrediction && w.ClickModel != nil {
score = float64(w.ClickModel.Predict(user.UserId, itemId, click.ConvertLabelsToFeatures(user.Labels), click.ConvertLabelsToFeatures(item.Labels)))
} else if w.RankingModel != nil && !w.RankingModel.Invalid() && w.RankingModel.IsUserPredictable(w.RankingModel.GetUserIndex().Id(user.UserId)) {
score = float64(w.RankingModel.Predict(user.UserId, itemId))
} else if rankingModel != nil && !rankingModel.Invalid() && rankingModel.IsUserPredictable(rankingModel.GetUserIndex().Id(user.UserId)) {
score = float64(rankingModel.Predict(user.UserId, itemId))
} else {
upper := upperBounds[""]
lower := lowerBounds[""]
Expand Down
Loading