-
Notifications
You must be signed in to change notification settings - Fork 143
[Feat] Support bitset filter for Brute Force
#560
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 19 commits
1ba31da
4e30bd2
cbc5d38
3a5d4e0
8a45192
e79b1e3
8c0031a
4c53846
4a53e94
85d2dfc
5ef5bc5
f53d1ce
9beb58f
36bae13
1fcc7de
b58f2a5
6c7b583
7c4d50e
3ecccfb
6cc5059
4243fb4
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -18,6 +18,7 @@ | |
|
|
||
| #include <cstdint> | ||
| #include <cuvs/distance/distance.hpp> | ||
| #include <raft/core/device_csr_matrix.hpp> | ||
| #include <raft/core/device_mdarray.hpp> | ||
| #include <raft/core/device_resources.hpp> | ||
| #include <raft/core/host_mdspan.hpp> | ||
|
|
@@ -456,8 +457,11 @@ inline constexpr bool is_vpq_dataset_v = is_vpq_dataset<DatasetT>::value; | |
|
|
||
| namespace filtering { | ||
|
|
||
| enum class FilterType { None, Bitmap, Bitset }; | ||
|
|
||
| struct base_filter { | ||
| virtual ~base_filter() = default; | ||
| virtual ~base_filter() = default; | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I notice no changes have been made to
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I’ve just added the comments. I believe using bitset as the default setting might not be ideal if we don't have enough input from end-users. Perhaps we should discuss this in the team group, as I noticed that the none filter is also set as the default in CAGRA.
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
I think you may have misunderstood me. The none filter is fine as the default for the the search functions, but for the code example in the docs, we should make sure we use a bitset and leave bitmap to users who need it. FAISS doesn't even support a bitmap and users aren't asking for it generally. It's good to keep it exposed for users who might need it. |
||
| virtual FilterType get_filter_type() const = 0; | ||
| }; | ||
|
|
||
| /* A filter that filters nothing. This is the default behavior. */ | ||
|
|
@@ -475,6 +479,8 @@ struct none_sample_filter : public base_filter { | |
| const uint32_t query_ix, | ||
| // the index of the current sample | ||
| const uint32_t sample_ix) const; | ||
|
|
||
| FilterType get_filter_type() const override { return FilterType::None; } | ||
| }; | ||
|
|
||
| /** | ||
|
|
@@ -513,15 +519,24 @@ struct ivf_to_sample_filter { | |
| */ | ||
| template <typename bitmap_t, typename index_t> | ||
| struct bitmap_filter : public base_filter { | ||
| using view_t = cuvs::core::bitmap_view<bitmap_t, index_t>; | ||
|
|
||
| // View of the bitset to use as a filter | ||
| const cuvs::core::bitmap_view<bitmap_t, index_t> bitmap_view_; | ||
| const view_t bitmap_view_; | ||
|
|
||
| bitmap_filter(const cuvs::core::bitmap_view<bitmap_t, index_t> bitmap_for_filtering); | ||
| bitmap_filter(const view_t bitmap_for_filtering); | ||
| inline _RAFT_HOST_DEVICE bool operator()( | ||
| // query index | ||
| const uint32_t query_ix, | ||
| // the index of the current sample | ||
| const uint32_t sample_ix) const; | ||
|
|
||
| FilterType get_filter_type() const override { return FilterType::Bitmap; } | ||
|
|
||
| view_t view() const { return bitmap_view_; } | ||
|
|
||
| template <typename csr_matrix_t> | ||
| void to_csr(raft::resources const& handle, csr_matrix_t& csr); | ||
| }; | ||
|
|
||
| /** | ||
|
|
@@ -532,15 +547,24 @@ struct bitmap_filter : public base_filter { | |
| */ | ||
| template <typename bitset_t, typename index_t> | ||
| struct bitset_filter : public base_filter { | ||
| using view_t = cuvs::core::bitset_view<bitset_t, index_t>; | ||
|
|
||
| // View of the bitset to use as a filter | ||
| const cuvs::core::bitset_view<bitset_t, index_t> bitset_view_; | ||
| const view_t bitset_view_; | ||
|
|
||
| bitset_filter(const cuvs::core::bitset_view<bitset_t, index_t> bitset_for_filtering); | ||
| bitset_filter(const view_t bitset_for_filtering); | ||
| inline _RAFT_HOST_DEVICE bool operator()( | ||
| // query index | ||
| const uint32_t query_ix, | ||
| // the index of the current sample | ||
| const uint32_t sample_ix) const; | ||
|
|
||
| FilterType get_filter_type() const override { return FilterType::Bitset; } | ||
|
|
||
| view_t view() const { return bitset_view_; } | ||
|
|
||
| template <typename csr_matrix_t> | ||
| void to_csr(raft::resources const& handle, csr_matrix_t& csr); | ||
| }; | ||
|
|
||
| /** | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -67,8 +67,8 @@ void _search(cuvsResources_t res, | |
| using queries_mdspan_type = raft::device_matrix_view<T const, int64_t, QueriesLayoutT>; | ||
| using neighbors_mdspan_type = raft::device_matrix_view<int64_t, int64_t, raft::row_major>; | ||
| using distances_mdspan_type = raft::device_matrix_view<float, int64_t, raft::row_major>; | ||
| using prefilter_mds_type = raft::device_vector_view<const uint32_t, int64_t>; | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we want to keep the filter immutable, don' we?
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This change is to be compatible with the
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We are using |
||
| using prefilter_bmp_type = cuvs::core::bitmap_view<const uint32_t, int64_t>; | ||
| using prefilter_mds_type = raft::device_vector_view<uint32_t, int64_t>; | ||
| using prefilter_bmp_type = cuvs::core::bitmap_view<uint32_t, int64_t>; | ||
|
|
||
| auto queries_mds = cuvs::core::from_dlpack<queries_mdspan_type>(queries_tensor); | ||
| auto neighbors_mds = cuvs::core::from_dlpack<neighbors_mdspan_type>(neighbors_tensor); | ||
|
|
@@ -85,14 +85,14 @@ void _search(cuvsResources_t res, | |
| distances_mds, | ||
| cuvs::neighbors::filtering::none_sample_filter{}); | ||
| } else if (prefilter.type == BITMAP) { | ||
| auto prefilter_ptr = reinterpret_cast<DLManagedTensor*>(prefilter.addr); | ||
| auto prefilter_mds = cuvs::core::from_dlpack<prefilter_mds_type>(prefilter_ptr); | ||
| auto prefilter_view = cuvs::neighbors::filtering::bitmap_filter( | ||
| prefilter_bmp_type((const uint32_t*)prefilter_mds.data_handle(), | ||
| auto prefilter_ptr = reinterpret_cast<DLManagedTensor*>(prefilter.addr); | ||
| auto prefilter_mds = cuvs::core::from_dlpack<prefilter_mds_type>(prefilter_ptr); | ||
| const auto prefilter = cuvs::neighbors::filtering::bitmap_filter( | ||
| prefilter_bmp_type((uint32_t*)prefilter_mds.data_handle(), | ||
| queries_mds.extent(0), | ||
| index_ptr->dataset().extent(0))); | ||
| cuvs::neighbors::brute_force::search( | ||
| *res_ptr, params, *index_ptr, queries_mds, neighbors_mds, distances_mds, prefilter_view); | ||
| *res_ptr, params, *index_ptr, queries_mds, neighbors_mds, distances_mds, prefilter); | ||
| } else { | ||
| RAFT_FAIL("Unsupported prefilter type: BITSET"); | ||
| } | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we reorder these please? I think we should be pushing bitset as the first option.