Skip to content

Commit 4c1a3ec

Browse files
committed
Merge branch 'master' into develop
2 parents 78cbbf6 + 01a700c commit 4c1a3ec

File tree

4 files changed

+132
-50
lines changed

4 files changed

+132
-50
lines changed

scripts/06_precompute_cache.py

Lines changed: 94 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,8 @@ def main(
3333
import cupy as xp
3434
import cupy.cuda.device
3535

36+
cupy.take_along_axis = take_along_axis
37+
cupy.put_along_axis = put_along_axis
3638
device = cupy.cuda.device.Device(gpu_id)
3739
device.use()
3840
vectors_dir = Path(vectors)
@@ -56,35 +58,26 @@ def main(
5658
msg.good(f"Normalized (mean {mean:,.2f}, variance {var:,.2f})")
5759
msg.info(f"Finding {n_neighbors:,} neighbors among {cutoff:,} most frequent")
5860
n = min(n_neighbors, vectors.shape[0])
61+
subset = vectors[:cutoff]
5962
best_rows = xp.zeros((end - start, n), dtype="i")
6063
scores = xp.zeros((end - start, n), dtype="f")
61-
# Pre-allocate this array, so we can use it each time.
62-
subset = xp.ascontiguousarray(vectors[:cutoff])
63-
sims = xp.zeros((batch_size, cutoff), dtype="f")
64-
indices = xp.arange(cutoff).reshape((-1, 1))
6564
for i in tqdm.tqdm(list(range(start, end, batch_size))):
66-
batch = vectors[i : i + batch_size]
67-
# batch e.g. (1024, 300)
68-
# vectors e.g. (10000, 300)
69-
# sims e.g. (1024, 10000)
70-
if batch.shape[0] == sims.shape[0]:
71-
xp.dot(batch, subset.T, out=sims)
72-
else:
73-
# In the last batch we'll have a different size.
74-
sims = xp.dot(batch, subset.T)
75-
size = sims.shape[0]
76-
# Get the indices and scores for the top N most similar for each in the
77-
# batch. This is a bit complicated, to avoid sorting all of the scores
78-
# -- we only want the top N to be sorted (which we do later). For now,
79-
# we use argpartition to just get the cut point.
80-
neighbors = xp.argpartition(sims, -n, axis=1)[:, -n:]
81-
neighbor_sims = xp.partition(sims, -n, axis=1)[:, -n:]
82-
# Can't figure out how to do this without the loop.
83-
for j in range(min(end - i, size)):
84-
# Sort in reverse order
85-
indices = xp.argsort(neighbor_sims[j], axis=-1)[::-1]
86-
best_rows[i + j] = xp.take(neighbors[j], indices)
87-
scores[i + j] = xp.take(neighbor_sims[j], indices)
65+
size = min(batch_size, end - i)
66+
batch = vectors[i : i + size]
67+
sims = xp.dot(batch, subset.T)
68+
# Set self-similarities to -inf, so that we don't return them.
69+
indices = xp.arange(i, min(i + size, sims.shape[1])).reshape((-1, 1))
70+
xp.put_along_axis(sims, indices, -xp.inf, axis=1)
71+
# This used to use argpartition, to do a partial sort...But this ended
72+
# up being a ratsnest of terrible numpy crap. Just sorting the whole
73+
# list isn't really slower, and it's much simpler to read.
74+
ranks = xp.argsort(sims, axis=1)
75+
batch_rows = ranks[:, -n:]
76+
# Reverse
77+
batch_rows = batch_rows[:, ::-1]
78+
batch_scores = xp.take_along_axis(sims, batch_rows, axis=1)
79+
best_rows[i : i + size] = batch_rows
80+
scores[i : i + size] = batch_scores
8881
msg.info("Saving output")
8982
if not isinstance(best_rows, numpy.ndarray):
9083
best_rows = best_rows.get()
@@ -103,6 +96,81 @@ def main(
10396
msg.good(f"Saved cache to {output_file}")
10497

10598

99+
# These functions are missing from cupy, but will be supported in cupy 7.
100+
def take_along_axis(a, indices, axis):
101+
"""Take values from the input array by matching 1d index and data slices.
102+
103+
Args:
104+
a (cupy.ndarray): Array to extract elements.
105+
indices (cupy.ndarray): Indices to take along each 1d slice of ``a``.
106+
axis (int): The axis to take 1d slices along.
107+
108+
Returns:
109+
cupy.ndarray: The indexed result.
110+
111+
.. seealso:: :func:`numpy.take_along_axis`
112+
"""
113+
import cupy
114+
115+
if indices.dtype.kind not in ("i", "u"):
116+
raise IndexError("`indices` must be an integer array")
117+
118+
if axis is None:
119+
a = a.ravel()
120+
axis = 0
121+
122+
ndim = a.ndim
123+
124+
if not (-ndim <= axis < ndim):
125+
raise IndexError("Axis overrun")
126+
127+
axis %= a.ndim
128+
129+
if ndim != indices.ndim:
130+
raise ValueError("`indices` and `a` must have the same number of dimensions")
131+
132+
fancy_index = []
133+
for i, n in enumerate(a.shape):
134+
if i == axis:
135+
fancy_index.append(indices)
136+
else:
137+
ind_shape = (1,) * i + (-1,) + (1,) * (ndim - i - 1)
138+
fancy_index.append(cupy.arange(n).reshape(ind_shape))
139+
140+
return a[fancy_index]
141+
142+
143+
def put_along_axis(a, indices, value, axis):
144+
import cupy
145+
146+
if indices.dtype.kind not in ("i", "u"):
147+
raise IndexError("`indices` must be an integer array")
148+
149+
if axis is None:
150+
a = a.ravel()
151+
axis = 0
152+
153+
ndim = a.ndim
154+
155+
if not (-ndim <= axis < ndim):
156+
raise IndexError("Axis overrun")
157+
158+
axis %= a.ndim
159+
160+
if ndim != indices.ndim:
161+
raise ValueError("`indices` and `a` must have the same number of dimensions")
162+
163+
fancy_index = []
164+
for i, n in enumerate(a.shape):
165+
if i == axis:
166+
fancy_index.append(indices)
167+
else:
168+
ind_shape = (1,) * i + (-1,) + (1,) * (ndim - i - 1)
169+
fancy_index.append(cupy.arange(n).reshape(ind_shape))
170+
171+
a[fancy_index] = value
172+
173+
106174
if __name__ == "__main__":
107175
try:
108176
plac.call(main)

sense2vec/sense2vec.py

Lines changed: 34 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ def __init__(
3030
RETURNS (Sense2Vec): The newly constructed object.
3131
"""
3232
self.vectors = Vectors(shape=shape, name=vectors_name)
33+
self._row2key = None
3334
self.strings = StringStore() if strings is None else strings
3435
self.freqs: Dict[int, int] = {}
3536
self.cache = None
@@ -87,6 +88,7 @@ def __setitem__(self, key: Union[str, int], vector: numpy.ndarray):
8788
if key not in self.vectors:
8889
raise ValueError(f"Can't find key {key} in table")
8990
self.vectors[key] = vector
91+
self._row2key = None
9092

9193
def __iter__(self):
9294
"""YIELDS (tuple): String key and vector pairs in the table."""
@@ -106,6 +108,12 @@ def values(self):
106108
"""YIELDS (numpy.ndarray): The vectors in the table."""
107109
yield from self.vectors.values()
108110

111+
@property
112+
def row2key(self):
113+
if not self._row2key:
114+
self._row2key = {row: key for key, row in self.vectors.key2row.items()}
115+
return self._row2key
116+
109117
@property
110118
def make_key(self) -> Callable:
111119
"""Get the function to make keys."""
@@ -128,6 +136,7 @@ def add(self, key: Union[str, int], vector: numpy.ndarray, freq: int = None):
128136
self.vectors.add(key, vector=vector)
129137
if freq is not None:
130138
self.set_freq(key, freq)
139+
self._row2key = None
131140

132141
def get_freq(self, key: Union[str, int], default=None) -> Union[int, None]:
133142
"""Get the frequency count for a given key.
@@ -200,31 +209,32 @@ def most_similar(
200209
"""
201210
if isinstance(keys, (str, int)):
202211
keys = [keys]
203-
# Always ask for more because we'll always get the keys themselves
204-
n_similar = n + len(keys)
205212
for key in keys:
206213
if key not in self:
207214
raise ValueError(f"Can't find key {key} in table")
208-
if len(self.vectors) < n_similar:
209-
n_similar = len(self.vectors)
210-
if self.cache:
211-
indices = self.cache.get("indices", [])
212-
scores = self.cache.get("scores", [])
213-
if len(indices) >= n_similar:
214-
key_row = self.vectors.find(key=key)
215-
sim_keys = self.vectors.find(rows=indices[key_row][:n_similar])
216-
sim_scores = scores[key_row][:n_similar]
217-
result = [(self.strings[k], s) for k, s in zip(sim_keys, sim_scores)]
218-
return [(key, score) for key, score in result if key not in keys]
219-
vecs = numpy.vstack([self[key] for key in keys])
220-
average = vecs.mean(axis=0, keepdims=True)
221-
result_keys, _, scores = self.vectors.most_similar(
222-
average, n=n_similar, batch_size=batch_size
223-
)
224-
result = list(zip(result_keys.flatten(), scores.flatten()))
225-
result = [(self.strings[key], score) for key, score in result if key]
226-
result = [(key, score) for key, score in result if key not in keys]
227-
return result
215+
if self.cache and self.cache["indices"].shape[1] >= n:
216+
n = min(len(self.vectors), n)
217+
key = self.ensure_int_key(key)
218+
key_row = self.vectors.find(key=key)
219+
rows = self.cache["indices"][key_row, :n]
220+
scores = self.cache["indices"][key_row, :n]
221+
keys = [self.row2key[r] for r in rows]
222+
keys = [self.strings[k] for k in keys]
223+
assert len(keys) == len(scores)
224+
return list(zip(keys, scores))
225+
else:
226+
# Always ask for more because we'll always get the keys themselves
227+
n = min(len(self.vectors), n + len(keys))
228+
rows = numpy.asarray(self.vectors.find(keys=keys))
229+
vecs = self.vectors.data[rows]
230+
average = vecs.mean(axis=0, keepdims=True)
231+
result_keys, _, scores = self.vectors.most_similar(
232+
average, n=n, batch_size=batch_size
233+
)
234+
result = list(zip(result_keys.flatten(), scores.flatten()))
235+
result = [(self.strings[key], score) for key, score in result if key]
236+
result = [(key, score) for key, score in result if key not in keys]
237+
return result
228238

229239
def get_other_senses(
230240
self, key: Union[str, int], ignore_case: bool = True
@@ -302,6 +312,7 @@ def from_bytes(self, bytes_data: bytes, exclude: Sequence[str] = tuple()):
302312
self.strings = StringStore().from_bytes(data["strings"])
303313
if "cache" not in exclude and "cache" in data:
304314
self.cache = data.get("cache", {})
315+
self._row2key = None
305316
return self
306317

307318
def to_disk(self, path: Union[Path, str], exclude: Sequence[str] = tuple()):
@@ -338,4 +349,5 @@ def from_disk(self, path: Union[Path, str], exclude: Sequence[str] = tuple()):
338349
self.strings = StringStore().from_disk(strings_path)
339350
if "cache" not in exclude and cache_path.exists():
340351
self.cache = srsly.read_msgpack(cache_path)
352+
self._row2key = None
341353
return self

tests/data/cache

0 Bytes
Binary file not shown.

tests/test_model.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,10 @@ def test_model_most_similar_cache(s2v):
2525
# Modify cache to test that the cache is used and values aren't computed
2626
query_row = s2v.vectors.find(key=s2v.ensure_int_key(query))
2727
scores = numpy.array(s2v.cache["scores"], copy=True) # otherwise not writable
28-
scores[query_row, 1] = 2.0
29-
scores[query_row, 2] = 3.0
28+
honey_bees_row = s2v.vectors.find(key="honey_bees|NOUN")
29+
scores[query_row, honey_bees_row] = 2.0
30+
beekeepers_row = s2v.vectors.find(key="Beekepers|NOUN")
31+
scores[query_row, beekeepers_row] = 3.0
3032
s2v.cache["scores"] = scores
3133
((key1, score1), (key2, score2)) = s2v.most_similar([query], n=2)
3234
assert key1 == "honey_bees|NOUN"

0 commit comments

Comments
 (0)