Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion cpp/include/cuvs/neighbors/cagra.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#pragma once

#include <cuvs/core/c_api.h>
#include <cuvs/neighbors/common.h>
#include <dlpack/dlpack.h>
#include <stdbool.h>
#include <stdint.h>
Expand Down Expand Up @@ -385,13 +386,16 @@ cuvsError_t cuvsCagraBuild(cuvsResources_t res,
* @param[in] queries DLManagedTensor* queries dataset to search
* @param[out] neighbors DLManagedTensor* output `k` neighbors for queries
* @param[out] distances DLManagedTensor* output `k` distances for queries
* @param[in] prefilter cuvsFilter input prefilter that can be used
to filter queries and neighbors based on the given bitset.
*/
cuvsError_t cuvsCagraSearch(cuvsResources_t res,
cuvsCagraSearchParams_t params,
cuvsCagraIndex_t index,
DLManagedTensor* queries,
DLManagedTensor* neighbors,
DLManagedTensor* distances);
DLManagedTensor* distances,
cuvsFilter filter);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This API change needs to be propagated to:

  • the python package
  • the example C project (cuvs/example/c)
  • probably the rust package

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done👍


/**
* @}
Expand Down
39 changes: 32 additions & 7 deletions cpp/src/neighbors/cagra_c.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
#include <cuvs/core/interop.hpp>
#include <cuvs/neighbors/cagra.h>
#include <cuvs/neighbors/cagra.hpp>
#include <cuvs/neighbors/common.h>

#include <fstream>

Expand Down Expand Up @@ -91,7 +92,8 @@ void _search(cuvsResources_t res,
cuvsCagraIndex index,
DLManagedTensor* queries_tensor,
DLManagedTensor* neighbors_tensor,
DLManagedTensor* distances_tensor)
DLManagedTensor* distances_tensor,
cuvsFilter filter)
{
auto res_ptr = reinterpret_cast<raft::resources*>(res);
auto index_ptr = reinterpret_cast<cuvs::neighbors::cagra::index<T, uint32_t>*>(index.addr);
Expand All @@ -117,8 +119,27 @@ void _search(cuvsResources_t res,
auto queries_mds = cuvs::core::from_dlpack<queries_mdspan_type>(queries_tensor);
auto neighbors_mds = cuvs::core::from_dlpack<neighbors_mdspan_type>(neighbors_tensor);
auto distances_mds = cuvs::core::from_dlpack<distances_mdspan_type>(distances_tensor);
cuvs::neighbors::cagra::search(
*res_ptr, search_params, *index_ptr, queries_mds, neighbors_mds, distances_mds);
if (filter.type == NO_FILTER) {
cuvs::neighbors::cagra::search(
*res_ptr, search_params, *index_ptr, queries_mds, neighbors_mds, distances_mds);
} else if (filter.type == BITSET) {
using filter_mdspan_type = raft::device_vector_view<int64_t, int64_t, raft::row_major>;
auto removed_indices_tensor = reinterpret_cast<DLManagedTensor*>(filter.addr);
auto removed_indices = cuvs::core::from_dlpack<filter_mdspan_type>(removed_indices_tensor);
cuvs::core::bitset<std::uint32_t, int64_t> removed_indices_bitset(
*res_ptr, removed_indices, index_ptr->dataset().extent(0));
auto bitset_filter_obj =
cuvs::neighbors::filtering::bitset_filter(removed_indices_bitset.view());
cuvs::neighbors::cagra::search(*res_ptr,
search_params,
*index_ptr,
queries_mds,
neighbors_mds,
distances_mds,
bitset_filter_obj);
} else {
RAFT_FAIL("Unsupported prefilter type: BITMAP");
}
}

template <typename T>
Expand Down Expand Up @@ -213,7 +234,8 @@ extern "C" cuvsError_t cuvsCagraSearch(cuvsResources_t res,
cuvsCagraIndex_t index_c_ptr,
DLManagedTensor* queries_tensor,
DLManagedTensor* neighbors_tensor,
DLManagedTensor* distances_tensor)
DLManagedTensor* distances_tensor,
cuvsFilter filter)
{
return cuvs::core::translate_exceptions([=] {
auto queries = queries_tensor->dl_tensor;
Expand All @@ -236,11 +258,14 @@ extern "C" cuvsError_t cuvsCagraSearch(cuvsResources_t res,
RAFT_EXPECTS(queries.dtype.code == index.dtype.code, "type mismatch between index and queries");

if (queries.dtype.code == kDLFloat && queries.dtype.bits == 32) {
_search<float>(res, *params, index, queries_tensor, neighbors_tensor, distances_tensor);
_search<float>(
res, *params, index, queries_tensor, neighbors_tensor, distances_tensor, filter);
} else if (queries.dtype.code == kDLInt && queries.dtype.bits == 8) {
_search<int8_t>(res, *params, index, queries_tensor, neighbors_tensor, distances_tensor);
_search<int8_t>(
res, *params, index, queries_tensor, neighbors_tensor, distances_tensor, filter);
} else if (queries.dtype.code == kDLUInt && queries.dtype.bits == 8) {
_search<uint8_t>(res, *params, index, queries_tensor, neighbors_tensor, distances_tensor);
_search<uint8_t>(
res, *params, index, queries_tensor, neighbors_tensor, distances_tensor, filter);
} else {
RAFT_FAIL("Unsupported queries DLtensor dtype: %d and bits: %d",
queries.dtype.code,
Expand Down
123 changes: 122 additions & 1 deletion cpp/test/neighbors/ann_cagra_c.cu
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,14 @@ float queries[4][2] = {{0.48216683, 0.0428398},
{0.51260436, 0.2643005},
{0.05198065, 0.5789965}};

int64_t filter[2] = {1, 2};

uint32_t neighbors_exp[4] = {3, 0, 3, 1};
float distances_exp[4] = {0.03878258, 0.12472608, 0.04776672, 0.15224178};

uint32_t neighbors_exp_filtered[4] = {3, 0, 3, 0};
float distances_exp_filtered[4] = {0.03878258, 0.12472608, 0.04776672, 0.59063464};

TEST(CagraC, BuildSearch)
{
// create cuvsResources_t
Expand Down Expand Up @@ -109,10 +114,15 @@ TEST(CagraC, BuildSearch)
distances_tensor.dl_tensor.shape = distances_shape;
distances_tensor.dl_tensor.strides = nullptr;

cuvsFilter prefilter;
prefilter.type = NO_FILTER;
prefilter.addr = (uintptr_t)NULL;

// search index
cuvsCagraSearchParams_t search_params;
cuvsCagraSearchParamsCreate(&search_params);
cuvsCagraSearch(res, search_params, index, &queries_tensor, &neighbors_tensor, &distances_tensor);
cuvsCagraSearch(
res, search_params, index, &queries_tensor, &neighbors_tensor, &distances_tensor, prefilter);

// verify output
ASSERT_TRUE(
Expand All @@ -126,3 +136,114 @@ TEST(CagraC, BuildSearch)
cuvsCagraIndexDestroy(index);
cuvsResourcesDestroy(res);
}

TEST(CagraC, BuildSearchFiltered)
{
// create cuvsResources_t
cuvsResources_t res;
cuvsResourcesCreate(&res);
cudaStream_t stream;
cuvsStreamGet(res, &stream);

// create dataset DLTensor
DLManagedTensor dataset_tensor;
dataset_tensor.dl_tensor.data = dataset;
dataset_tensor.dl_tensor.device.device_type = kDLCPU;
dataset_tensor.dl_tensor.ndim = 2;
dataset_tensor.dl_tensor.dtype.code = kDLFloat;
dataset_tensor.dl_tensor.dtype.bits = 32;
dataset_tensor.dl_tensor.dtype.lanes = 1;
int64_t dataset_shape[2] = {4, 2};
dataset_tensor.dl_tensor.shape = dataset_shape;
dataset_tensor.dl_tensor.strides = nullptr;

// create index
cuvsCagraIndex_t index;
cuvsCagraIndexCreate(&index);

// build index
cuvsCagraIndexParams_t build_params;
cuvsCagraIndexParamsCreate(&build_params);
cuvsCagraBuild(res, build_params, &dataset_tensor, index);

// create queries DLTensor
rmm::device_uvector<float> queries_d(4 * 2, stream);
raft::copy(queries_d.data(), (float*)queries, 4 * 2, stream);

DLManagedTensor queries_tensor;
queries_tensor.dl_tensor.data = queries_d.data();
queries_tensor.dl_tensor.device.device_type = kDLCUDA;
queries_tensor.dl_tensor.ndim = 2;
queries_tensor.dl_tensor.dtype.code = kDLFloat;
queries_tensor.dl_tensor.dtype.bits = 32;
queries_tensor.dl_tensor.dtype.lanes = 1;
int64_t queries_shape[2] = {4, 2};
queries_tensor.dl_tensor.shape = queries_shape;
queries_tensor.dl_tensor.strides = nullptr;

// create neighbors DLTensor
rmm::device_uvector<uint32_t> neighbors_d(4, stream);

DLManagedTensor neighbors_tensor;
neighbors_tensor.dl_tensor.data = neighbors_d.data();
neighbors_tensor.dl_tensor.device.device_type = kDLCUDA;
neighbors_tensor.dl_tensor.ndim = 2;
neighbors_tensor.dl_tensor.dtype.code = kDLUInt;
neighbors_tensor.dl_tensor.dtype.bits = 32;
neighbors_tensor.dl_tensor.dtype.lanes = 1;
int64_t neighbors_shape[2] = {4, 1};
neighbors_tensor.dl_tensor.shape = neighbors_shape;
neighbors_tensor.dl_tensor.strides = nullptr;

// create distances DLTensor
rmm::device_uvector<float> distances_d(4, stream);

DLManagedTensor distances_tensor;
distances_tensor.dl_tensor.data = distances_d.data();
distances_tensor.dl_tensor.device.device_type = kDLCUDA;
distances_tensor.dl_tensor.ndim = 2;
distances_tensor.dl_tensor.dtype.code = kDLFloat;
distances_tensor.dl_tensor.dtype.bits = 32;
distances_tensor.dl_tensor.dtype.lanes = 1;
int64_t distances_shape[2] = {4, 1};
distances_tensor.dl_tensor.shape = distances_shape;
distances_tensor.dl_tensor.strides = nullptr;

// create filter DLTensor
rmm::device_uvector<int64_t> filter_d(2, stream);
raft::copy(filter_d.data(), (int64_t*)filter, 2, stream);

cuvsFilter filter;

DLManagedTensor filter_tensor;
filter_tensor.dl_tensor.data = filter_d.data();
filter_tensor.dl_tensor.device.device_type = kDLCUDA;
filter_tensor.dl_tensor.ndim = 1;
filter_tensor.dl_tensor.dtype.code = kDLInt;
filter_tensor.dl_tensor.dtype.bits = 64;
filter_tensor.dl_tensor.dtype.lanes = 1;
int64_t filter_shape[1] = {2};
filter_tensor.dl_tensor.shape = filter_shape;
filter_tensor.dl_tensor.strides = nullptr;

filter.type = BITSET;
filter.addr = (uintptr_t)&filter_tensor;

// search index
cuvsCagraSearchParams_t search_params;
cuvsCagraSearchParamsCreate(&search_params);
cuvsCagraSearch(
res, search_params, index, &queries_tensor, &neighbors_tensor, &distances_tensor, filter);

// verify output
ASSERT_TRUE(cuvs::devArrMatchHost(
neighbors_exp_filtered, neighbors_d.data(), 4, cuvs::Compare<uint32_t>()));
ASSERT_TRUE(cuvs::devArrMatchHost(
distances_exp_filtered, distances_d.data(), 4, cuvs::CompareApprox<float>(0.001f)));

// de-allocate index and res
cuvsCagraSearchParamsDestroy(search_params);
cuvsCagraIndexParamsDestroy(build_params);
cuvsCagraIndexDestroy(index);
cuvsResourcesDestroy(res);
}