Skip to content

Commit 0dd7bde

Browse files
authored
Fix cagra_hnsw serialization when dataset is not part of index (#591)
After calling `build()`, ideally the CAGRA index contains both the dataset and the graph. But when we do not have sufficient device memory, then only the graph is returned. In such case we need to pass the dataset explicitly to the serialization routines. For serialization in HNSW format, in case we have flat hierarchy, the dataset was not passed. This PR fixes this problem by adding an optional `dataset` argument to `cagra::serialize_to_hnswlib`. Furthermore, to improve execution time, we change from writing a single element to writing a single row of the graph and dataset at time. Additionally, debug messages for tracking data saving time are added. Authors: - Tamas Bela Feher (https://github.com/tfeher) - Corey J. Nolet (https://github.com/cjnolet) Approvers: - Divye Gala (https://github.com/divyegala) URL: #591
1 parent 836183e commit 0dd7bde

File tree

5 files changed

+191
-102
lines changed

5 files changed

+191
-102
lines changed

cpp/bench/ann/src/cuvs/cuvs_cagra_hnswlib_wrapper.h

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,9 @@
1717

1818
#include "cuvs_cagra_wrapper.h"
1919
#include <cuvs/neighbors/hnsw.hpp>
20+
#include <raft/core/logger.hpp>
2021

22+
#include <chrono>
2123
#include <memory>
2224

2325
namespace cuvs::bench {
@@ -90,8 +92,13 @@ void cuvs_cagra_hnswlib<T, IdxT>::build(const T* dataset, size_t nrow)
9092
auto host_dataset_view = raft::make_host_matrix_view<const T, int64_t>(dataset, nrow, this->dim_);
9193
auto opt_dataset_view =
9294
std::optional<raft::host_matrix_view<const T, int64_t>>(std::move(host_dataset_view));
93-
hnsw_index_ = cuvs::neighbors::hnsw::from_cagra(
95+
const auto start_clock = std::chrono::system_clock::now();
96+
hnsw_index_ = cuvs::neighbors::hnsw::from_cagra(
9497
handle_, build_param_.hnsw_index_params, *cagra_index, opt_dataset_view);
98+
int time =
99+
std::chrono::duration_cast<std::chrono::seconds>(std::chrono::system_clock::now() - start_clock)
100+
.count();
101+
RAFT_LOG_DEBUG("Graph saved to HNSW format in %d:%d min", time / 60, time % 60);
95102
}
96103

97104
template <typename T, typename IdxT>

cpp/include/cuvs/neighbors/cagra.hpp

Lines changed: 48 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1599,11 +1599,16 @@ void deserialize(raft::resources const& handle,
15991599
* @param[in] handle the raft handle
16001600
* @param[in] os output stream
16011601
* @param[in] index CAGRA index
1602+
* @param[in] dataset [optional] host array that stores the dataset, required if the index
1603+
* does not contain the dataset.
16021604
*
16031605
*/
1604-
void serialize_to_hnswlib(raft::resources const& handle,
1605-
std::ostream& os,
1606-
const cuvs::neighbors::cagra::index<float, uint32_t>& index);
1606+
void serialize_to_hnswlib(
1607+
raft::resources const& handle,
1608+
std::ostream& os,
1609+
const cuvs::neighbors::cagra::index<float, uint32_t>& index,
1610+
std::optional<raft::host_matrix_view<const float, int64_t, raft::row_major>> dataset =
1611+
std::nullopt);
16071612

16081613
/**
16091614
* Save a CAGRA build index in hnswlib base-layer-only serialized format
@@ -1628,11 +1633,16 @@ void serialize_to_hnswlib(raft::resources const& handle,
16281633
* @param[in] handle the raft handle
16291634
* @param[in] filename the file name for saving the index
16301635
* @param[in] index CAGRA index
1636+
* @param[in] dataset [optional] host array that stores the dataset, required if the index
1637+
* does not contain the dataset.
16311638
*
16321639
*/
1633-
void serialize_to_hnswlib(raft::resources const& handle,
1634-
const std::string& filename,
1635-
const cuvs::neighbors::cagra::index<float, uint32_t>& index);
1640+
void serialize_to_hnswlib(
1641+
raft::resources const& handle,
1642+
const std::string& filename,
1643+
const cuvs::neighbors::cagra::index<float, uint32_t>& index,
1644+
std::optional<raft::host_matrix_view<const float, int64_t, raft::row_major>> dataset =
1645+
std::nullopt);
16361646

16371647
/**
16381648
* Write the CAGRA built index as a base layer HNSW index to an output stream
@@ -1656,11 +1666,16 @@ void serialize_to_hnswlib(raft::resources const& handle,
16561666
* @param[in] handle the raft handle
16571667
* @param[in] os output stream
16581668
* @param[in] index CAGRA index
1669+
* @param[in] dataset [optional] host array that stores the dataset, required if the index
1670+
* does not contain the dataset.
16591671
*
16601672
*/
1661-
void serialize_to_hnswlib(raft::resources const& handle,
1662-
std::ostream& os,
1663-
const cuvs::neighbors::cagra::index<int8_t, uint32_t>& index);
1673+
void serialize_to_hnswlib(
1674+
raft::resources const& handle,
1675+
std::ostream& os,
1676+
const cuvs::neighbors::cagra::index<int8_t, uint32_t>& index,
1677+
std::optional<raft::host_matrix_view<const int8_t, int64_t, raft::row_major>> dataset =
1678+
std::nullopt);
16641679

16651680
/**
16661681
* Save a CAGRA build index in hnswlib base-layer-only serialized format
@@ -1685,11 +1700,16 @@ void serialize_to_hnswlib(raft::resources const& handle,
16851700
* @param[in] handle the raft handle
16861701
* @param[in] filename the file name for saving the index
16871702
* @param[in] index CAGRA index
1703+
* @param[in] dataset [optional] host array that stores the dataset, required if the index
1704+
* does not contain the dataset.
16881705
*
16891706
*/
1690-
void serialize_to_hnswlib(raft::resources const& handle,
1691-
const std::string& filename,
1692-
const cuvs::neighbors::cagra::index<int8_t, uint32_t>& index);
1707+
void serialize_to_hnswlib(
1708+
raft::resources const& handle,
1709+
const std::string& filename,
1710+
const cuvs::neighbors::cagra::index<int8_t, uint32_t>& index,
1711+
std::optional<raft::host_matrix_view<const int8_t, int64_t, raft::row_major>> dataset =
1712+
std::nullopt);
16931713

16941714
/**
16951715
* Write the CAGRA built index as a base layer HNSW index to an output stream
@@ -1713,11 +1733,16 @@ void serialize_to_hnswlib(raft::resources const& handle,
17131733
* @param[in] handle the raft handle
17141734
* @param[in] os output stream
17151735
* @param[in] index CAGRA index
1736+
* @param[in] dataset [optional] host array that stores the dataset, required if the index
1737+
* does not contain the dataset.
17161738
*
17171739
*/
1718-
void serialize_to_hnswlib(raft::resources const& handle,
1719-
std::ostream& os,
1720-
const cuvs::neighbors::cagra::index<uint8_t, uint32_t>& index);
1740+
void serialize_to_hnswlib(
1741+
raft::resources const& handle,
1742+
std::ostream& os,
1743+
const cuvs::neighbors::cagra::index<uint8_t, uint32_t>& index,
1744+
std::optional<raft::host_matrix_view<const uint8_t, int64_t, raft::row_major>> dataset =
1745+
std::nullopt);
17211746

17221747
/**
17231748
* Save a CAGRA build index in hnswlib base-layer-only serialized format
@@ -1742,11 +1767,16 @@ void serialize_to_hnswlib(raft::resources const& handle,
17421767
* @param[in] handle the raft handle
17431768
* @param[in] filename the file name for saving the index
17441769
* @param[in] index CAGRA index
1770+
* @param[in] dataset [optional] host array that stores the dataset, required if the index
1771+
* does not contain the dataset.
17451772
*
17461773
*/
1747-
void serialize_to_hnswlib(raft::resources const& handle,
1748-
const std::string& filename,
1749-
const cuvs::neighbors::cagra::index<uint8_t, uint32_t>& index);
1774+
void serialize_to_hnswlib(
1775+
raft::resources const& handle,
1776+
const std::string& filename,
1777+
const cuvs::neighbors::cagra::index<uint8_t, uint32_t>& index,
1778+
std::optional<raft::host_matrix_view<const uint8_t, int64_t, raft::row_major>> dataset =
1779+
std::nullopt);
17501780

17511781
/**
17521782
* @}

cpp/src/neighbors/cagra_serialize.cuh

Lines changed: 50 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -20,51 +20,56 @@
2020

2121
namespace cuvs::neighbors::cagra {
2222

23-
#define CUVS_INST_CAGRA_SERIALIZE(DTYPE) \
24-
void serialize(raft::resources const& handle, \
25-
const std::string& filename, \
26-
const cuvs::neighbors::cagra::index<DTYPE, uint32_t>& index, \
27-
bool include_dataset) \
28-
{ \
29-
cuvs::neighbors::cagra::detail::serialize<DTYPE, uint32_t>( \
30-
handle, filename, index, include_dataset); \
31-
}; \
32-
\
33-
void deserialize(raft::resources const& handle, \
34-
const std::string& filename, \
35-
cuvs::neighbors::cagra::index<DTYPE, uint32_t>* index) \
36-
{ \
37-
cuvs::neighbors::cagra::detail::deserialize<DTYPE, uint32_t>(handle, filename, index); \
38-
}; \
39-
void serialize(raft::resources const& handle, \
40-
std::ostream& os, \
41-
const cuvs::neighbors::cagra::index<DTYPE, uint32_t>& index, \
42-
bool include_dataset) \
43-
{ \
44-
cuvs::neighbors::cagra::detail::serialize<DTYPE, uint32_t>( \
45-
handle, os, index, include_dataset); \
46-
} \
47-
\
48-
void deserialize(raft::resources const& handle, \
49-
std::istream& is, \
50-
cuvs::neighbors::cagra::index<DTYPE, uint32_t>* index) \
51-
{ \
52-
cuvs::neighbors::cagra::detail::deserialize<DTYPE, uint32_t>(handle, is, index); \
53-
} \
54-
\
55-
void serialize_to_hnswlib(raft::resources const& handle, \
56-
std::ostream& os, \
57-
const cuvs::neighbors::cagra::index<DTYPE, uint32_t>& index) \
58-
{ \
59-
cuvs::neighbors::cagra::detail::serialize_to_hnswlib<DTYPE, uint32_t>(handle, os, index); \
60-
} \
61-
\
62-
void serialize_to_hnswlib(raft::resources const& handle, \
63-
const std::string& filename, \
64-
const cuvs::neighbors::cagra::index<DTYPE, uint32_t>& index) \
65-
{ \
66-
cuvs::neighbors::cagra::detail::serialize_to_hnswlib<DTYPE, uint32_t>( \
67-
handle, filename, index); \
23+
#define CUVS_INST_CAGRA_SERIALIZE(DTYPE) \
24+
void serialize(raft::resources const& handle, \
25+
const std::string& filename, \
26+
const cuvs::neighbors::cagra::index<DTYPE, uint32_t>& index, \
27+
bool include_dataset) \
28+
{ \
29+
cuvs::neighbors::cagra::detail::serialize<DTYPE, uint32_t>( \
30+
handle, filename, index, include_dataset); \
31+
}; \
32+
\
33+
void deserialize(raft::resources const& handle, \
34+
const std::string& filename, \
35+
cuvs::neighbors::cagra::index<DTYPE, uint32_t>* index) \
36+
{ \
37+
cuvs::neighbors::cagra::detail::deserialize<DTYPE, uint32_t>(handle, filename, index); \
38+
}; \
39+
void serialize(raft::resources const& handle, \
40+
std::ostream& os, \
41+
const cuvs::neighbors::cagra::index<DTYPE, uint32_t>& index, \
42+
bool include_dataset) \
43+
{ \
44+
cuvs::neighbors::cagra::detail::serialize<DTYPE, uint32_t>( \
45+
handle, os, index, include_dataset); \
46+
} \
47+
\
48+
void deserialize(raft::resources const& handle, \
49+
std::istream& is, \
50+
cuvs::neighbors::cagra::index<DTYPE, uint32_t>* index) \
51+
{ \
52+
cuvs::neighbors::cagra::detail::deserialize<DTYPE, uint32_t>(handle, is, index); \
53+
} \
54+
\
55+
void serialize_to_hnswlib( \
56+
raft::resources const& handle, \
57+
std::ostream& os, \
58+
const cuvs::neighbors::cagra::index<DTYPE, uint32_t>& index, \
59+
std::optional<raft::host_matrix_view<const DTYPE, int64_t, raft::row_major>> dataset) \
60+
{ \
61+
cuvs::neighbors::cagra::detail::serialize_to_hnswlib<DTYPE, uint32_t>( \
62+
handle, os, index, dataset); \
63+
} \
64+
\
65+
void serialize_to_hnswlib( \
66+
raft::resources const& handle, \
67+
const std::string& filename, \
68+
const cuvs::neighbors::cagra::index<DTYPE, uint32_t>& index, \
69+
std::optional<raft::host_matrix_view<const DTYPE, int64_t, raft::row_major>> dataset) \
70+
{ \
71+
cuvs::neighbors::cagra::detail::serialize_to_hnswlib<DTYPE, uint32_t>( \
72+
handle, filename, index, dataset); \
6873
}
6974

7075
} // namespace cuvs::neighbors::cagra

0 commit comments

Comments
 (0)