Skip to content

Commit 3c7f117

Browse files
authored
Python API for CAGRA+HNSW (#246)
Authors: - Divye Gala (https://github.com/divyegala) - Corey J. Nolet (https://github.com/cjnolet) Approvers: - Corey J. Nolet (https://github.com/cjnolet) URL: #246
1 parent 496427f commit 3c7f117

File tree

21 files changed

+695
-57
lines changed

21 files changed

+695
-57
lines changed

cpp/include/cuvs/neighbors/cagra.h

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -267,6 +267,15 @@ cuvsError_t cuvsCagraIndexCreate(cuvsCagraIndex_t* index);
267267
*/
268268
cuvsError_t cuvsCagraIndexDestroy(cuvsCagraIndex_t index);
269269

270+
/**
271+
* @brief Get dimension of the CAGRA index
272+
*
273+
* @param[in] index CAGRA index
274+
* @param[out] dim return dimension of the index
275+
* @return cuvsError_t
276+
*/
277+
cuvsError_t cuvsCagraIndexGetDims(cuvsCagraIndex_t index, int* dim);
278+
270279
/**
271280
* @}
272281
*/
@@ -338,7 +347,7 @@ cuvsError_t cuvsCagraBuild(cuvsResources_t res,
338347
* with the same type of `queries`, such that `index.dtype.code ==
339348
* queries.dl_tensor.dtype.code` Types for input are:
340349
* 1. `queries`:
341-
*` a. kDLDataType.code == kDLFloat` and `kDLDataType.bits = 32`
350+
* a. `kDLDataType.code == kDLFloat` and `kDLDataType.bits = 32`
342351
* b. `kDLDataType.code == kDLInt` and `kDLDataType.bits = 8`
343352
* c. `kDLDataType.code == kDLUInt` and `kDLDataType.bits = 8`
344353
* 2. `neighbors`: `kDLDataType.code == kDLUInt` and `kDLDataType.bits = 32`

cpp/include/cuvs/neighbors/hnsw.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -105,8 +105,10 @@ cuvsError_t cuvsHnswIndexDestroy(cuvsHnswIndex_t index);
105105
* with the same type of `queries`, such that `index.dtype.code ==
106106
* queries.dl_tensor.dtype.code`
107107
* Supported types for input are:
108-
* 1. `queries`: `kDLDataType.code == kDLFloat` or `kDLDataType.code == kDLInt` and
109-
* `kDLDataType.bits = 32`
108+
* 1. `queries`:
109+
* a. `kDLDataType.code == kDLFloat` and `kDLDataType.bits = 32`
110+
* b. `kDLDataType.code == kDLInt` and `kDLDataType.bits = 8`
111+
* c. `kDLDataType.code == kDLUInt` and `kDLDataType.bits = 8`
110112
* 2. `neighbors`: `kDLDataType.code == kDLUInt` and `kDLDataType.bits = 64`
111113
* 3. `distances`: `kDLDataType.code == kDLFloat` and `kDLDataType.bits = 32`
112114
* NOTE: The HNSW index can only be searched by the hnswlib wrapper in cuVS,

cpp/include/cuvs/neighbors/hnsw.hpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,8 @@ std::unique_ptr<index<int8_t>> from_cagra(
173173

174174
/**@}*/
175175

176+
// TODO: Filtered Search APIs: https://github.com/rapidsai/cuvs/issues/363
177+
176178
/**
177179
* @defgroup hnsw_cpp_index_search Search hnswlib index
178180
* @{
@@ -260,7 +262,7 @@ void search(raft::resources const& res,
260262
void search(raft::resources const& res,
261263
const search_params& params,
262264
const index<uint8_t>& idx,
263-
raft::host_matrix_view<const int, int64_t, raft::row_major> queries,
265+
raft::host_matrix_view<const uint8_t, int64_t, raft::row_major> queries,
264266
raft::host_matrix_view<uint64_t, int64_t, raft::row_major> neighbors,
265267
raft::host_matrix_view<float, int64_t, raft::row_major> distances);
266268

@@ -303,7 +305,7 @@ void search(raft::resources const& res,
303305
void search(raft::resources const& res,
304306
const search_params& params,
305307
const index<int8_t>& idx,
306-
raft::host_matrix_view<const int, int64_t, raft::row_major> queries,
308+
raft::host_matrix_view<const int8_t, int64_t, raft::row_major> queries,
307309
raft::host_matrix_view<uint64_t, int64_t, raft::row_major> neighbors,
308310
raft::host_matrix_view<float, int64_t, raft::row_major> distances);
309311

cpp/src/neighbors/cagra_c.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,14 @@ extern "C" cuvsError_t cuvsCagraIndexDestroy(cuvsCagraIndex_t index_c_ptr)
176176
});
177177
}
178178

179+
extern "C" cuvsError_t cuvsCagraIndexGetDims(cuvsCagraIndex_t index, int* dim)
180+
{
181+
return cuvs::core::translate_exceptions([=] {
182+
auto index_ptr = reinterpret_cast<cuvs::neighbors::cagra::index<float, uint32_t>*>(index->addr);
183+
*dim = index_ptr->dim();
184+
});
185+
}
186+
179187
extern "C" cuvsError_t cuvsCagraBuild(cuvsResources_t res,
180188
cuvsCagraIndexParams_t params,
181189
DLManagedTensor* dataset_tensor,

cpp/src/neighbors/detail/cagra/cagra_serialize.cuh

Lines changed: 6 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -120,9 +120,9 @@ void serialize_to_hnswlib(raft::resources const& res,
120120
os.write(reinterpret_cast<char*>(&curr_element_count), sizeof(std::size_t));
121121
// Example:M: 16, dim = 128, data_t = float, index_t = uint32_t, list_size_type = uint32_t,
122122
// labeltype: size_t size_data_per_element_ = M * 2 * sizeof(index_t) + sizeof(list_size_type) +
123-
// dim * 4 + sizeof(labeltype)
124-
auto size_data_per_element =
125-
static_cast<std::size_t>(index_.graph_degree() * sizeof(IdxT) + 4 + index_.dim() * 4 + 8);
123+
// dim * sizeof(T) + sizeof(labeltype)
124+
auto size_data_per_element = static_cast<std::size_t>(index_.graph_degree() * sizeof(IdxT) + 4 +
125+
index_.dim() * sizeof(T) + 8);
126126
os.write(reinterpret_cast<char*>(&size_data_per_element), sizeof(std::size_t));
127127
// label_offset
128128
std::size_t label_offset = size_data_per_element - 8;
@@ -185,18 +185,9 @@ void serialize_to_hnswlib(raft::resources const& res,
185185
}
186186

187187
auto data_row = host_dataset.data_handle() + (index_.dim() * i);
188-
if constexpr (std::is_same_v<T, float>) {
189-
for (std::size_t j = 0; j < index_.dim(); ++j) {
190-
auto data_elem = static_cast<float>(host_dataset(i, j));
191-
os.write(reinterpret_cast<char*>(&data_elem), sizeof(float));
192-
}
193-
} else if constexpr (std::is_same_v<T, std::int8_t> or std::is_same_v<T, std::uint8_t>) {
194-
for (std::size_t j = 0; j < index_.dim(); ++j) {
195-
auto data_elem = static_cast<int>(host_dataset(i, j));
196-
os.write(reinterpret_cast<char*>(&data_elem), sizeof(int));
197-
}
198-
} else {
199-
RAFT_FAIL("Unsupported dataset type while saving CAGRA dataset to HNSWlib format");
188+
for (std::size_t j = 0; j < index_.dim(); ++j) {
189+
auto data_elem = static_cast<T>(host_dataset(i, j));
190+
os.write(reinterpret_cast<char*>(&data_elem), sizeof(T));
200191
}
201192

202193
os.write(reinterpret_cast<char*>(&i), sizeof(std::size_t));

cpp/src/neighbors/detail/hnsw.hpp

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -110,9 +110,9 @@ std::unique_ptr<index<T>> from_cagra(raft::resources const& res,
110110
return std::unique_ptr<index<T>>(hnsw_index);
111111
}
112112

113-
template <typename QueriesT>
114-
void get_search_knn_results(hnswlib::HierarchicalNSW<QueriesT> const* idx,
115-
const QueriesT* query,
113+
template <typename T>
114+
void get_search_knn_results(hnswlib::HierarchicalNSW<typename hnsw_dist_t<T>::type> const* idx,
115+
const T* query,
116116
int k,
117117
uint64_t* indices,
118118
float* distances)
@@ -127,11 +127,11 @@ void get_search_knn_results(hnswlib::HierarchicalNSW<QueriesT> const* idx,
127127
}
128128
}
129129

130-
template <typename T, typename QueriesT>
130+
template <typename T>
131131
void search(raft::resources const& res,
132132
const search_params& params,
133133
const index<T>& idx,
134-
raft::host_matrix_view<const QueriesT, int64_t, raft::row_major> queries,
134+
raft::host_matrix_view<const T, int64_t, raft::row_major> queries,
135135
raft::host_matrix_view<uint64_t, int64_t, raft::row_major> neighbors,
136136
raft::host_matrix_view<float, int64_t, raft::row_major> distances)
137137
{
@@ -146,7 +146,8 @@ void search(raft::resources const& res,
146146

147147
idx.set_ef(params.ef);
148148
auto const* hnswlib_index =
149-
reinterpret_cast<hnswlib::HierarchicalNSW<QueriesT> const*>(idx.get_index());
149+
reinterpret_cast<hnswlib::HierarchicalNSW<typename hnsw_dist_t<T>::type> const*>(
150+
idx.get_index());
150151

151152
// when num_threads == 0, automatically maximize parallelism
152153
if (params.num_threads) {

cpp/src/neighbors/hnsw.cpp

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -34,20 +34,20 @@ CUVS_INST_HNSW_FROM_CAGRA(int8_t);
3434

3535
#undef CUVS_INST_HNSW_FROM_CAGRA
3636

37-
#define CUVS_INST_HNSW_SEARCH(T, QueriesT) \
38-
void search(raft::resources const& res, \
39-
const search_params& params, \
40-
const index<T>& idx, \
41-
raft::host_matrix_view<const QueriesT, int64_t, raft::row_major> queries, \
42-
raft::host_matrix_view<uint64_t, int64_t, raft::row_major> neighbors, \
43-
raft::host_matrix_view<float, int64_t, raft::row_major> distances) \
44-
{ \
45-
detail::search<T, QueriesT>(res, params, idx, queries, neighbors, distances); \
37+
#define CUVS_INST_HNSW_SEARCH(T) \
38+
void search(raft::resources const& res, \
39+
const search_params& params, \
40+
const index<T>& idx, \
41+
raft::host_matrix_view<const T, int64_t, raft::row_major> queries, \
42+
raft::host_matrix_view<uint64_t, int64_t, raft::row_major> neighbors, \
43+
raft::host_matrix_view<float, int64_t, raft::row_major> distances) \
44+
{ \
45+
detail::search<T>(res, params, idx, queries, neighbors, distances); \
4646
}
4747

48-
CUVS_INST_HNSW_SEARCH(float, float);
49-
CUVS_INST_HNSW_SEARCH(uint8_t, int);
50-
CUVS_INST_HNSW_SEARCH(int8_t, int);
48+
CUVS_INST_HNSW_SEARCH(float);
49+
CUVS_INST_HNSW_SEARCH(uint8_t);
50+
CUVS_INST_HNSW_SEARCH(int8_t);
5151

5252
#undef CUVS_INST_HNSW_SEARCH
5353

cpp/src/neighbors/hnsw_c.cpp

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
#include <cuvs/neighbors/hnsw.hpp>
3232

3333
namespace {
34-
template <typename T, typename QueriesT>
34+
template <typename T>
3535
void _search(cuvsResources_t res,
3636
cuvsHnswSearchParams params,
3737
cuvsHnswIndex index,
@@ -46,7 +46,7 @@ void _search(cuvsResources_t res,
4646
search_params.ef = params.ef;
4747
search_params.num_threads = params.numThreads;
4848

49-
using queries_mdspan_type = raft::host_matrix_view<QueriesT const, int64_t, raft::row_major>;
49+
using queries_mdspan_type = raft::host_matrix_view<T const, int64_t, raft::row_major>;
5050
using neighbors_mdspan_type = raft::host_matrix_view<uint64_t, int64_t, raft::row_major>;
5151
using distances_mdspan_type = raft::host_matrix_view<float, int64_t, raft::row_major>;
5252
auto queries_mds = cuvs::core::from_dlpack<queries_mdspan_type>(queries_tensor);
@@ -127,16 +127,13 @@ extern "C" cuvsError_t cuvsHnswSearch(cuvsResources_t res,
127127

128128
auto index = *index_c_ptr;
129129
RAFT_EXPECTS(queries.dtype.code == index.dtype.code, "type mismatch between index and queries");
130-
RAFT_EXPECTS(queries.dtype.bits == 32, "number of bits in queries dtype should be 32");
131130

132131
if (index.dtype.code == kDLFloat) {
133-
_search<float, float>(
134-
res, *params, index, queries_tensor, neighbors_tensor, distances_tensor);
132+
_search<float>(res, *params, index, queries_tensor, neighbors_tensor, distances_tensor);
135133
} else if (index.dtype.code == kDLUInt) {
136-
_search<uint8_t, int>(
137-
res, *params, index, queries_tensor, neighbors_tensor, distances_tensor);
134+
_search<uint8_t>(res, *params, index, queries_tensor, neighbors_tensor, distances_tensor);
138135
} else if (index.dtype.code == kDLInt) {
139-
_search<int8_t, int>(res, *params, index, queries_tensor, neighbors_tensor, distances_tensor);
136+
_search<int8_t>(res, *params, index, queries_tensor, neighbors_tensor, distances_tensor);
140137
} else {
141138
RAFT_FAIL("Unsupported index dtype: %d and bits: %d", queries.dtype.code, queries.dtype.bits);
142139
}
@@ -152,13 +149,10 @@ extern "C" cuvsError_t cuvsHnswDeserialize(cuvsResources_t res,
152149
return cuvs::core::translate_exceptions([=] {
153150
if (index->dtype.code == kDLFloat && index->dtype.bits == 32) {
154151
index->addr = reinterpret_cast<uintptr_t>(_deserialize<float>(res, filename, dim, metric));
155-
index->dtype.code = kDLFloat;
156152
} else if (index->dtype.code == kDLUInt && index->dtype.bits == 8) {
157153
index->addr = reinterpret_cast<uintptr_t>(_deserialize<uint8_t>(res, filename, dim, metric));
158-
index->dtype.code = kDLInt;
159154
} else if (index->dtype.code == kDLInt && index->dtype.bits == 8) {
160155
index->addr = reinterpret_cast<uintptr_t>(_deserialize<int8_t>(res, filename, dim, metric));
161-
index->dtype.code = kDLUInt;
162156
} else {
163157
RAFT_FAIL("Unsupported dtype in file %s", filename);
164158
}

docs/source/c_api/neighbors.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,3 +13,4 @@ Nearest Neighbors
1313
neighbors_ivf_flat_c.rst
1414
neighbors_ivf_pq_c.rst
1515
neighbors_cagra_c.rst
16+
neighbors_hnsw_c.rst

docs/source/cpp_api/neighbors.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ Nearest Neighbors
1111

1212
neighbors_bruteforce.rst
1313
neighbors_cagra.rst
14+
neighbors_hnsw.rst
1415
neighbors_ivf_flat.rst
1516
neighbors_ivf_pq.rst
1617
neighbors_nn_descent.rst

0 commit comments

Comments
 (0)