Skip to content

Commit 6d4e2bf

Browse files
authored
Consolidate Index Constructors (#418)
* initial commit * updating python bindings to use new ctor * python binding error fix * error fix * reverting some changes -> experiment * removing redundnt code from native index * python build error fix * tyring to resolve python build error * attempt at python build fix * adding IndexSearchParams * setting search threads to non zero * minor check removed * eperiment 3-> making distance fully owned by data_store * exp 3 clang fix * exp 4 * making distance as unique_ptr * trying to fix build * finally fixing problem * some minor fix * adding dll export to index_factory static function * adding dll export for static fn in index_factory * code cleanup * resolving gopal's comments * resolving build failures
1 parent 977dd3c commit 6d4e2bf

20 files changed

+194
-223
lines changed

apps/build_memory_index.cpp

Lines changed: 0 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -22,50 +22,6 @@
2222

2323
namespace po = boost::program_options;
2424

25-
template <typename T, typename TagT = uint32_t, typename LabelT = uint32_t>
26-
int build_in_memory_index(const diskann::Metric &metric, const std::string &data_path, const uint32_t R,
27-
const uint32_t L, const float alpha, const std::string &save_path, const uint32_t num_threads,
28-
const bool use_pq_build, const size_t num_pq_bytes, const bool use_opq,
29-
const std::string &label_file, const std::string &universal_label, const uint32_t Lf)
30-
{
31-
diskann::IndexWriteParameters paras = diskann::IndexWriteParametersBuilder(L, R)
32-
.with_filter_list_size(Lf)
33-
.with_alpha(alpha)
34-
.with_saturate_graph(false)
35-
.with_num_threads(num_threads)
36-
.build();
37-
std::string labels_file_to_use = save_path + "_label_formatted.txt";
38-
std::string mem_labels_int_map_file = save_path + "_labels_map.txt";
39-
40-
size_t data_num, data_dim;
41-
diskann::get_bin_metadata(data_path, data_num, data_dim);
42-
43-
diskann::Index<T, TagT, LabelT> index(metric, data_dim, data_num, false, false, false, use_pq_build, num_pq_bytes,
44-
use_opq);
45-
auto s = std::chrono::high_resolution_clock::now();
46-
if (label_file == "")
47-
{
48-
index.build(data_path.c_str(), data_num, paras);
49-
}
50-
else
51-
{
52-
convert_labels_string_to_int(label_file, labels_file_to_use, mem_labels_int_map_file, universal_label);
53-
if (universal_label != "")
54-
{
55-
LabelT unv_label_as_num = 0;
56-
index.set_universal_label(unv_label_as_num);
57-
}
58-
index.build_filtered_index(data_path.c_str(), labels_file_to_use, data_num, paras);
59-
}
60-
std::chrono::duration<double> diff = std::chrono::high_resolution_clock::now() - s;
61-
62-
std::cout << "Indexing time: " << diff.count() << "\n";
63-
index.save(save_path.c_str());
64-
if (label_file != "")
65-
std::remove(labels_file_to_use.c_str());
66-
return 0;
67-
}
68-
6925
int main(int argc, char **argv)
7026
{
7127
std::string data_type, dist_fn, data_path, index_path_prefix, label_file, universal_label, label_type;

apps/build_stitched_index.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -285,7 +285,7 @@ void prune_and_save(path final_index_path_prefix, path full_index_path_prefix, p
285285
auto pruning_index_timer = std::chrono::high_resolution_clock::now();
286286

287287
diskann::get_bin_metadata(input_data_path, number_of_label_points, dimension);
288-
diskann::Index<T> index(diskann::Metric::L2, dimension, number_of_label_points, false, false);
288+
diskann::Index<T> index(diskann::Metric::L2, dimension, number_of_label_points, nullptr, nullptr, 0, false, false);
289289

290290
// not searching this index, set search_l to 0
291291
index.load(full_index_path_prefix.c_str(), num_threads, 1);

apps/test_insert_deletes_consolidate.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -152,14 +152,14 @@ void build_incremental_index(const std::string &data_path, diskann::IndexWritePa
152152
using TagT = uint32_t;
153153
auto data_type = diskann_type_to_name<T>();
154154
auto tag_type = diskann_type_to_name<TagT>();
155+
auto index_search_params = diskann::IndexSearchParams(params.search_list_size, params.num_threads);
155156
diskann::IndexConfig index_config = diskann::IndexConfigBuilder()
156157
.with_metric(diskann::L2)
157158
.with_dimension(dim)
158159
.with_max_points(max_points_to_insert)
159160
.is_dynamic_index(true)
160161
.with_index_write_params(params)
161-
.with_search_threads(params.num_threads)
162-
.with_initial_search_list_size(params.search_list_size)
162+
.with_index_search_params(index_search_params)
163163
.with_data_type(data_type)
164164
.with_tag_type(tag_type)
165165
.with_data_load_store_strategy(diskann::MEMORY)

apps/test_streaming_scenario.cpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,7 @@ void build_incremental_index(const std::string &data_path, const uint32_t L, con
186186
.with_num_frozen_points(num_start_pts)
187187
.build();
188188

189+
auto index_search_params = diskann::IndexSearchParams(L, insert_threads);
189190
diskann::IndexWriteParameters delete_params = diskann::IndexWriteParametersBuilder(L, R)
190191
.with_max_occlusion_size(C)
191192
.with_alpha(alpha)
@@ -200,7 +201,6 @@ void build_incremental_index(const std::string &data_path, const uint32_t L, con
200201
diskann::cout << "metadata: file " << data_path << " has " << num_points << " points in " << dim << " dims"
201202
<< std::endl;
202203
aligned_dim = ROUND_UP(dim, 8);
203-
204204
auto index_config = diskann::IndexConfigBuilder()
205205
.with_metric(diskann::L2)
206206
.with_dimension(dim)
@@ -210,12 +210,11 @@ void build_incremental_index(const std::string &data_path, const uint32_t L, con
210210
.is_use_opq(false)
211211
.with_num_pq_chunks(0)
212212
.is_pq_dist_build(false)
213-
.with_search_threads(insert_threads)
214-
.with_initial_search_list_size(L)
215213
.with_tag_type(diskann_type_to_name<TagT>())
216214
.with_label_type(diskann_type_to_name<LabelT>())
217215
.with_data_type(diskann_type_to_name<T>())
218216
.with_index_write_params(params)
217+
.with_index_search_params(index_search_params)
219218
.with_data_load_store_strategy(diskann::MEMORY)
220219
.build();
221220

apps/utils/count_bfs_levels.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ template <typename T> void bfs_count(const std::string &index_path, uint32_t dat
2727
{
2828
using TagT = uint32_t;
2929
using LabelT = uint32_t;
30-
diskann::Index<T, TagT, LabelT> index(diskann::Metric::L2, data_dims, 0, false, false);
30+
diskann::Index<T, TagT, LabelT> index(diskann::Metric::L2, data_dims, 0, nullptr, nullptr, 0, false, false);
3131
std::cout << "Index class instantiated" << std::endl;
3232
index.load(index_path.c_str(), 1, 100);
3333
std::cout << "Index loaded" << std::endl;

include/in_mem_data_store.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ namespace diskann
2121
template <typename data_t> class InMemDataStore : public AbstractDataStore<data_t>
2222
{
2323
public:
24-
InMemDataStore(const location_t capacity, const size_t dim, std::shared_ptr<Distance<data_t>> distance_fn);
24+
InMemDataStore(const location_t capacity, const size_t dim, std::unique_ptr<Distance<data_t>> distance_fn);
2525
virtual ~InMemDataStore();
2626

2727
virtual location_t load(const std::string &filename) override;
@@ -73,7 +73,7 @@ template <typename data_t> class InMemDataStore : public AbstractDataStore<data_
7373
// but this gives us perf benefits as the datastore can do distance
7474
// computations during search and compute norms of vectors internally without
7575
// have to copy data back and forth.
76-
std::shared_ptr<Distance<data_t>> _distance_fn;
76+
std::unique_ptr<Distance<data_t>> _distance_fn;
7777

7878
// in case we need to save vector norms for optimization
7979
std::shared_ptr<float[]> _pre_computed_norms;

include/index.h

Lines changed: 7 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -49,21 +49,16 @@ template <typename T, typename TagT = uint32_t, typename LabelT = uint32_t> clas
4949
**************************************************************************/
5050

5151
public:
52-
// Constructor for Bulk operations and for creating the index object solely
53-
// for loading a prexisting index.
54-
DISKANN_DLLEXPORT Index(Metric m, const size_t dim, const size_t max_points = 1, const bool dynamic_index = false,
52+
// Call this when creating and passing Index Config is inconvenient.
53+
DISKANN_DLLEXPORT Index(Metric m, const size_t dim, const size_t max_points,
54+
const std::shared_ptr<IndexWriteParameters> index_parameters,
55+
const std::shared_ptr<IndexSearchParams> index_search_params,
56+
const size_t num_frozen_pts = 0, const bool dynamic_index = false,
5557
const bool enable_tags = false, const bool concurrent_consolidate = false,
5658
const bool pq_dist_build = false, const size_t num_pq_chunks = 0,
57-
const bool use_opq = false, const size_t num_frozen_pts = 0,
58-
const bool init_data_store = true);
59-
60-
// Constructor for incremental index
61-
DISKANN_DLLEXPORT Index(Metric m, const size_t dim, const size_t max_points, const bool dynamic_index,
62-
const IndexWriteParameters &indexParameters, const uint32_t initial_search_list_size,
63-
const uint32_t search_threads, const bool enable_tags = false,
64-
const bool concurrent_consolidate = false, const bool pq_dist_build = false,
65-
const size_t num_pq_chunks = 0, const bool use_opq = false);
59+
const bool use_opq = false);
6660

61+
// This is called by IndexFactory which returns AbstractIndex's simplified API
6762
DISKANN_DLLEXPORT Index(const IndexConfig &index_config, std::unique_ptr<AbstractDataStore<T>> data_store
6863
/* std::unique_ptr<AbstractGraphStore> graph_store*/);
6964

@@ -329,7 +324,6 @@ template <typename T, typename TagT = uint32_t, typename LabelT = uint32_t> clas
329324
private:
330325
// Distance functions
331326
Metric _dist_metric = diskann::L2;
332-
std::shared_ptr<Distance<T>> _distance;
333327

334328
// Data
335329
std::unique_ptr<AbstractDataStore<T>> _data_store;

include/index_config.h

Lines changed: 35 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -33,24 +33,23 @@ struct IndexConfig
3333
std::string tag_type;
3434
std::string data_type;
3535

36+
// Params for building index
3637
std::shared_ptr<IndexWriteParameters> index_write_params;
37-
38-
uint32_t search_threads;
39-
uint32_t initial_search_list_size;
38+
// Params for searching index
39+
std::shared_ptr<IndexSearchParams> index_search_params;
4040

4141
private:
4242
IndexConfig(DataStoreStrategy data_strategy, GraphStoreStrategy graph_strategy, Metric metric, size_t dimension,
4343
size_t max_points, size_t num_pq_chunks, size_t num_frozen_points, bool dynamic_index, bool enable_tags,
4444
bool pq_dist_build, bool concurrent_consolidate, bool use_opq, const std::string &data_type,
4545
const std::string &tag_type, const std::string &label_type,
46-
std::shared_ptr<IndexWriteParameters> index_write_params, uint32_t search_threads,
47-
uint32_t initial_search_list_size)
46+
std::shared_ptr<IndexWriteParameters> index_write_params,
47+
std::shared_ptr<IndexSearchParams> index_search_params)
4848
: data_strategy(data_strategy), graph_strategy(graph_strategy), metric(metric), dimension(dimension),
4949
max_points(max_points), dynamic_index(dynamic_index), enable_tags(enable_tags), pq_dist_build(pq_dist_build),
5050
concurrent_consolidate(concurrent_consolidate), use_opq(use_opq), num_pq_chunks(num_pq_chunks),
5151
num_frozen_pts(num_frozen_points), label_type(label_type), tag_type(tag_type), data_type(data_type),
52-
index_write_params(index_write_params), search_threads(search_threads),
53-
initial_search_list_size(initial_search_list_size)
52+
index_write_params(index_write_params), index_search_params(index_search_params)
5453
{
5554
}
5655

@@ -60,9 +59,7 @@ struct IndexConfig
6059
class IndexConfigBuilder
6160
{
6261
public:
63-
IndexConfigBuilder()
64-
{
65-
}
62+
IndexConfigBuilder() = default;
6663

6764
IndexConfigBuilder &with_metric(Metric m)
6865
{
@@ -160,15 +157,31 @@ class IndexConfigBuilder
160157
return *this;
161158
}
162159

163-
IndexConfigBuilder &with_search_threads(uint32_t search_threads)
160+
IndexConfigBuilder &with_index_write_params(std::shared_ptr<IndexWriteParameters> index_write_params_ptr)
161+
{
162+
if (index_write_params_ptr == nullptr)
163+
{
164+
diskann::cout << "Passed, empty build_params while creating index config" << std::endl;
165+
return *this;
166+
}
167+
this->_index_write_params = index_write_params_ptr;
168+
return *this;
169+
}
170+
171+
IndexConfigBuilder &with_index_search_params(IndexSearchParams &search_params)
164172
{
165-
this->_search_threads = search_threads;
173+
this->_index_search_params = std::make_shared<IndexSearchParams>(search_params);
166174
return *this;
167175
}
168176

169-
IndexConfigBuilder &with_initial_search_list_size(uint32_t search_list_size)
177+
IndexConfigBuilder &with_index_search_params(std::shared_ptr<IndexSearchParams> search_params_ptr)
170178
{
171-
this->_initial_search_list_size = search_list_size;
179+
if (search_params_ptr == nullptr)
180+
{
181+
diskann::cout << "Passed, empty search_params while creating index config" << std::endl;
182+
return *this;
183+
}
184+
this->_index_search_params = search_params_ptr;
172185
return *this;
173186
}
174187

@@ -177,19 +190,20 @@ class IndexConfigBuilder
177190
if (_data_type == "" || _data_type.empty())
178191
throw ANNException("Error: data_type can not be empty", -1);
179192

180-
if (_dynamic_index && _index_write_params != nullptr)
193+
if (_dynamic_index && _num_frozen_pts == 0)
181194
{
182-
if (_search_threads == 0)
183-
throw ANNException("Error: please pass search_threads for building dynamic index.", -1);
195+
_num_frozen_pts = 1;
196+
}
184197

185-
if (_initial_search_list_size == 0)
198+
if (_dynamic_index)
199+
{
200+
if (_index_search_params != nullptr && _index_search_params->initial_search_list_size == 0)
186201
throw ANNException("Error: please pass initial_search_list_size for building dynamic index.", -1);
187202
}
188203

189204
return IndexConfig(_data_strategy, _graph_strategy, _metric, _dimension, _max_points, _num_pq_chunks,
190205
_num_frozen_pts, _dynamic_index, _enable_tags, _pq_dist_build, _concurrent_consolidate,
191-
_use_opq, _data_type, _tag_type, _label_type, _index_write_params, _search_threads,
192-
_initial_search_list_size);
206+
_use_opq, _data_type, _tag_type, _label_type, _index_write_params, _index_search_params);
193207
}
194208

195209
IndexConfigBuilder(const IndexConfigBuilder &) = delete;
@@ -217,8 +231,6 @@ class IndexConfigBuilder
217231
std::string _data_type;
218232

219233
std::shared_ptr<IndexWriteParameters> _index_write_params;
220-
221-
uint32_t _search_threads;
222-
uint32_t _initial_search_list_size;
234+
std::shared_ptr<IndexSearchParams> _index_search_params;
223235
};
224236
} // namespace diskann

include/index_factory.h

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,15 @@ class IndexFactory
1010
DISKANN_DLLEXPORT explicit IndexFactory(const IndexConfig &config);
1111
DISKANN_DLLEXPORT std::unique_ptr<AbstractIndex> create_instance();
1212

13+
// Consruct a data store with distance function emplaced within
14+
template <typename T>
15+
DISKANN_DLLEXPORT static std::unique_ptr<AbstractDataStore<T>> construct_datastore(DataStoreStrategy stratagy,
16+
size_t num_points,
17+
size_t dimension, Metric m);
18+
1319
private:
1420
void check_config();
1521

16-
template <typename T>
17-
std::unique_ptr<AbstractDataStore<T>> construct_datastore(DataStoreStrategy stratagy, size_t num_points,
18-
size_t dimension);
19-
2022
std::unique_ptr<AbstractGraphStore> construct_graphstore(GraphStoreStrategy stratagy, size_t size);
2123

2224
template <typename data_type, typename tag_type, typename label_type>

include/parameters.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,17 @@ class IndexWriteParameters
3838
friend class IndexWriteParametersBuilder;
3939
};
4040

41+
class IndexSearchParams
42+
{
43+
public:
44+
IndexSearchParams(const uint32_t initial_search_list_size, const uint32_t num_search_threads)
45+
: initial_search_list_size(initial_search_list_size), num_search_threads(num_search_threads)
46+
{
47+
}
48+
const uint32_t initial_search_list_size; // search L
49+
const uint32_t num_search_threads; // search threads
50+
};
51+
4152
class IndexWriteParametersBuilder
4253
{
4354
/**

0 commit comments

Comments
 (0)