Skip to content

Commit 28d9990

Browse files
authored
Add support for refinement with uint32_t index type (#563)
Closes #537. Needed change for the transition from Raft to cuVS. Authors: - Micka (https://github.com/lowener) Approvers: - Ben Frederickson (https://github.com/benfred) URL: #563
1 parent 1e548f8 commit 28d9990

File tree

4 files changed

+54
-6
lines changed

4 files changed

+54
-6
lines changed

cpp/include/cuvs/neighbors/refine.hpp

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,51 @@ void refine(raft::resources const& handle,
7676
raft::device_matrix_view<float, int64_t, raft::row_major> distances,
7777
cuvs::distance::DistanceType metric = cuvs::distance::DistanceType::L2Unexpanded);
7878

79+
/**
80+
* @brief Refine nearest neighbor search.
81+
*
82+
* Refinement is an operation that follows an approximate NN search. The approximate search has
83+
* already selected n_candidates neighbor candidates for each query. We narrow it down to k
84+
* neighbors. For each query, we calculate the exact distance between the query and its
85+
* n_candidates neighbor candidate, and select the k nearest ones.
86+
*
87+
* The k nearest neighbors and distances are returned.
88+
*
89+
* Example usage
90+
* @code{.cpp}
91+
* using namespace cuvs::neighbors;
92+
* // use default index parameters
93+
* ivf_pq::index_params index_params;
94+
* // create and fill the index from a [N, D] dataset
95+
* auto index = ivf_pq::build(handle, index_params, dataset);
96+
* // use default search parameters
97+
* ivf_pq::search_params search_params;
98+
* // search m = 4 * k nearest neighbours for each of the N queries
99+
* ivf_pq::search(handle, search_params, index, queries, neighbor_candidates,
100+
* out_dists_tmp);
101+
* // refine it to the k nearest one
102+
* refine(handle, dataset, queries, neighbor_candidates, out_indices, out_dists,
103+
* index.metric());
104+
* @endcode
105+
*
106+
*
107+
* @param[in] handle the raft handle
108+
* @param[in] dataset device matrix that stores the dataset [n_rows, dims]
109+
* @param[in] queries device matrix of the queries [n_queris, dims]
110+
* @param[in] neighbor_candidates indices of candidate vectors [n_queries, n_candidates], where
111+
* n_candidates >= k
112+
* @param[out] indices device matrix that stores the refined indices [n_queries, k]
113+
* @param[out] distances device matrix that stores the refined distances [n_queries, k]
114+
* @param[in] metric distance metric to use. Euclidean (L2) is used by default
115+
*/
116+
void refine(raft::resources const& handle,
117+
raft::device_matrix_view<const float, int64_t, raft::row_major> dataset,
118+
raft::device_matrix_view<const float, int64_t, raft::row_major> queries,
119+
raft::device_matrix_view<const uint32_t, int64_t, raft::row_major> neighbor_candidates,
120+
raft::device_matrix_view<uint32_t, int64_t, raft::row_major> indices,
121+
raft::device_matrix_view<float, int64_t, raft::row_major> distances,
122+
cuvs::distance::DistanceType metric = cuvs::distance::DistanceType::L2Unexpanded);
123+
79124
/**
80125
* @brief Refine nearest neighbor search.
81126
*

cpp/src/neighbors/ivf_flat_index.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -226,6 +226,7 @@ void index<T, IdxT>::check_consistency()
226226
"inconsistent number of lists (clusters)");
227227
}
228228

229+
template struct index<float, uint32_t>; // Used for refine function
229230
template struct index<float, int64_t>;
230231
template struct index<half, int64_t>;
231232
template struct index<int8_t, int64_t>;

cpp/src/neighbors/refine/detail/refine_device_float_float.cu

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,5 +43,6 @@
4343
}
4444

4545
instantiate_cuvs_neighbors_refine_d(int64_t, float, float, int64_t);
46+
instantiate_cuvs_neighbors_refine_d(uint32_t, float, float, int64_t);
4647

4748
#undef instantiate_cuvs_neighbors_refine_d

cpp/src/neighbors/refine/refine_device.cuh

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -84,12 +84,13 @@ void refine_device(
8484
cuvs::neighbors::ivf_flat::index<data_t, idx_t> refinement_index(
8585
handle, cuvs::distance::DistanceType(metric), n_queries, false, true, dim);
8686

87-
cuvs::neighbors::ivf_flat::detail::fill_refinement_index(handle,
88-
&refinement_index,
89-
dataset.data_handle(),
90-
neighbor_candidates.data_handle(),
91-
n_queries,
92-
n_candidates);
87+
cuvs::neighbors::ivf_flat::detail::fill_refinement_index<data_t, idx_t>(
88+
handle,
89+
&refinement_index,
90+
dataset.data_handle(),
91+
neighbor_candidates.data_handle(),
92+
static_cast<idx_t>(n_queries),
93+
static_cast<uint32_t>(n_candidates));
9394
uint32_t grid_dim_x = 1;
9495

9596
// the neighbor ids will be computed in uint32_t as offset

0 commit comments

Comments
 (0)