Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 56 additions & 21 deletions cpp/src/neighbors/all_neighbors/all_neighbors_merge.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
namespace cuvs::neighbors::all_neighbors::detail {
using namespace cuvs::neighbors;

template <typename IdxT, int BLOCK_SIZE, int ITEMS_PER_THREAD>
template <typename IdxT, int BLOCK_SIZE, int ITEMS_PER_THREAD, bool SweepAll = false>
RAFT_KERNEL merge_subgraphs_kernel(IdxT* cluster_data_indices,
size_t graph_degree,
size_t num_cluster_in_batch,
Expand Down Expand Up @@ -102,11 +102,12 @@ RAFT_KERNEL merge_subgraphs_kernel(IdxT* cluster_data_indices,

__syncthreads();

size_t limit = 2 * graph_degree;
// load sorted result into shared memory to get unique values
idxBase = threadIdx.x * ITEMS_PER_THREAD;
for (int i = 0; i < ITEMS_PER_THREAD; i++) {
size_t colId = idxBase + i;
if (colId < 2 * graph_degree) {
if (colId < limit) {
blockKeys[colId] = threadKeyValuePair[i].key;
blockValues[colId] = threadKeyValuePair[i].value;
}
Expand All @@ -118,16 +119,47 @@ RAFT_KERNEL merge_subgraphs_kernel(IdxT* cluster_data_indices,
if (threadIdx.x == 0) { uniqueMask[0] = 1; }
for (int i = 0; i < ITEMS_PER_THREAD; i++) {
size_t colId = idxBase + i;
if (colId > 0 && colId < 2 * graph_degree) {
uniqueMask[colId] = static_cast<int16_t>(blockValues[colId] != blockValues[colId - 1]);
if (colId > 0 && colId < limit) {
// this assumes same distance between vector from two different batches. however, currently
// there are subtle differences in the result based on the matrix size used to call gemm.
// This makes it difficult to remove duplicates, because they might no longer be right next
// to each other after sorting by distances. Thus, for now we sweep a neighboring window of
// size 4 or sweep the entire row to check for duplicates, and keep the first occurrence
// only.
// related issue: https://github.com/rapidsai/cuvs/issues/1056
// uniqueMask[colId] = static_cast<int16_t>(blockValues[colId] != blockValues[colId - 1]);

int is_unique = 1;

if constexpr (SweepAll) { // sweep whole row for better deduplication
for (int j = 0; j < limit; ++j) {
if (j < colId && blockValues[j] == blockValues[colId]) {
is_unique = 0;
break;
}
}
} else { // otherwise sweep a small window
IdxT curr_val = blockValues[colId];
#pragma unroll
for (int offset = -4; offset < 0; offset++) {
int neighbor_idx = static_cast<int>(colId) + offset;
if (neighbor_idx >= 0) {
if (blockValues[neighbor_idx] == curr_val) {
is_unique = 0;
break;
}
}
}
}
uniqueMask[colId] = static_cast<int16_t>(is_unique);
}
}

__syncthreads();

// prefix sum
if (threadIdx.x == 0) {
for (int i = 1; i < 2 * graph_degree; i++) {
for (int i = 1; i < limit; i++) {
uniqueMask[i] += uniqueMask[i - 1];
}
}
Expand All @@ -141,7 +173,7 @@ RAFT_KERNEL merge_subgraphs_kernel(IdxT* cluster_data_indices,

for (int i = 0; i < ITEMS_PER_THREAD; i++) {
size_t colId = idxBase + i;
if (colId > 0 && colId < 2 * graph_degree) {
if (colId > 0 && colId < limit) {
bool is_unique = uniqueMask[colId] != uniqueMask[colId - 1];
int16_t global_colId = uniqueMask[colId] - 1;
if (is_unique && static_cast<size_t>(global_colId) < graph_degree) {
Expand All @@ -153,7 +185,7 @@ RAFT_KERNEL merge_subgraphs_kernel(IdxT* cluster_data_indices,
}
}

template <typename T, typename IdxT = int64_t>
template <typename T, typename IdxT = int64_t, bool SweepAll = false>
void merge_subgraphs(raft::resources const& res,
size_t k,
size_t num_data_in_cluster,
Expand All @@ -170,7 +202,7 @@ void merge_subgraphs(raft::resources const& res,
#pragma omp critical // for omp-using multi-gpu purposes
{
if (num_elems <= 128) {
merge_subgraphs_kernel<IdxT, 32, 4>
merge_subgraphs_kernel<IdxT, 32, 4, SweepAll>
<<<num_data_in_cluster, 32, sharedMemSize, raft::resource::get_cuda_stream(res)>>>(
inverted_indices_d,
k,
Expand All @@ -181,7 +213,7 @@ void merge_subgraphs(raft::resources const& res,
batch_neighbors_d,
select_min);
} else if (num_elems <= 512) {
merge_subgraphs_kernel<IdxT, 128, 4>
merge_subgraphs_kernel<IdxT, 128, 4, SweepAll>
<<<num_data_in_cluster, 128, sharedMemSize, raft::resource::get_cuda_stream(res)>>>(
inverted_indices_d,
k,
Expand All @@ -192,7 +224,7 @@ void merge_subgraphs(raft::resources const& res,
batch_neighbors_d,
select_min);
} else if (num_elems <= 1024) {
merge_subgraphs_kernel<IdxT, 128, 8>
merge_subgraphs_kernel<IdxT, 128, 8, SweepAll>
<<<num_data_in_cluster, 128, sharedMemSize, raft::resource::get_cuda_stream(res)>>>(
inverted_indices_d,
k,
Expand All @@ -203,7 +235,7 @@ void merge_subgraphs(raft::resources const& res,
batch_neighbors_d,
select_min);
} else if (num_elems <= 2048) {
merge_subgraphs_kernel<IdxT, 256, 8>
merge_subgraphs_kernel<IdxT, 256, 8, SweepAll>
<<<num_data_in_cluster, 256, sharedMemSize, raft::resource::get_cuda_stream(res)>>>(
inverted_indices_d,
k,
Expand All @@ -221,7 +253,10 @@ void merge_subgraphs(raft::resources const& res,
}
}

template <typename T, typename IdxT = int64_t, typename BeforeRemapT = int64_t>
template <typename T,
typename IdxT = int64_t,
typename BeforeRemapT = int64_t,
bool SweepAll = false>
void remap_and_merge_subgraphs(raft::resources const& res,
raft::device_vector_view<IdxT, IdxT> inverted_indices_d,
raft::host_vector_view<IdxT, IdxT> inverted_indices,
Expand Down Expand Up @@ -254,15 +289,15 @@ void remap_and_merge_subgraphs(raft::resources const& res,
num_data_in_cluster * k,
raft::resource::get_cuda_stream(res));

merge_subgraphs(res,
k,
num_data_in_cluster,
inverted_indices_d.data_handle(),
global_distances.data_handle(),
batch_distances_d.data_handle(),
global_neighbors.data_handle(),
batch_neighbors_d.data_handle(),
select_min);
merge_subgraphs<T, IdxT, SweepAll>(res,
k,
num_data_in_cluster,
inverted_indices_d.data_handle(),
global_distances.data_handle(),
batch_distances_d.data_handle(),
global_neighbors.data_handle(),
batch_neighbors_d.data_handle(),
select_min);
}

} // namespace cuvs::neighbors::all_neighbors::detail