Skip to content

Commit 0924979

Browse files
jinsolplowener
authored andcommitted
Make duplicate removal in all neighbors robust to distance drift across batches (rapidsai#1185)
This is to fix an edge case that happens and the root cause is in issue: rapidsai#1056, which is about different distance results from `raft::linalg::gemm` based on the input sizes. Right now, when merging two knn graphs from different batches, we sort by distances (i.e. keys), and if the distances are same we sort by indices (i.e. values). After doing so, we compare indices right next to each other to check for duplicates under assumption that same vectors end up with same distances. However, due to the problem stated in issue 1056, distance for same index can be slightly different based on the size of the input matrix to gemm (or where the vector is in the entire matrix). For example, say we are calculating nearest neighbors for vector 0. we could end up with ``` indices = [1, 2, 3, 2, ....] distances = [0.023, 0.02355981, 0.02355983, 0.02355987] ``` because distance between vector 0 and vector 2 is calculated as 0.02355981 in the first batch, and 0.02355987 in the second batch. This PR fixes this issue by checking 4 neighbors to its left for duplicates, instead of checking the one next to itself. Authors: - Jinsol Park (https://github.com/jinsolp) - Corey J. Nolet (https://github.com/cjnolet) Approvers: - Corey J. Nolet (https://github.com/cjnolet) URL: rapidsai#1185
1 parent f3e166e commit 0924979

File tree

1 file changed

+56
-21
lines changed

1 file changed

+56
-21
lines changed

cpp/src/neighbors/all_neighbors/all_neighbors_merge.cuh

Lines changed: 56 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
namespace cuvs::neighbors::all_neighbors::detail {
3131
using namespace cuvs::neighbors;
3232

33-
template <typename IdxT, int BLOCK_SIZE, int ITEMS_PER_THREAD>
33+
template <typename IdxT, int BLOCK_SIZE, int ITEMS_PER_THREAD, bool SweepAll = false>
3434
RAFT_KERNEL merge_subgraphs_kernel(IdxT* cluster_data_indices,
3535
size_t graph_degree,
3636
size_t num_cluster_in_batch,
@@ -102,11 +102,12 @@ RAFT_KERNEL merge_subgraphs_kernel(IdxT* cluster_data_indices,
102102

103103
__syncthreads();
104104

105+
size_t limit = 2 * graph_degree;
105106
// load sorted result into shared memory to get unique values
106107
idxBase = threadIdx.x * ITEMS_PER_THREAD;
107108
for (int i = 0; i < ITEMS_PER_THREAD; i++) {
108109
size_t colId = idxBase + i;
109-
if (colId < 2 * graph_degree) {
110+
if (colId < limit) {
110111
blockKeys[colId] = threadKeyValuePair[i].key;
111112
blockValues[colId] = threadKeyValuePair[i].value;
112113
}
@@ -118,16 +119,47 @@ RAFT_KERNEL merge_subgraphs_kernel(IdxT* cluster_data_indices,
118119
if (threadIdx.x == 0) { uniqueMask[0] = 1; }
119120
for (int i = 0; i < ITEMS_PER_THREAD; i++) {
120121
size_t colId = idxBase + i;
121-
if (colId > 0 && colId < 2 * graph_degree) {
122-
uniqueMask[colId] = static_cast<int16_t>(blockValues[colId] != blockValues[colId - 1]);
122+
if (colId > 0 && colId < limit) {
123+
// this assumes same distance between vector from two different batches. however, currently
124+
// there are subtle differences in the result based on the matrix size used to call gemm.
125+
// This makes it difficult to remove duplicates, because they might no longer be right next
126+
// to each other after sorting by distances. Thus, for now we sweep a neighboring window of
127+
// size 4 or sweep the entire row to check for duplicates, and keep the first occurrence
128+
// only.
129+
// related issue: https://github.com/rapidsai/cuvs/issues/1056
130+
// uniqueMask[colId] = static_cast<int16_t>(blockValues[colId] != blockValues[colId - 1]);
131+
132+
int is_unique = 1;
133+
134+
if constexpr (SweepAll) { // sweep whole row for better deduplication
135+
for (int j = 0; j < limit; ++j) {
136+
if (j < colId && blockValues[j] == blockValues[colId]) {
137+
is_unique = 0;
138+
break;
139+
}
140+
}
141+
} else { // otherwise sweep a small window
142+
IdxT curr_val = blockValues[colId];
143+
#pragma unroll
144+
for (int offset = -4; offset < 0; offset++) {
145+
int neighbor_idx = static_cast<int>(colId) + offset;
146+
if (neighbor_idx >= 0) {
147+
if (blockValues[neighbor_idx] == curr_val) {
148+
is_unique = 0;
149+
break;
150+
}
151+
}
152+
}
153+
}
154+
uniqueMask[colId] = static_cast<int16_t>(is_unique);
123155
}
124156
}
125157

126158
__syncthreads();
127159

128160
// prefix sum
129161
if (threadIdx.x == 0) {
130-
for (int i = 1; i < 2 * graph_degree; i++) {
162+
for (int i = 1; i < limit; i++) {
131163
uniqueMask[i] += uniqueMask[i - 1];
132164
}
133165
}
@@ -141,7 +173,7 @@ RAFT_KERNEL merge_subgraphs_kernel(IdxT* cluster_data_indices,
141173

142174
for (int i = 0; i < ITEMS_PER_THREAD; i++) {
143175
size_t colId = idxBase + i;
144-
if (colId > 0 && colId < 2 * graph_degree) {
176+
if (colId > 0 && colId < limit) {
145177
bool is_unique = uniqueMask[colId] != uniqueMask[colId - 1];
146178
int16_t global_colId = uniqueMask[colId] - 1;
147179
if (is_unique && static_cast<size_t>(global_colId) < graph_degree) {
@@ -153,7 +185,7 @@ RAFT_KERNEL merge_subgraphs_kernel(IdxT* cluster_data_indices,
153185
}
154186
}
155187

156-
template <typename T, typename IdxT = int64_t>
188+
template <typename T, typename IdxT = int64_t, bool SweepAll = false>
157189
void merge_subgraphs(raft::resources const& res,
158190
size_t k,
159191
size_t num_data_in_cluster,
@@ -170,7 +202,7 @@ void merge_subgraphs(raft::resources const& res,
170202
#pragma omp critical // for omp-using multi-gpu purposes
171203
{
172204
if (num_elems <= 128) {
173-
merge_subgraphs_kernel<IdxT, 32, 4>
205+
merge_subgraphs_kernel<IdxT, 32, 4, SweepAll>
174206
<<<num_data_in_cluster, 32, sharedMemSize, raft::resource::get_cuda_stream(res)>>>(
175207
inverted_indices_d,
176208
k,
@@ -181,7 +213,7 @@ void merge_subgraphs(raft::resources const& res,
181213
batch_neighbors_d,
182214
select_min);
183215
} else if (num_elems <= 512) {
184-
merge_subgraphs_kernel<IdxT, 128, 4>
216+
merge_subgraphs_kernel<IdxT, 128, 4, SweepAll>
185217
<<<num_data_in_cluster, 128, sharedMemSize, raft::resource::get_cuda_stream(res)>>>(
186218
inverted_indices_d,
187219
k,
@@ -192,7 +224,7 @@ void merge_subgraphs(raft::resources const& res,
192224
batch_neighbors_d,
193225
select_min);
194226
} else if (num_elems <= 1024) {
195-
merge_subgraphs_kernel<IdxT, 128, 8>
227+
merge_subgraphs_kernel<IdxT, 128, 8, SweepAll>
196228
<<<num_data_in_cluster, 128, sharedMemSize, raft::resource::get_cuda_stream(res)>>>(
197229
inverted_indices_d,
198230
k,
@@ -203,7 +235,7 @@ void merge_subgraphs(raft::resources const& res,
203235
batch_neighbors_d,
204236
select_min);
205237
} else if (num_elems <= 2048) {
206-
merge_subgraphs_kernel<IdxT, 256, 8>
238+
merge_subgraphs_kernel<IdxT, 256, 8, SweepAll>
207239
<<<num_data_in_cluster, 256, sharedMemSize, raft::resource::get_cuda_stream(res)>>>(
208240
inverted_indices_d,
209241
k,
@@ -221,7 +253,10 @@ void merge_subgraphs(raft::resources const& res,
221253
}
222254
}
223255

224-
template <typename T, typename IdxT = int64_t, typename BeforeRemapT = int64_t>
256+
template <typename T,
257+
typename IdxT = int64_t,
258+
typename BeforeRemapT = int64_t,
259+
bool SweepAll = false>
225260
void remap_and_merge_subgraphs(raft::resources const& res,
226261
raft::device_vector_view<IdxT, IdxT> inverted_indices_d,
227262
raft::host_vector_view<IdxT, IdxT> inverted_indices,
@@ -254,15 +289,15 @@ void remap_and_merge_subgraphs(raft::resources const& res,
254289
num_data_in_cluster * k,
255290
raft::resource::get_cuda_stream(res));
256291

257-
merge_subgraphs(res,
258-
k,
259-
num_data_in_cluster,
260-
inverted_indices_d.data_handle(),
261-
global_distances.data_handle(),
262-
batch_distances_d.data_handle(),
263-
global_neighbors.data_handle(),
264-
batch_neighbors_d.data_handle(),
265-
select_min);
292+
merge_subgraphs<T, IdxT, SweepAll>(res,
293+
k,
294+
num_data_in_cluster,
295+
inverted_indices_d.data_handle(),
296+
global_distances.data_handle(),
297+
batch_distances_d.data_handle(),
298+
global_neighbors.data_handle(),
299+
batch_neighbors_d.data_handle(),
300+
select_min);
266301
}
267302

268303
} // namespace cuvs::neighbors::all_neighbors::detail

0 commit comments

Comments
 (0)