|
37 | 37 | namespace { |
38 | 38 | using namespace cuvs::neighbors; |
39 | 39 |
|
| 40 | +template <typename T> |
| 41 | +void convert_c_index_params(cuvsTieredIndexParams params, |
| 42 | + int64_t n_rows, |
| 43 | + int64_t dim, |
| 44 | + tiered_index::index_params<T>* out) |
| 45 | +{ |
| 46 | + out->min_ann_rows = params.min_ann_rows; |
| 47 | + out->create_ann_index_on_extend = params.create_ann_index_on_extend; |
| 48 | + out->metric = params.metric; |
| 49 | + |
| 50 | + if constexpr (std::is_same_v<T, cagra::index_params>) { |
| 51 | + if (params.cagra_params != NULL) { |
| 52 | + cagra::convert_c_index_params(*params.cagra_params, n_rows, dim, out); |
| 53 | + } |
| 54 | + } else if constexpr (std::is_same_v<T, ivf_flat::index_params>) { |
| 55 | + if (params.ivf_flat_params != NULL) { |
| 56 | + ivf_flat::convert_c_index_params(*params.ivf_flat_params, out); |
| 57 | + } |
| 58 | + } else if constexpr (std::is_same_v<T, ivf_pq::index_params>) { |
| 59 | + if (params.ivf_pq_params != NULL) { |
| 60 | + ivf_pq::convert_c_index_params(*params.ivf_pq_params, out); |
| 61 | + } |
| 62 | + } else { |
| 63 | + RAFT_FAIL("unhandled index params type"); |
| 64 | + } |
| 65 | +} |
| 66 | + |
40 | 67 | template <typename T> |
41 | 68 | void* _build(cuvsResources_t res, cuvsTieredIndexParams params, DLManagedTensor* dataset_tensor) |
42 | 69 | { |
43 | 70 | auto res_ptr = reinterpret_cast<raft::resources*>(res); |
44 | 71 | using mdspan_type = raft::device_matrix_view<const T, int64_t, raft::row_major>; |
45 | 72 | auto mds = cuvs::core::from_dlpack<mdspan_type>(dataset_tensor); |
46 | 73 |
|
| 74 | + auto dataset = dataset_tensor->dl_tensor; |
| 75 | + RAFT_EXPECTS(dataset.ndim == 2, "dataset should be a 2-dimensional tensor"); |
| 76 | + RAFT_EXPECTS(dataset.shape != nullptr, "dataset should have an initialized shape"); |
| 77 | + |
47 | 78 | switch (params.algo) { |
48 | 79 | case CUVS_TIERED_INDEX_ALGO_CAGRA: { |
49 | 80 | auto build_params = tiered_index::index_params<cagra::index_params>(); |
50 | | - if (params.cagra_params != NULL) { |
51 | | - auto dataset = dataset_tensor->dl_tensor; |
52 | | - cagra::convert_c_index_params( |
53 | | - *params.cagra_params, dataset.shape[0], dataset.shape[1], &build_params); |
54 | | - } |
55 | | - build_params.min_ann_rows = params.min_ann_rows; |
56 | | - build_params.create_ann_index_on_extend = params.create_ann_index_on_extend; |
57 | | - build_params.metric = params.metric; |
| 81 | + convert_c_index_params(params, dataset.shape[0], dataset.shape[1], &build_params); |
58 | 82 | return new tiered_index::index<cagra::index<T, uint32_t>>( |
59 | 83 | tiered_index::build(*res_ptr, build_params, mds)); |
60 | 84 | } |
61 | 85 | case CUVS_TIERED_INDEX_ALGO_IVF_FLAT: { |
62 | 86 | auto build_params = tiered_index::index_params<ivf_flat::index_params>(); |
63 | | - if (params.ivf_flat_params != NULL) { |
64 | | - ivf_flat::convert_c_index_params(*params.ivf_flat_params, &build_params); |
65 | | - } |
66 | | - build_params.metric = params.metric; |
67 | | - build_params.min_ann_rows = params.min_ann_rows; |
68 | | - build_params.create_ann_index_on_extend = params.create_ann_index_on_extend; |
| 87 | + convert_c_index_params(params, dataset.shape[0], dataset.shape[1], &build_params); |
69 | 88 | return new tiered_index::index<ivf_flat::index<T, int64_t>>( |
70 | 89 | tiered_index::build(*res_ptr, build_params, mds)); |
71 | 90 | } |
72 | 91 | case CUVS_TIERED_INDEX_ALGO_IVF_PQ: { |
73 | | - auto build_params = tiered_index::index_params<ivf_pq::index_params>(); |
74 | | - build_params.metric = params.metric; |
75 | | - if (params.ivf_pq_params != NULL) { |
76 | | - ivf_pq::convert_c_index_params(*params.ivf_pq_params, &build_params); |
77 | | - } |
78 | | - build_params.metric = params.metric; |
79 | | - build_params.min_ann_rows = params.min_ann_rows; |
80 | | - build_params.create_ann_index_on_extend = params.create_ann_index_on_extend; |
| 92 | + auto build_params = tiered_index::index_params<ivf_pq::index_params>(); |
| 93 | + convert_c_index_params(params, dataset.shape[0], dataset.shape[1], &build_params); |
81 | 94 | return new tiered_index::index<ivf_pq::typed_index<T, int64_t>>( |
82 | 95 | tiered_index::build(*res_ptr, build_params, mds)); |
83 | 96 | } |
@@ -157,6 +170,47 @@ void _extend(cuvsResources_t res, DLManagedTensor* new_vectors, cuvsTieredIndex |
157 | 170 |
|
158 | 171 | tiered_index::extend(*res_ptr, vectors_mds, index_ptr); |
159 | 172 | } |
| 173 | +template <typename UpstreamT> |
| 174 | +void _merge(cuvsResources_t res, |
| 175 | + cuvsTieredIndexParams params, |
| 176 | + cuvsTieredIndex_t* indices, |
| 177 | + size_t num_indices, |
| 178 | + cuvsTieredIndex_t output_index) |
| 179 | +{ |
| 180 | + auto res_ptr = reinterpret_cast<raft::resources*>(res); |
| 181 | + |
| 182 | + std::vector<cuvs::neighbors::tiered_index::index<UpstreamT>*> cpp_indices; |
| 183 | + |
| 184 | + int64_t n_rows = 0, dim = 0; |
| 185 | + for (size_t i = 0; i < num_indices; ++i) { |
| 186 | + RAFT_EXPECTS(indices[i]->dtype.code == indices[0]->dtype.code, |
| 187 | + "indices must all have the same dtype"); |
| 188 | + RAFT_EXPECTS(indices[i]->dtype.bits == indices[0]->dtype.bits, |
| 189 | + "indices must all have the same dtype"); |
| 190 | + RAFT_EXPECTS(indices[i]->algo == indices[0]->algo, |
| 191 | + "indices must all have the same index algorithm"); |
| 192 | + |
| 193 | + auto idx_ptr = |
| 194 | + reinterpret_cast<cuvs::neighbors::tiered_index::index<UpstreamT>*>(indices[i]->addr); |
| 195 | + n_rows += idx_ptr->size(); |
| 196 | + if (dim) { |
| 197 | + RAFT_EXPECTS(dim == idx_ptr->dim(), "indices must all have the same dimensionality"); |
| 198 | + } else { |
| 199 | + dim = idx_ptr->dim(); |
| 200 | + } |
| 201 | + cpp_indices.push_back(idx_ptr); |
| 202 | + } |
| 203 | + |
| 204 | + auto build_params = tiered_index::index_params<typename UpstreamT::index_params_type>(); |
| 205 | + convert_c_index_params(params, n_rows, dim, &build_params); |
| 206 | + |
| 207 | + auto ptr = new cuvs::neighbors::tiered_index::index<UpstreamT>( |
| 208 | + cuvs::neighbors::tiered_index::merge(*res_ptr, build_params, cpp_indices)); |
| 209 | + |
| 210 | + output_index->addr = reinterpret_cast<uintptr_t>(ptr); |
| 211 | + output_index->dtype = indices[0]->dtype; |
| 212 | + output_index->algo = indices[0]->algo; |
| 213 | +} |
160 | 214 |
|
161 | 215 | } // namespace |
162 | 216 |
|
@@ -305,3 +359,31 @@ extern "C" cuvsError_t cuvsTieredIndexExtend(cuvsResources_t res, |
305 | 359 | } |
306 | 360 | }); |
307 | 361 | } |
| 362 | + |
| 363 | +extern "C" cuvsError_t cuvsTieredIndexMerge(cuvsResources_t res, |
| 364 | + cuvsTieredIndexParams_t params, |
| 365 | + cuvsTieredIndex_t* indices, |
| 366 | + size_t num_indices, |
| 367 | + cuvsTieredIndex_t output_index) |
| 368 | +{ |
| 369 | + return cuvs::core::translate_exceptions([=] { |
| 370 | + RAFT_EXPECTS(num_indices >= 1, "must have at least one index to merge"); |
| 371 | + |
| 372 | + switch (indices[0]->algo) { |
| 373 | + case CUVS_TIERED_INDEX_ALGO_CAGRA: { |
| 374 | + _merge<cagra::index<float, uint32_t>>(res, *params, indices, num_indices, output_index); |
| 375 | + break; |
| 376 | + } |
| 377 | + case CUVS_TIERED_INDEX_ALGO_IVF_FLAT: { |
| 378 | + _merge<ivf_flat::index<float, int64_t>>(res, *params, indices, num_indices, output_index); |
| 379 | + break; |
| 380 | + } |
| 381 | + case CUVS_TIERED_INDEX_ALGO_IVF_PQ: { |
| 382 | + _merge<ivf_pq::typed_index<float, int64_t>>( |
| 383 | + res, *params, indices, num_indices, output_index); |
| 384 | + break; |
| 385 | + } |
| 386 | + default: RAFT_FAIL("unsupported tiered index algorithm"); |
| 387 | + } |
| 388 | + }); |
| 389 | +} |
0 commit comments