Skip to content

Commit 274d917

Browse files
committed
Fix cache test
1 parent 9f11575 commit 274d917

File tree

1 file changed

+8
-3
lines changed

1 file changed

+8
-3
lines changed

tests/test_model.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,13 +22,18 @@ def test_model_most_similar_cache(s2v):
2222
query = "beekeepers|NOUN"
2323
assert s2v.cache
2424
assert query in s2v
25+
indices = s2v.cache["indices"]
2526
# Modify cache to test that the cache is used and values aren't computed
2627
query_row = s2v.vectors.find(key=s2v.ensure_int_key(query))
2728
scores = numpy.array(s2v.cache["scores"], copy=True) # otherwise not writable
2829
honey_bees_row = s2v.vectors.find(key="honey_bees|NOUN")
29-
scores[query_row, honey_bees_row] = 2.0
30-
beekeepers_row = s2v.vectors.find(key="Beekepers|NOUN")
31-
scores[query_row, beekeepers_row] = 3.0
30+
beekeepers_row = s2v.vectors.find(key="Beekeepers|NOUN")
31+
for i in range(indices.shape[0]):
32+
for j in range(indices.shape[1]):
33+
if indices[i, j] == honey_bees_row:
34+
scores[i, j] = 2.0
35+
elif indices[i, j] == beekeepers_row:
36+
scores[i, j] = 3.0
3237
s2v.cache["scores"] = scores
3338
((key1, score1), (key2, score2)) = s2v.most_similar([query], n=2)
3439
assert key1 == "honey_bees|NOUN"

0 commit comments

Comments
 (0)