@@ -76,6 +76,51 @@ void refine(raft::resources const& handle,
7676 raft::device_matrix_view<float , int64_t , raft::row_major> distances,
7777 cuvs::distance::DistanceType metric = cuvs::distance::DistanceType::L2Unexpanded);
7878
79+ /* *
80+ * @brief Refine nearest neighbor search.
81+ *
82+ * Refinement is an operation that follows an approximate NN search. The approximate search has
83+ * already selected n_candidates neighbor candidates for each query. We narrow it down to k
84+ * neighbors. For each query, we calculate the exact distance between the query and its
85+ * n_candidates neighbor candidate, and select the k nearest ones.
86+ *
87+ * The k nearest neighbors and distances are returned.
88+ *
89+ * Example usage
90+ * @code{.cpp}
91+ * using namespace cuvs::neighbors;
92+ * // use default index parameters
93+ * ivf_pq::index_params index_params;
94+ * // create and fill the index from a [N, D] dataset
95+ * auto index = ivf_pq::build(handle, index_params, dataset);
96+ * // use default search parameters
97+ * ivf_pq::search_params search_params;
98+ * // search m = 4 * k nearest neighbours for each of the N queries
99+ * ivf_pq::search(handle, search_params, index, queries, neighbor_candidates,
100+ * out_dists_tmp);
101+ * // refine it to the k nearest one
102+ * refine(handle, dataset, queries, neighbor_candidates, out_indices, out_dists,
103+ * index.metric());
104+ * @endcode
105+ *
106+ *
107+ * @param[in] handle the raft handle
108+ * @param[in] dataset device matrix that stores the dataset [n_rows, dims]
109+ * @param[in] queries device matrix of the queries [n_queris, dims]
110+ * @param[in] neighbor_candidates indices of candidate vectors [n_queries, n_candidates], where
111+ * n_candidates >= k
112+ * @param[out] indices device matrix that stores the refined indices [n_queries, k]
113+ * @param[out] distances device matrix that stores the refined distances [n_queries, k]
114+ * @param[in] metric distance metric to use. Euclidean (L2) is used by default
115+ */
116+ void refine (raft::resources const & handle,
117+ raft::device_matrix_view<const float , int64_t , raft::row_major> dataset,
118+ raft::device_matrix_view<const float , int64_t , raft::row_major> queries,
119+ raft::device_matrix_view<const uint32_t , int64_t , raft::row_major> neighbor_candidates,
120+ raft::device_matrix_view<uint32_t , int64_t , raft::row_major> indices,
121+ raft::device_matrix_view<float , int64_t , raft::row_major> distances,
122+ cuvs::distance::DistanceType metric = cuvs::distance::DistanceType::L2Unexpanded);
123+
79124/* *
80125 * @brief Refine nearest neighbor search.
81126 *
0 commit comments