Skip to content

Commit c778c88

Browse files
authored
[Feat] Support bitset filter for Brute Force (#560)
Authors: - rhdong (https://github.com/rhdong) - Corey J. Nolet (https://github.com/cjnolet) Approvers: - Corey J. Nolet (https://github.com/cjnolet) URL: #560
1 parent 833f28c commit c778c88

File tree

7 files changed

+750
-99
lines changed

7 files changed

+750
-99
lines changed

cpp/include/cuvs/neighbors/brute_force.hpp

Lines changed: 140 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -332,15 +332,28 @@ auto build(raft::resources const& handle,
332332
* Note, this function requires a temporary buffer to store intermediate results between cuda kernel
333333
* calls, which may lead to undesirable allocations and slowdown. To alleviate the problem, you can
334334
* pass a pool memory resource or a large enough pre-allocated memory resource to reduce or
335-
* eliminate entirely allocations happening within `search`:
335+
* eliminate entirely allocations happening within `search`.
336+
*
337+
* Usage example:
336338
* @code{.cpp}
337-
* ...
338-
* // Use the same allocator across multiple searches to reduce the number of
339-
* // cuda memory allocations
340-
* brute_force::search(handle, index, queries1, out_inds1, out_dists1);
341-
* brute_force::search(handle, index, queries2, out_inds2, out_dists2);
342-
* brute_force::search(handle, index, queries3, out_inds3, out_dists3);
343-
* ...
339+
* using namespace cuvs::neighbors;
340+
*
341+
* // use default index parameters
342+
* brute_force::index_params index_params;
343+
* // create and fill the index from a [N, D] dataset
344+
* brute_force::index_params index_params;
345+
* auto index = brute_force::build(handle, index_params, dataset);
346+
* // use default search parameters
347+
* brute_force::search_params search_params;
348+
* // create a bitset to filter the search
349+
* auto removed_indices = raft::make_device_vector<int64_t, int64_t>(res, n_removed_indices);
350+
* raft::core::bitset<std::uint32_t, int64_t> removed_indices_bitset(
351+
* res, removed_indices.view(), dataset.extent(0));
352+
* // search K nearest neighbours according to a bitset
353+
* auto neighbors = raft::make_device_matrix<uint32_t>(res, n_queries, k);
354+
* auto distances = raft::make_device_matrix<float>(res, n_queries, k);
355+
* auto filter = filtering::bitset_filter(removed_indices_bitset.view());
356+
* brute_force::search(res, search_params, index, queries, neighbors, distances, filter);
344357
* @endcode
345358
*
346359
* @param[in] handle
@@ -350,9 +363,17 @@ auto build(raft::resources const& handle,
350363
* @param[out] neighbors a device pointer to the indices of the neighbors in the source dataset
351364
* [n_queries, k]
352365
* @param[out] distances a device pointer to the distances to the selected neighbors [n_queries, k]
353-
* @param[in] sample_filter An optional device bitmap filter function with a `row-major` layout and
354-
* the shape of [n_queries, index->size()], which means the filter will use the first
355-
* `index->size()` bits to indicate whether queries[0] should compute the distance with dataset.
366+
* @param[in] sample_filter An optional device filter that restricts which dataset elements should
367+
* be considered for each query.
368+
*
369+
* - Supports two types of filters:
370+
* 1. **Bitset Filter**: A shared filter where each bit corresponds to a dataset element.
371+
* All queries share the same filter, with a logical shape of `[1, index->size()]`.
372+
* 2. **Bitmap Filter**: A per-query filter with a logical shape of `[n_queries, index->size()]`,
373+
* where each bit indicates whether a specific dataset element should be considered for a
374+
* particular query. (1 for inclusion, 0 for exclusion).
375+
*
376+
* - The default value is `none_sample_filter`, which applies no filtering.
356377
*/
357378
void search(raft::resources const& handle,
358379
const cuvs::neighbors::brute_force::search_params& params,
@@ -379,15 +400,28 @@ void search(raft::resources const& handle,
379400
* Note, this function requires a temporary buffer to store intermediate results between cuda kernel
380401
* calls, which may lead to undesirable allocations and slowdown. To alleviate the problem, you can
381402
* pass a pool memory resource or a large enough pre-allocated memory resource to reduce or
382-
* eliminate entirely allocations happening within `search`:
403+
* eliminate entirely allocations happening within `search`.
404+
*
405+
* Usage example:
383406
* @code{.cpp}
384-
* ...
385-
* // Use the same allocator across multiple searches to reduce the number of
386-
* // cuda memory allocations
387-
* brute_force::search(handle, index, queries1, out_inds1, out_dists1);
388-
* brute_force::search(handle, index, queries2, out_inds2, out_dists2);
389-
* brute_force::search(handle, index, queries3, out_inds3, out_dists3);
390-
* ...
407+
* using namespace cuvs::neighbors;
408+
*
409+
* // use default index parameters
410+
* brute_force::index_params index_params;
411+
* // create and fill the index from a [N, D] dataset
412+
* brute_force::index_params index_params;
413+
* auto index = brute_force::build(handle, index_params, dataset);
414+
* // use default search parameters
415+
* brute_force::search_params search_params;
416+
* // create a bitset to filter the search
417+
* auto removed_indices = raft::make_device_vector<int64_t, int64_t>(res, n_removed_indices);
418+
* raft::core::bitset<std::uint32_t, int64_t> removed_indices_bitset(
419+
* res, removed_indices.view(), dataset.extent(0));
420+
* // search K nearest neighbours according to a bitset
421+
* auto neighbors = raft::make_device_matrix<uint32_t>(res, n_queries, k);
422+
* auto distances = raft::make_device_matrix<half>(res, n_queries, k);
423+
* auto filter = filtering::bitset_filter(removed_indices_bitset.view());
424+
* brute_force::search(res, search_params, index, queries, neighbors, distances, filter);
391425
* @endcode
392426
*
393427
* @param[in] handle
@@ -397,8 +431,17 @@ void search(raft::resources const& handle,
397431
* @param[out] neighbors a device pointer to the indices of the neighbors in the source dataset
398432
* [n_queries, k]
399433
* @param[out] distances a device pointer to the distances to the selected neighbors [n_queries, k]
400-
* @param[in] sample_filter a optional device bitmap filter function that greenlights samples for a
401-
* given
434+
* @param[in] sample_filter An optional device filter that restricts which dataset elements should
435+
* be considered for each query.
436+
*
437+
* - Supports two types of filters:
438+
* 1. **Bitset Filter**: A shared filter where each bit corresponds to a dataset element.
439+
* All queries share the same filter, with a logical shape of `[1, index->size()]`.
440+
* 2. **Bitmap Filter**: A per-query filter with a logical shape of `[n_queries, index->size()]`,
441+
* where each bit indicates whether a specific dataset element should be considered for a
442+
* particular query. (1 for inclusion, 0 for exclusion).
443+
*
444+
* - The default value is `none_sample_filter`, which applies no filtering.
402445
*/
403446
void search(raft::resources const& handle,
404447
const cuvs::neighbors::brute_force::search_params& params,
@@ -421,15 +464,51 @@ void search(raft::resources const& handle,
421464
*
422465
* See the [brute_force::build](#brute_force::build) documentation for a usage example.
423466
*
467+
* Note, this function requires a temporary buffer to store intermediate results between cuda kernel
468+
* calls, which may lead to undesirable allocations and slowdown. To alleviate the problem, you can
469+
* pass a pool memory resource or a large enough pre-allocated memory resource to reduce or
470+
* eliminate entirely allocations happening within `search`.
471+
*
472+
* Usage example:
473+
* @code{.cpp}
474+
* using namespace cuvs::neighbors;
475+
*
476+
* // use default index parameters
477+
* brute_force::index_params index_params;
478+
* // create and fill the index from a [N, D] dataset
479+
* brute_force::index_params index_params;
480+
* auto index = brute_force::build(handle, index_params, dataset);
481+
* // use default search parameters
482+
* brute_force::search_params search_params;
483+
* // create a bitset to filter the search
484+
* auto removed_indices = raft::make_device_vector<int64_t, int64_t>(res, n_removed_indices);
485+
* raft::core::bitset<std::uint32_t, int64_t> removed_indices_bitset(
486+
* res, removed_indices.view(), dataset.extent(0));
487+
* // search K nearest neighbours according to a bitset
488+
* auto neighbors = raft::make_device_matrix<uint32_t>(res, n_queries, k);
489+
* auto distances = raft::make_device_matrix<float>(res, n_queries, k);
490+
* auto filter = filtering::bitset_filter(removed_indices_bitset.view());
491+
* brute_force::search(res, search_params, index, queries, neighbors, distances, filter);
492+
* @endcode
493+
*
424494
* @param[in] handle
425495
* @param[in] params parameters configuring the search
426496
* @param[in] index bruteforce constructed index
427497
* @param[in] queries a device pointer to a col-major matrix [n_queries, index->dim()]
428498
* @param[out] neighbors a device pointer to the indices of the neighbors in the source dataset
429499
* [n_queries, k]
430500
* @param[out] distances a device pointer to the distances to the selected neighbors [n_queries, k]
431-
* @param[in] sample_filter an optional device bitmap filter function that greenlights samples for a
432-
* given query
501+
* @param[in] sample_filter An optional device filter that restricts which dataset elements should
502+
* be considered for each query.
503+
*
504+
* - Supports two types of filters:
505+
* 1. **Bitset Filter**: A shared filter where each bit corresponds to a dataset element.
506+
* All queries share the same filter, with a logical shape of `[1, index->size()]`.
507+
* 2. **Bitmap Filter**: A per-query filter with a logical shape of `[n_queries, index->size()]`,
508+
* where each bit indicates whether a specific dataset element should be considered for a
509+
* particular query. (1 for inclusion, 0 for exclusion).
510+
*
511+
* - The default value is `none_sample_filter`, which applies no filtering.
433512
*/
434513
void search(raft::resources const& handle,
435514
const cuvs::neighbors::brute_force::search_params& params,
@@ -452,15 +531,51 @@ void search(raft::resources const& handle,
452531
*
453532
* See the [brute_force::build](#brute_force::build) documentation for a usage example.
454533
*
534+
* Note, this function requires a temporary buffer to store intermediate results between cuda kernel
535+
* calls, which may lead to undesirable allocations and slowdown. To alleviate the problem, you can
536+
* pass a pool memory resource or a large enough pre-allocated memory resource to reduce or
537+
* eliminate entirely allocations happening within `search`.
538+
*
539+
* Usage example:
540+
* @code{.cpp}
541+
* using namespace cuvs::neighbors;
542+
*
543+
* // use default index parameters
544+
* brute_force::index_params index_params;
545+
* // create and fill the index from a [N, D] dataset
546+
* brute_force::index_params index_params;
547+
* auto index = brute_force::build(handle, index_params, dataset);
548+
* // use default search parameters
549+
* brute_force::search_params search_params;
550+
* // create a bitset to filter the search
551+
* auto removed_indices = raft::make_device_vector<int64_t, int64_t>(res, n_removed_indices);
552+
* raft::core::bitset<std::uint32_t, int64_t> removed_indices_bitset(
553+
* res, removed_indices.view(), dataset.extent(0));
554+
* // search K nearest neighbours according to a bitset
555+
* auto neighbors = raft::make_device_matrix<uint32_t>(res, n_queries, k);
556+
* auto distances = raft::make_device_matrix<half>(res, n_queries, k);
557+
* auto filter = filtering::bitset_filter(removed_indices_bitset.view());
558+
* brute_force::search(res, search_params, index, queries, neighbors, distances, filter);
559+
* @endcode
560+
*
455561
* @param[in] handle
456562
* @param[in] params parameters configuring the search
457563
* @param[in] index bruteforce constructed index
458564
* @param[in] queries a device pointer to a col-major matrix [n_queries, index->dim()]
459565
* @param[out] neighbors a device pointer to the indices of the neighbors in the source dataset
460566
* [n_queries, k]
461567
* @param[out] distances a device pointer to the distances to the selected neighbors [n_queries, k]
462-
* @param[in] sample_filter an optional device bitmap filter function that greenlights samples for a
463-
* given query
568+
* @param[in] sample_filter An optional device filter that restricts which dataset elements should
569+
* be considered for each query.
570+
*
571+
* - Supports two types of filters:
572+
* 1. **Bitset Filter**: A shared filter where each bit corresponds to a dataset element.
573+
* All queries share the same filter, with a logical shape of `[1, index->size()]`.
574+
* 2. **Bitmap Filter**: A per-query filter with a logical shape of `[n_queries, index->size()]`,
575+
* where each bit indicates whether a specific dataset element should be considered for a
576+
* particular query. (1 for inclusion, 0 for exclusion).
577+
*
578+
* - The default value is `none_sample_filter`, which applies no filtering.
464579
*/
465580
void search(raft::resources const& handle,
466581
const cuvs::neighbors::brute_force::search_params& params,

cpp/include/cuvs/neighbors/common.hpp

Lines changed: 29 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
#include <cstdint>
2020
#include <cuvs/distance/distance.hpp>
21+
#include <raft/core/device_csr_matrix.hpp>
2122
#include <raft/core/device_mdarray.hpp>
2223
#include <raft/core/device_resources.hpp>
2324
#include <raft/core/host_mdspan.hpp>
@@ -456,8 +457,11 @@ inline constexpr bool is_vpq_dataset_v = is_vpq_dataset<DatasetT>::value;
456457

457458
namespace filtering {
458459

460+
enum class FilterType { None, Bitmap, Bitset };
461+
459462
struct base_filter {
460-
virtual ~base_filter() = default;
463+
virtual ~base_filter() = default;
464+
virtual FilterType get_filter_type() const = 0;
461465
};
462466

463467
/* A filter that filters nothing. This is the default behavior. */
@@ -475,6 +479,8 @@ struct none_sample_filter : public base_filter {
475479
const uint32_t query_ix,
476480
// the index of the current sample
477481
const uint32_t sample_ix) const;
482+
483+
FilterType get_filter_type() const override { return FilterType::None; }
478484
};
479485

480486
/**
@@ -513,15 +519,24 @@ struct ivf_to_sample_filter {
513519
*/
514520
template <typename bitmap_t, typename index_t>
515521
struct bitmap_filter : public base_filter {
522+
using view_t = cuvs::core::bitmap_view<bitmap_t, index_t>;
523+
516524
// View of the bitset to use as a filter
517-
const cuvs::core::bitmap_view<bitmap_t, index_t> bitmap_view_;
525+
const view_t bitmap_view_;
518526

519-
bitmap_filter(const cuvs::core::bitmap_view<bitmap_t, index_t> bitmap_for_filtering);
527+
bitmap_filter(const view_t bitmap_for_filtering);
520528
inline _RAFT_HOST_DEVICE bool operator()(
521529
// query index
522530
const uint32_t query_ix,
523531
// the index of the current sample
524532
const uint32_t sample_ix) const;
533+
534+
FilterType get_filter_type() const override { return FilterType::Bitmap; }
535+
536+
view_t view() const { return bitmap_view_; }
537+
538+
template <typename csr_matrix_t>
539+
void to_csr(raft::resources const& handle, csr_matrix_t& csr);
525540
};
526541

527542
/**
@@ -532,15 +547,24 @@ struct bitmap_filter : public base_filter {
532547
*/
533548
template <typename bitset_t, typename index_t>
534549
struct bitset_filter : public base_filter {
550+
using view_t = cuvs::core::bitset_view<bitset_t, index_t>;
551+
535552
// View of the bitset to use as a filter
536-
const cuvs::core::bitset_view<bitset_t, index_t> bitset_view_;
553+
const view_t bitset_view_;
537554

538-
bitset_filter(const cuvs::core::bitset_view<bitset_t, index_t> bitset_for_filtering);
555+
bitset_filter(const view_t bitset_for_filtering);
539556
inline _RAFT_HOST_DEVICE bool operator()(
540557
// query index
541558
const uint32_t query_ix,
542559
// the index of the current sample
543560
const uint32_t sample_ix) const;
561+
562+
FilterType get_filter_type() const override { return FilterType::Bitset; }
563+
564+
view_t view() const { return bitset_view_; }
565+
566+
template <typename csr_matrix_t>
567+
void to_csr(raft::resources const& handle, csr_matrix_t& csr);
544568
};
545569

546570
/**

cpp/src/neighbors/brute_force_c.cpp

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -67,8 +67,8 @@ void _search(cuvsResources_t res,
6767
using queries_mdspan_type = raft::device_matrix_view<T const, int64_t, QueriesLayoutT>;
6868
using neighbors_mdspan_type = raft::device_matrix_view<int64_t, int64_t, raft::row_major>;
6969
using distances_mdspan_type = raft::device_matrix_view<float, int64_t, raft::row_major>;
70-
using prefilter_mds_type = raft::device_vector_view<const uint32_t, int64_t>;
71-
using prefilter_bmp_type = cuvs::core::bitmap_view<const uint32_t, int64_t>;
70+
using prefilter_mds_type = raft::device_vector_view<uint32_t, int64_t>;
71+
using prefilter_bmp_type = cuvs::core::bitmap_view<uint32_t, int64_t>;
7272

7373
auto queries_mds = cuvs::core::from_dlpack<queries_mdspan_type>(queries_tensor);
7474
auto neighbors_mds = cuvs::core::from_dlpack<neighbors_mdspan_type>(neighbors_tensor);
@@ -85,14 +85,14 @@ void _search(cuvsResources_t res,
8585
distances_mds,
8686
cuvs::neighbors::filtering::none_sample_filter{});
8787
} else if (prefilter.type == BITMAP) {
88-
auto prefilter_ptr = reinterpret_cast<DLManagedTensor*>(prefilter.addr);
89-
auto prefilter_mds = cuvs::core::from_dlpack<prefilter_mds_type>(prefilter_ptr);
90-
auto prefilter_view = cuvs::neighbors::filtering::bitmap_filter(
91-
prefilter_bmp_type((const uint32_t*)prefilter_mds.data_handle(),
88+
auto prefilter_ptr = reinterpret_cast<DLManagedTensor*>(prefilter.addr);
89+
auto prefilter_mds = cuvs::core::from_dlpack<prefilter_mds_type>(prefilter_ptr);
90+
const auto prefilter = cuvs::neighbors::filtering::bitmap_filter(
91+
prefilter_bmp_type((uint32_t*)prefilter_mds.data_handle(),
9292
queries_mds.extent(0),
9393
index_ptr->dataset().extent(0)));
9494
cuvs::neighbors::brute_force::search(
95-
*res_ptr, params, *index_ptr, queries_mds, neighbors_mds, distances_mds, prefilter_view);
95+
*res_ptr, params, *index_ptr, queries_mds, neighbors_mds, distances_mds, prefilter);
9696
} else {
9797
RAFT_FAIL("Unsupported prefilter type: BITSET");
9898
}

0 commit comments

Comments
 (0)