-
Notifications
You must be signed in to change notification settings - Fork 143
Moving MG functions into unified API + raft::device_resources_snmg as device resource type for MG functions
#454
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 25 commits
ee98593
3a99b40
e16b68e
cdd5cfb
96e69fc
657bf9e
1fdccd4
d7fff4c
45a41fa
b0c7ab9
db12ab9
d5f1500
db4cf11
2f10065
ddaeee9
3075869
4a422df
39bda6b
886ba3f
028408b
4e0f512
9b6cd88
da48376
d2e2be9
1e1d97e
34d67ee
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -17,8 +17,8 @@ | |
|
|
||
| #include "cuvs_ann_bench_utils.h" | ||
| #include "cuvs_cagra_wrapper.h" | ||
| #include <cuvs/neighbors/mg.hpp> | ||
| #include <raft/core/resource/nccl_clique.hpp> | ||
| #include <cuvs/neighbors/cagra.hpp> | ||
| #include <raft/core/device_resources_snmg.hpp> | ||
|
|
||
| namespace cuvs::bench { | ||
| using namespace cuvs::neighbors; | ||
|
|
@@ -33,21 +33,20 @@ class cuvs_mg_cagra : public algo<T>, public algo_gpu { | |
| using algo<T>::dim_; | ||
|
|
||
| struct build_param : public cuvs::bench::cuvs_cagra<T, IdxT>::build_param { | ||
| cuvs::neighbors::mg::distribution_mode mode; | ||
| cuvs::neighbors::distribution_mode mode; | ||
| }; | ||
|
|
||
| struct search_param : public cuvs::bench::cuvs_cagra<T, IdxT>::search_param { | ||
| cuvs::neighbors::mg::sharded_merge_mode merge_mode; | ||
| cuvs::neighbors::sharded_merge_mode merge_mode; | ||
| }; | ||
|
|
||
| cuvs_mg_cagra(Metric metric, int dim, const build_param& param, int concurrent_searches = 1) | ||
| : algo<T>(metric, dim), index_params_(param) | ||
| : algo<T>(metric, dim), index_params_(param), clique_() | ||
| { | ||
| index_params_.cagra_params.metric = parse_metric_type(metric); | ||
| index_params_.ivf_pq_build_params->metric = parse_metric_type(metric); | ||
|
|
||
| // init nccl clique outside as to not affect benchmark | ||
| const raft::comms::nccl_clique& clique = raft::resource::get_nccl_clique(handle_); | ||
| clique_.set_memory_pool(80); | ||
| } | ||
|
|
||
| void build(const T* dataset, size_t nrow) final; | ||
|
|
@@ -69,7 +68,7 @@ class cuvs_mg_cagra : public algo<T>, public algo_gpu { | |
|
|
||
| [[nodiscard]] auto get_sync_stream() const noexcept -> cudaStream_t override | ||
| { | ||
| auto stream = raft::resource::get_cuda_stream(handle_); | ||
| auto stream = raft::resource::get_cuda_stream(clique_); | ||
| return stream; | ||
| } | ||
|
|
||
|
|
@@ -87,11 +86,11 @@ class cuvs_mg_cagra : public algo<T>, public algo_gpu { | |
| std::unique_ptr<algo<T>> copy() override; | ||
|
|
||
| private: | ||
| raft::device_resources handle_; | ||
| raft::device_resources_snmg clique_; | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We should be able to store
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe it's not a huge deal since this file is specific to mg. Could we at least rename the file to |
||
| float refine_ratio_; | ||
| build_param index_params_; | ||
| cuvs::neighbors::mg::search_params<cagra::search_params> search_params_; | ||
| std::shared_ptr<cuvs::neighbors::mg::index<cuvs::neighbors::cagra::index<T, IdxT>, T, IdxT>> | ||
| cuvs::neighbors::mg_search_params<cagra::search_params> search_params_; | ||
| std::shared_ptr<cuvs::neighbors::mg_index<cuvs::neighbors::cagra::index<T, IdxT>, T, IdxT>> | ||
| index_; | ||
| }; | ||
|
|
||
|
|
@@ -100,14 +99,14 @@ void cuvs_mg_cagra<T, IdxT>::build(const T* dataset, size_t nrow) | |
| { | ||
| auto dataset_extents = raft::make_extents<IdxT>(nrow, dim_); | ||
| index_params_.prepare_build_params(dataset_extents); | ||
| cuvs::neighbors::mg::index_params<cagra::index_params> build_params = index_params_.cagra_params; | ||
| build_params.mode = index_params_.mode; | ||
| cuvs::neighbors::mg_index_params<cagra::index_params> build_params = index_params_.cagra_params; | ||
| build_params.mode = index_params_.mode; | ||
|
|
||
| auto dataset_view = | ||
| raft::make_host_matrix_view<const T, int64_t, raft::row_major>(dataset, nrow, dim_); | ||
| auto idx = cuvs::neighbors::mg::build(handle_, build_params, dataset_view); | ||
| auto idx = cuvs::neighbors::cagra::build(clique_, build_params, dataset_view); | ||
| index_ = | ||
| std::make_shared<cuvs::neighbors::mg::index<cuvs::neighbors::cagra::index<T, IdxT>, T, IdxT>>( | ||
| std::make_shared<cuvs::neighbors::mg_index<cuvs::neighbors::cagra::index<T, IdxT>, T, IdxT>>( | ||
| std::move(idx)); | ||
| } | ||
|
|
||
|
|
@@ -118,8 +117,7 @@ void cuvs_mg_cagra<T, IdxT>::set_search_param(const search_param_base& param, | |
| const void* filter_bitset) | ||
| { | ||
| if (filter_bitset != nullptr) { throw std::runtime_error("Filtering is not supported yet."); } | ||
| auto sp = dynamic_cast<const search_param&>(param); | ||
| // search_params_ = static_cast<mg::search_params<cagra::search_params>>(sp.p); | ||
| auto sp = dynamic_cast<const search_param&>(param); | ||
| cagra::search_params* search_params_ptr_ = static_cast<cagra::search_params*>(&search_params_); | ||
| *search_params_ptr_ = sp.p; | ||
| search_params_.merge_mode = sp.merge_mode; | ||
|
|
@@ -134,15 +132,15 @@ void cuvs_mg_cagra<T, IdxT>::set_search_dataset(const T* dataset, size_t nrow) | |
| template <typename T, typename IdxT> | ||
| void cuvs_mg_cagra<T, IdxT>::save(const std::string& file) const | ||
| { | ||
| cuvs::neighbors::mg::serialize(handle_, *index_, file); | ||
| cuvs::neighbors::cagra::serialize(clique_, *index_, file); | ||
| } | ||
|
|
||
| template <typename T, typename IdxT> | ||
| void cuvs_mg_cagra<T, IdxT>::load(const std::string& file) | ||
| { | ||
| index_ = | ||
| std::make_shared<cuvs::neighbors::mg::index<cuvs::neighbors::cagra::index<T, IdxT>, T, IdxT>>( | ||
| std::move(cuvs::neighbors::mg::deserialize_cagra<T, IdxT>(handle_, file))); | ||
| std::make_shared<cuvs::neighbors::mg_index<cuvs::neighbors::cagra::index<T, IdxT>, T, IdxT>>( | ||
| std::move(cuvs::neighbors::cagra::deserialize<T, IdxT>(clique_, file))); | ||
| } | ||
|
|
||
| template <typename T, typename IdxT> | ||
|
|
@@ -165,8 +163,8 @@ void cuvs_mg_cagra<T, IdxT>::search_base( | |
| auto distances_view = | ||
| raft::make_host_matrix_view<float, int64_t, raft::row_major>(distances, batch_size, k); | ||
|
|
||
| cuvs::neighbors::mg::search( | ||
| handle_, *index_, search_params_, queries_view, neighbors_view, distances_view); | ||
| cuvs::neighbors::cagra::search( | ||
| clique_, *index_, search_params_, queries_view, neighbors_view, distances_view); | ||
| } | ||
|
|
||
| template <typename T, typename IdxT> | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.