Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions cpp/include/cuvs/neighbors/ivf_pq.h
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,24 @@ cuvsError_t cuvsIvfPqIndexCreate(cuvsIvfPqIndex_t* index);
* @param[in] index cuvsIvfPqIndex_t to de-allocate
*/
cuvsError_t cuvsIvfPqIndexDestroy(cuvsIvfPqIndex_t index);

/** Get the number of clusters/inverted lists */
uint32_t cuvsIvfPqIndexGetNLists(cuvsIvfPqIndex_t index);

/** Get the dimensionality of the cluster centers */
uint32_t cuvsIvfPqIndexGetDimExt(cuvsIvfPqIndex_t index);

/**
* @brief Get the cluster centers corresponding to the lists in the original space
*
* @param[in] res cuvsResources_t opaque C handle
* @param[in] index cuvsIvfPqIndex_t Built NN-Descent index
* @param[out] centers Preallocated array on host memory to store output, [n_lists, dim_ext]
* @return cuvsError_t
*/
cuvsError_t cuvsIvfPqIndexGetCenters(cuvsResources_t res,
cuvsIvfPqIndex_t index,
DLManagedTensor* centers);
/**
* @}
*/
Expand Down
45 changes: 45 additions & 0 deletions cpp/src/neighbors/ivf_pq_c.cpp
100755 → 100644
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,24 @@ void _extend(cuvsResources_t res,

cuvs::neighbors::ivf_pq::extend(*res_ptr, vectors_mds, indices_mds, index_ptr);
}

template <typename output_mdspan_type, typename IdxT>
void _get_centers(cuvsResources_t res, cuvsIvfPqIndex index, DLManagedTensor* centers)
{
auto res_ptr = reinterpret_cast<raft::resources*>(res);
auto index_ptr = reinterpret_cast<cuvs::neighbors::ivf_pq::index<IdxT>*>(index.addr);
auto dst = cuvs::core::from_dlpack<output_mdspan_type>(centers);
auto src = index_ptr->centers();

RAFT_EXPECTS(src.extent(0) == dst.extent(0), "Output centers has incorrect number of rows");
RAFT_EXPECTS(src.extent(1) == dst.extent(1), "Output centers has incorrect number of cols");

cudaMemcpyAsync(dst.data_handle(),
src.data_handle(),
dst.extent(0) * dst.extent(1) * sizeof(float),
cudaMemcpyDefault,
raft::resource::get_cuda_stream(*res_ptr));
}
} // namespace

extern "C" cuvsError_t cuvsIvfPqIndexCreate(cuvsIvfPqIndex_t* index)
Expand Down Expand Up @@ -295,3 +313,30 @@ extern "C" cuvsError_t cuvsIvfPqExtend(cuvsResources_t res,
}
});
}

extern "C" uint32_t cuvsIvfPqIndexGetNLists(cuvsIvfPqIndex_t index)
{
auto index_ptr = reinterpret_cast<cuvs::neighbors::ivf_pq::index<int64_t>*>(index->addr);
return index_ptr->n_lists();
}

extern "C" uint32_t cuvsIvfPqIndexGetDimExt(cuvsIvfPqIndex_t index)
{
auto index_ptr = reinterpret_cast<cuvs::neighbors::ivf_pq::index<int64_t>*>(index->addr);
return index_ptr->dim_ext();
}

extern "C" cuvsError_t cuvsIvfPqIndexGetCenters(cuvsResources_t res,
cuvsIvfPqIndex_t index,
DLManagedTensor* centers)
{
return cuvs::core::translate_exceptions([=] {
if (cuvs::core::is_dlpack_device_compatible(centers->dl_tensor)) {
using output_mdspan_type = raft::device_matrix_view<float, int64_t, raft::row_major>;
_get_centers<output_mdspan_type, int64_t>(res, *index, centers);
} else {
using output_mdspan_type = raft::host_matrix_view<float, int64_t, raft::row_major>;
_get_centers<output_mdspan_type, int64_t>(res, *index, centers);
}
});
}
16 changes: 12 additions & 4 deletions python/cuvs/cuvs/neighbors/ivf_pq/ivf_pq.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -81,25 +81,33 @@ cdef extern from "cuvs/neighbors/ivf_pq.h" nogil:

cuvsError_t cuvsIvfPqIndexDestroy(cuvsIvfPqIndex_t index)

uint32_t cuvsIvfPqIndexGetNLists(cuvsIvfPqIndex_t index)

uint32_t cuvsIvfPqIndexGetDimExt(cuvsIvfPqIndex_t index)

cuvsError_t cuvsIvfPqIndexGetCenters(cuvsResources_t res,
cuvsIvfPqIndex_t index,
DLManagedTensor * centers)

cuvsError_t cuvsIvfPqBuild(cuvsResources_t res,
cuvsIvfPqIndexParams* params,
DLManagedTensor* dataset,
cuvsIvfPqIndex_t index) except +
cuvsIvfPqIndex_t index)

cuvsError_t cuvsIvfPqSearch(cuvsResources_t res,
cuvsIvfPqSearchParams* params,
cuvsIvfPqIndex_t index,
DLManagedTensor* queries,
DLManagedTensor* neighbors,
DLManagedTensor* distances) except +
DLManagedTensor* distances)

cuvsError_t cuvsIvfPqSerialize(cuvsResources_t res,
const char * filename,
cuvsIvfPqIndex_t index) except +
cuvsIvfPqIndex_t index)

cuvsError_t cuvsIvfPqDeserialize(cuvsResources_t res,
const char * filename,
cuvsIvfPqIndex_t index) except +
cuvsIvfPqIndex_t index)

cuvsError_t cuvsIvfPqExtend(cuvsResources_t res,
DLManagedTensor* new_vectors,
Expand Down
29 changes: 29 additions & 0 deletions python/cuvs/cuvs/neighbors/ivf_pq/ivf_pq.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,35 @@ cdef class Index:
def __repr__(self):
return "Index(type=IvfPq)"

@property
def n_lists(self):
""" The number of inverted lists (clusters) """
return cuvsIvfPqIndexGetNLists(self.index)

@property
def dim_ext(self):
""" dimensionality of the cluster centers """
return cuvsIvfPqIndexGetDimExt(self.index)

@property
def centers(self):
""" Get the cluster centers corresponding to the lists in the
original space """
return self._get_centers()

@auto_sync_resources
def _get_centers(self, resources=None):
if not self.trained:
raise ValueError("Index needs to be built before getting centers")

cdef cuvsResources_t res = <cuvsResources_t>resources.get_c_obj()

output = np.empty((self.n_lists, self.dim_ext), dtype='float32')
ai = wrap_array(output)
cdef cydlpack.DLManagedTensor* output_dlpack = cydlpack.dlpack_c(ai)
check_cuvs(cuvsIvfPqIndexGetCenters(res, self.index, output_dlpack))
return output


@auto_sync_resources
def build(IndexParams index_params, dataset, resources=None):
Expand Down
3 changes: 3 additions & 0 deletions python/cuvs/cuvs/tests/test_ivf_pq.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,9 @@ def run_ivf_pq_build_search_test(
if not inplace:
out_dist_device, out_idx_device = ret_output

centers = index.centers
assert centers.shape[0] == n_lists

if not compare:
return

Expand Down