@@ -33,6 +33,8 @@ def main(
33
33
import cupy as xp
34
34
import cupy .cuda .device
35
35
36
+ cupy .take_along_axis = take_along_axis
37
+ cupy .put_along_axis = put_along_axis
36
38
device = cupy .cuda .device .Device (gpu_id )
37
39
device .use ()
38
40
vectors_dir = Path (vectors )
@@ -56,35 +58,26 @@ def main(
56
58
msg .good (f"Normalized (mean { mean :,.2f} , variance { var :,.2f} )" )
57
59
msg .info (f"Finding { n_neighbors :,} neighbors among { cutoff :,} most frequent" )
58
60
n = min (n_neighbors , vectors .shape [0 ])
61
+ subset = vectors [:cutoff ]
59
62
best_rows = xp .zeros ((end - start , n ), dtype = "i" )
60
63
scores = xp .zeros ((end - start , n ), dtype = "f" )
61
- # Pre-allocate this array, so we can use it each time.
62
- subset = xp .ascontiguousarray (vectors [:cutoff ])
63
- sims = xp .zeros ((batch_size , cutoff ), dtype = "f" )
64
- indices = xp .arange (cutoff ).reshape ((- 1 , 1 ))
65
64
for i in tqdm .tqdm (list (range (start , end , batch_size ))):
66
- batch = vectors [i : i + batch_size ]
67
- # batch e.g. (1024, 300)
68
- # vectors e.g. (10000, 300)
69
- # sims e.g. (1024, 10000)
70
- if batch .shape [0 ] == sims .shape [0 ]:
71
- xp .dot (batch , subset .T , out = sims )
72
- else :
73
- # In the last batch we'll have a different size.
74
- sims = xp .dot (batch , subset .T )
75
- size = sims .shape [0 ]
76
- # Get the indices and scores for the top N most similar for each in the
77
- # batch. This is a bit complicated, to avoid sorting all of the scores
78
- # -- we only want the top N to be sorted (which we do later). For now,
79
- # we use argpartition to just get the cut point.
80
- neighbors = xp .argpartition (sims , - n , axis = 1 )[:, - n :]
81
- neighbor_sims = xp .partition (sims , - n , axis = 1 )[:, - n :]
82
- # Can't figure out how to do this without the loop.
83
- for j in range (min (end - i , size )):
84
- # Sort in reverse order
85
- indices = xp .argsort (neighbor_sims [j ], axis = - 1 )[::- 1 ]
86
- best_rows [i + j ] = xp .take (neighbors [j ], indices )
87
- scores [i + j ] = xp .take (neighbor_sims [j ], indices )
65
+ size = min (batch_size , end - i )
66
+ batch = vectors [i : i + size ]
67
+ sims = xp .dot (batch , subset .T )
68
+ # Set self-similarities to -inf, so that we don't return them.
69
+ indices = xp .arange (i , min (i + size , sims .shape [1 ])).reshape ((- 1 , 1 ))
70
+ xp .put_along_axis (sims , indices , - xp .inf , axis = 1 )
71
+ # This used to use argpartition, to do a partial sort...But this ended
72
+ # up being a ratsnest of terrible numpy crap. Just sorting the whole
73
+ # list isn't really slower, and it's much simpler to read.
74
+ ranks = xp .argsort (sims , axis = 1 )
75
+ batch_rows = ranks [:, - n :]
76
+ # Reverse
77
+ batch_rows = batch_rows [:, ::- 1 ]
78
+ batch_scores = xp .take_along_axis (sims , batch_rows , axis = 1 )
79
+ best_rows [i : i + size ] = batch_rows
80
+ scores [i : i + size ] = batch_scores
88
81
msg .info ("Saving output" )
89
82
if not isinstance (best_rows , numpy .ndarray ):
90
83
best_rows = best_rows .get ()
@@ -103,6 +96,81 @@ def main(
103
96
msg .good (f"Saved cache to { output_file } " )
104
97
105
98
99
+ # These functions are missing from cupy, but will be supported in cupy 7.
100
+ def take_along_axis (a , indices , axis ):
101
+ """Take values from the input array by matching 1d index and data slices.
102
+
103
+ Args:
104
+ a (cupy.ndarray): Array to extract elements.
105
+ indices (cupy.ndarray): Indices to take along each 1d slice of ``a``.
106
+ axis (int): The axis to take 1d slices along.
107
+
108
+ Returns:
109
+ cupy.ndarray: The indexed result.
110
+
111
+ .. seealso:: :func:`numpy.take_along_axis`
112
+ """
113
+ import cupy
114
+
115
+ if indices .dtype .kind not in ("i" , "u" ):
116
+ raise IndexError ("`indices` must be an integer array" )
117
+
118
+ if axis is None :
119
+ a = a .ravel ()
120
+ axis = 0
121
+
122
+ ndim = a .ndim
123
+
124
+ if not (- ndim <= axis < ndim ):
125
+ raise IndexError ("Axis overrun" )
126
+
127
+ axis %= a .ndim
128
+
129
+ if ndim != indices .ndim :
130
+ raise ValueError ("`indices` and `a` must have the same number of dimensions" )
131
+
132
+ fancy_index = []
133
+ for i , n in enumerate (a .shape ):
134
+ if i == axis :
135
+ fancy_index .append (indices )
136
+ else :
137
+ ind_shape = (1 ,) * i + (- 1 ,) + (1 ,) * (ndim - i - 1 )
138
+ fancy_index .append (cupy .arange (n ).reshape (ind_shape ))
139
+
140
+ return a [fancy_index ]
141
+
142
+
143
+ def put_along_axis (a , indices , value , axis ):
144
+ import cupy
145
+
146
+ if indices .dtype .kind not in ("i" , "u" ):
147
+ raise IndexError ("`indices` must be an integer array" )
148
+
149
+ if axis is None :
150
+ a = a .ravel ()
151
+ axis = 0
152
+
153
+ ndim = a .ndim
154
+
155
+ if not (- ndim <= axis < ndim ):
156
+ raise IndexError ("Axis overrun" )
157
+
158
+ axis %= a .ndim
159
+
160
+ if ndim != indices .ndim :
161
+ raise ValueError ("`indices` and `a` must have the same number of dimensions" )
162
+
163
+ fancy_index = []
164
+ for i , n in enumerate (a .shape ):
165
+ if i == axis :
166
+ fancy_index .append (indices )
167
+ else :
168
+ ind_shape = (1 ,) * i + (- 1 ,) + (1 ,) * (ndim - i - 1 )
169
+ fancy_index .append (cupy .arange (n ).reshape (ind_shape ))
170
+
171
+ a [fancy_index ] = value
172
+
173
+
106
174
if __name__ == "__main__" :
107
175
try :
108
176
plac .call (main )
0 commit comments