Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 82 additions & 7 deletions pkg/tokenization/prefixstore/lru_store_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,18 @@ limitations under the License.
package prefixstore //nolint:testpackage // convenience

import (
"strings"
"testing"

"github.com/daulet/tokenizers"
"github.com/stretchr/testify/assert"
)

func TestLRUTokenStore_AddAndRetrieve(t *testing.T) {
//nolint:gocritic
func setupTestLRUTokenStore(t *testing.T, blockSize int) (*LRUTokenStore, string, string, []uint32) {
t.Helper()
store, err := NewLRUTokenStore(&Config{
&LRUStoreConfig{CacheSize: defaultMaxCacheSize, BlockSize: 4},
&LRUStoreConfig{CacheSize: defaultMaxCacheSize, BlockSize: blockSize},
})
assert.NoError(t, err)

Expand All @@ -36,15 +39,87 @@ func TestLRUTokenStore_AddAndRetrieve(t *testing.T) {
{0, 3}, {4, 11}, {12, 14}, {15, 21}, {22, 24}, {25, 30},
}

// Add tokenization to the store
err = store.AddTokenization(modelName, text, tokens, offsets)
assert.NoError(t, err)

// Retrieve tokens for a matching prefix
lru, ok := store.(*LRUTokenStore)
assert.True(t, ok)
return lru, modelName, text, tokens
}

func TestLRUTokenStore_AddAndRetrieve(t *testing.T) {
lruStore, modelName, _, _ := setupTestLRUTokenStore(t, 4)

prompt := "The capital of F"
result, overlapRatio := store.FindLongestContainedTokens(prompt, modelName)
assert.Equal(t, []uint32{1, 2, 3}, result)
assert.Equal(t, overlapRatio, 1.0)
actualTokens, overlapRatio := lruStore.FindLongestContainedTokens(prompt, modelName)
assert.Equal(t, []uint32{1, 2, 3}, actualTokens, "FindLongestContainedTokens(%q) result mismatch", prompt)
assert.Equal(t, 1.0, overlapRatio, "FindLongestContainedTokens(%q) overlapRatio should be 1.0", prompt)
}

func TestLRUTokenStore_PartialMismatch(t *testing.T) {
blockSize := 4
lruStore, modelName, _, expectedTokens := setupTestLRUTokenStore(t, blockSize)

cases := []struct {
prompt string
matchPrompt string
tokenContains int
}{
{
prompt: "The capital of France is Marseille",
matchPrompt: "The capital of France is",
tokenContains: 4,
},
{
prompt: "The capital of Japan is Tokyo",
matchPrompt: "The capital ",
tokenContains: 2,
},
{
prompt: "The capital of F",
matchPrompt: "The capital of F",
tokenContains: 3,
},
{
prompt: "The capital of France is Paris",
matchPrompt: "The capital of France is Par",
tokenContains: 5,
},
}
for _, c := range cases {
numerator := len(c.matchPrompt)
denominator := len(c.prompt)
expectedRatio := float64(numerator) / float64(denominator)
actualTokens, overlapRatio := lruStore.FindLongestContainedTokens(c.prompt, modelName)
t.Logf("prompt: %q, actualTokens: %v, overlapRatio: %v\n", c.prompt, actualTokens, overlapRatio)
assert.Subset(t, expectedTokens, actualTokens, "prompt=%q, result should be subset of tokens", c.prompt)
assert.LessOrEqual(t, c.tokenContains, len(actualTokens), "prompt=%q, token count mismatch", c.prompt)
assert.Equal(t, expectedRatio, overlapRatio, "prompt=%q, overlapRatio mismatch", c.prompt)
}
}

func TestLRUTokenStore_PrefixMatch(t *testing.T) {
blockSize := 4
lruStore, modelName, text, expectedTokens := setupTestLRUTokenStore(t, blockSize)

words := strings.Split(text, " ")
prefix := ""
for i, word := range words {
if i > 0 {
prefix += " "
}
prefix += word

actualTokens, overlapRatio := lruStore.FindLongestContainedTokens(prefix, modelName)
t.Logf("word: %q, prefix: %q, actualTokens: %v, overlapRatio: %v", word, prefix, actualTokens, overlapRatio)
assert.Subset(t, expectedTokens, actualTokens, "prefix=%q, result should be subset of tokens", prefix)

b := len(prefix) / blockSize
minVal := float64(b*blockSize) / float64(len(prefix))
assert.GreaterOrEqual(t, len(actualTokens), i, "prefix=%q, token count should be >= i", prefix)
assert.LessOrEqual(t, len(actualTokens), i+1, "prefix=%q, token count should be <= i+1", prefix)
assert.GreaterOrEqual(t, overlapRatio, minVal, "prefix=%q, overlapRatio should be >= min", prefix)
}
}

func TestLRUTokenStore_LRUEviction(t *testing.T) {
Expand Down
Loading