Skip to content

Commit c80fc1d

Browse files
authored
Expose ivf-pq centers to python/c (#881)
Authors: - Ben Frederickson (https://github.com/benfred) Approvers: - Corey J. Nolet (https://github.com/cjnolet) URL: #881
1 parent 19a1759 commit c80fc1d

File tree

5 files changed

+107
-4
lines changed

5 files changed

+107
-4
lines changed

cpp/include/cuvs/neighbors/ivf_pq.h

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -260,6 +260,24 @@ cuvsError_t cuvsIvfPqIndexCreate(cuvsIvfPqIndex_t* index);
260260
* @param[in] index cuvsIvfPqIndex_t to de-allocate
261261
*/
262262
cuvsError_t cuvsIvfPqIndexDestroy(cuvsIvfPqIndex_t index);
263+
264+
/** Get the number of clusters/inverted lists */
265+
uint32_t cuvsIvfPqIndexGetNLists(cuvsIvfPqIndex_t index);
266+
267+
/** Get the dimensionality of the cluster centers */
268+
uint32_t cuvsIvfPqIndexGetDimExt(cuvsIvfPqIndex_t index);
269+
270+
/**
271+
* @brief Get the cluster centers corresponding to the lists in the original space
272+
*
273+
* @param[in] res cuvsResources_t opaque C handle
274+
* @param[in] index cuvsIvfPqIndex_t Built NN-Descent index
275+
* @param[out] centers Preallocated array on host memory to store output, [n_lists, dim_ext]
276+
* @return cuvsError_t
277+
*/
278+
cuvsError_t cuvsIvfPqIndexGetCenters(cuvsResources_t res,
279+
cuvsIvfPqIndex_t index,
280+
DLManagedTensor* centers);
263281
/**
264282
* @}
265283
*/

cpp/src/neighbors/ivf_pq_c.cpp

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,24 @@ void _extend(cuvsResources_t res,
143143
cuvs::neighbors::ivf_pq::extend(*res_ptr, vectors_mds, indices_mds, index_ptr);
144144
}
145145
}
146+
147+
template <typename output_mdspan_type, typename IdxT>
148+
void _get_centers(cuvsResources_t res, cuvsIvfPqIndex index, DLManagedTensor* centers)
149+
{
150+
auto res_ptr = reinterpret_cast<raft::resources*>(res);
151+
auto index_ptr = reinterpret_cast<cuvs::neighbors::ivf_pq::index<IdxT>*>(index.addr);
152+
auto dst = cuvs::core::from_dlpack<output_mdspan_type>(centers);
153+
auto src = index_ptr->centers();
154+
155+
RAFT_EXPECTS(src.extent(0) == dst.extent(0), "Output centers has incorrect number of rows");
156+
RAFT_EXPECTS(src.extent(1) == dst.extent(1), "Output centers has incorrect number of cols");
157+
158+
cudaMemcpyAsync(dst.data_handle(),
159+
src.data_handle(),
160+
dst.extent(0) * dst.extent(1) * sizeof(float),
161+
cudaMemcpyDefault,
162+
raft::resource::get_cuda_stream(*res_ptr));
163+
}
146164
} // namespace
147165

148166
extern "C" cuvsError_t cuvsIvfPqIndexCreate(cuvsIvfPqIndex_t* index)
@@ -312,3 +330,30 @@ extern "C" cuvsError_t cuvsIvfPqExtend(cuvsResources_t res,
312330
}
313331
});
314332
}
333+
334+
extern "C" uint32_t cuvsIvfPqIndexGetNLists(cuvsIvfPqIndex_t index)
335+
{
336+
auto index_ptr = reinterpret_cast<cuvs::neighbors::ivf_pq::index<int64_t>*>(index->addr);
337+
return index_ptr->n_lists();
338+
}
339+
340+
extern "C" uint32_t cuvsIvfPqIndexGetDimExt(cuvsIvfPqIndex_t index)
341+
{
342+
auto index_ptr = reinterpret_cast<cuvs::neighbors::ivf_pq::index<int64_t>*>(index->addr);
343+
return index_ptr->dim_ext();
344+
}
345+
346+
extern "C" cuvsError_t cuvsIvfPqIndexGetCenters(cuvsResources_t res,
347+
cuvsIvfPqIndex_t index,
348+
DLManagedTensor* centers)
349+
{
350+
return cuvs::core::translate_exceptions([=] {
351+
if (cuvs::core::is_dlpack_device_compatible(centers->dl_tensor)) {
352+
using output_mdspan_type = raft::device_matrix_view<float, int64_t, raft::row_major>;
353+
_get_centers<output_mdspan_type, int64_t>(res, *index, centers);
354+
} else {
355+
using output_mdspan_type = raft::host_matrix_view<float, int64_t, raft::row_major>;
356+
_get_centers<output_mdspan_type, int64_t>(res, *index, centers);
357+
}
358+
});
359+
}

python/cuvs/cuvs/neighbors/ivf_pq/ivf_pq.pxd

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -81,25 +81,33 @@ cdef extern from "cuvs/neighbors/ivf_pq.h" nogil:
8181

8282
cuvsError_t cuvsIvfPqIndexDestroy(cuvsIvfPqIndex_t index)
8383

84+
uint32_t cuvsIvfPqIndexGetNLists(cuvsIvfPqIndex_t index)
85+
86+
uint32_t cuvsIvfPqIndexGetDimExt(cuvsIvfPqIndex_t index)
87+
88+
cuvsError_t cuvsIvfPqIndexGetCenters(cuvsResources_t res,
89+
cuvsIvfPqIndex_t index,
90+
DLManagedTensor * centers)
91+
8492
cuvsError_t cuvsIvfPqBuild(cuvsResources_t res,
8593
cuvsIvfPqIndexParams* params,
8694
DLManagedTensor* dataset,
87-
cuvsIvfPqIndex_t index) except +
95+
cuvsIvfPqIndex_t index)
8896

8997
cuvsError_t cuvsIvfPqSearch(cuvsResources_t res,
9098
cuvsIvfPqSearchParams* params,
9199
cuvsIvfPqIndex_t index,
92100
DLManagedTensor* queries,
93101
DLManagedTensor* neighbors,
94-
DLManagedTensor* distances) except +
102+
DLManagedTensor* distances)
95103

96104
cuvsError_t cuvsIvfPqSerialize(cuvsResources_t res,
97105
const char * filename,
98-
cuvsIvfPqIndex_t index) except +
106+
cuvsIvfPqIndex_t index)
99107

100108
cuvsError_t cuvsIvfPqDeserialize(cuvsResources_t res,
101109
const char * filename,
102-
cuvsIvfPqIndex_t index) except +
110+
cuvsIvfPqIndex_t index)
103111

104112
cuvsError_t cuvsIvfPqExtend(cuvsResources_t res,
105113
DLManagedTensor* new_vectors,

python/cuvs/cuvs/neighbors/ivf_pq/ivf_pq.pyx

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -238,6 +238,35 @@ cdef class Index:
238238
def __repr__(self):
239239
return "Index(type=IvfPq)"
240240

241+
@property
242+
def n_lists(self):
243+
""" The number of inverted lists (clusters) """
244+
return cuvsIvfPqIndexGetNLists(self.index)
245+
246+
@property
247+
def dim_ext(self):
248+
""" dimensionality of the cluster centers """
249+
return cuvsIvfPqIndexGetDimExt(self.index)
250+
251+
@property
252+
def centers(self):
253+
""" Get the cluster centers corresponding to the lists in the
254+
original space """
255+
return self._get_centers()
256+
257+
@auto_sync_resources
258+
def _get_centers(self, resources=None):
259+
if not self.trained:
260+
raise ValueError("Index needs to be built before getting centers")
261+
262+
cdef cuvsResources_t res = <cuvsResources_t>resources.get_c_obj()
263+
264+
output = np.empty((self.n_lists, self.dim_ext), dtype='float32')
265+
ai = wrap_array(output)
266+
cdef cydlpack.DLManagedTensor* output_dlpack = cydlpack.dlpack_c(ai)
267+
check_cuvs(cuvsIvfPqIndexGetCenters(res, self.index, output_dlpack))
268+
return output
269+
241270

242271
@auto_sync_resources
243272
def build(IndexParams index_params, dataset, resources=None):

python/cuvs/cuvs/tests/test_ivf_pq.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,9 @@ def run_ivf_pq_build_search_test(
110110
if not inplace:
111111
out_dist_device, out_idx_device = ret_output
112112

113+
centers = index.centers
114+
assert centers.shape[0] == n_lists
115+
113116
if not compare:
114117
return
115118

0 commit comments

Comments
 (0)