Skip to content

Commit a2a6a67

Browse files
authored
Merge pull request #719 from rapidsai/branch-25.02
Forward-merge branch-25.02 into branch-25.04
2 parents a1e0cc0 + 1591029 commit a2a6a67

File tree

16 files changed

+358
-146
lines changed

16 files changed

+358
-146
lines changed

conda/recipes/cuvs-bench-cpu/meta.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ requirements:
6262
- pyyaml
6363
- python
6464
- requests
65+
- sklearn>=1.5
6566
about:
6667
home: https://rapids.ai/
6768
license: Apache-2.0

conda/recipes/cuvs-bench/meta.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,7 @@ requirements:
101101
- python
102102
- requests
103103
- rmm ={{ minor_version }}
104+
- sklearn>=1.5
104105
about:
105106
home: https://rapids.ai/
106107
license: Apache-2.0

cpp/include/cuvs/neighbors/cagra.hpp

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -337,6 +337,8 @@ struct index : cuvs::neighbors::index {
337337
using search_params_type = cagra::search_params;
338338
using index_type = IdxT;
339339
using value_type = T;
340+
using dataset_index_type = int64_t;
341+
340342
static_assert(!raft::is_narrowing_v<uint32_t, IdxT>,
341343
"IdxT must be able to represent all values of uint32_t");
342344

@@ -510,14 +512,14 @@ struct index : cuvs::neighbors::index {
510512
*/
511513
template <typename DatasetT>
512514
auto update_dataset(raft::resources const& res, DatasetT&& dataset)
513-
-> std::enable_if_t<std::is_base_of_v<cuvs::neighbors::dataset<int64_t>, DatasetT>>
515+
-> std::enable_if_t<std::is_base_of_v<cuvs::neighbors::dataset<dataset_index_type>, DatasetT>>
514516
{
515517
dataset_ = std::make_unique<DatasetT>(std::move(dataset));
516518
}
517519

518520
template <typename DatasetT>
519521
auto update_dataset(raft::resources const& res, std::unique_ptr<DatasetT>&& dataset)
520-
-> std::enable_if_t<std::is_base_of_v<neighbors::dataset<int64_t>, DatasetT>>
522+
-> std::enable_if_t<std::is_base_of_v<neighbors::dataset<dataset_index_type>, DatasetT>>
521523
{
522524
dataset_ = std::move(dataset);
523525
}
@@ -561,7 +563,7 @@ struct index : cuvs::neighbors::index {
561563
cuvs::distance::DistanceType metric_;
562564
raft::device_matrix<IdxT, int64_t, raft::row_major> graph_;
563565
raft::device_matrix_view<const IdxT, int64_t, raft::row_major> graph_view_;
564-
std::unique_ptr<neighbors::dataset<int64_t>> dataset_;
566+
std::unique_ptr<neighbors::dataset<dataset_index_type>> dataset_;
565567
};
566568
/**
567569
* @}

cpp/src/neighbors/detail/cagra/cagra_merge.cuh

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -43,14 +43,16 @@ index<T, IdxT> merge(raft::resources const& handle,
4343
const cagra::merge_params& params,
4444
std::vector<cuvs::neighbors::cagra::index<T, IdxT>*>& indices)
4545
{
46+
using cagra_index_t = cuvs::neighbors::cagra::index<T, IdxT>;
47+
using ds_idx_type = typename cagra_index_t::dataset_index_type;
48+
4649
std::size_t dim = 0;
4750
std::size_t new_dataset_size = 0;
4851
int64_t stride = -1;
4952

50-
for (auto index : indices) {
53+
for (cagra_index_t* index : indices) {
5154
RAFT_EXPECTS(index != nullptr,
5255
"Null pointer detected in 'indices'. Ensure all elements are valid before usage.");
53-
using ds_idx_type = decltype(index->data().n_rows());
5456
if (auto* strided_dset = dynamic_cast<const strided_dataset<T, ds_idx_type>*>(&index->data());
5557
strided_dset != nullptr) {
5658
if (dim == 0) {
@@ -74,8 +76,7 @@ index<T, IdxT> merge(raft::resources const& handle,
7476
IdxT offset = 0;
7577

7678
auto merge_dataset = [&](T* dst) {
77-
for (auto index : indices) {
78-
using ds_idx_type = decltype(index->data().n_rows());
79+
for (cagra_index_t* index : indices) {
7980
auto* strided_dset = dynamic_cast<const strided_dataset<T, ds_idx_type>*>(&index->data());
8081

8182
RAFT_CUDA_TRY(cudaMemcpy2DAsync(dst + offset * dim,

cpp/src/neighbors/detail/nn_descent.cuh

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1047,24 +1047,32 @@ void GnndGraph<Index_t>::init_random_graph()
10471047
for (size_t seg_idx = 0; seg_idx < static_cast<size_t>(num_segments); seg_idx++) {
10481048
// random sequence (range: 0~nrow)
10491049
// segment_x stores neighbors which id % num_segments == x
1050-
std::vector<Index_t> rand_seq(nrow / num_segments);
1050+
std::vector<Index_t> rand_seq((nrow + num_segments - 1) / num_segments);
10511051
std::iota(rand_seq.begin(), rand_seq.end(), 0);
10521052
auto gen = std::default_random_engine{seg_idx};
10531053
std::shuffle(rand_seq.begin(), rand_seq.end(), gen);
10541054

10551055
#pragma omp parallel for
10561056
for (size_t i = 0; i < nrow; i++) {
1057-
size_t base_idx = i * node_degree + seg_idx * segment_size;
1058-
auto h_neighbor_list = h_graph + base_idx;
1059-
auto h_dist_list = h_dists.data_handle() + base_idx;
1057+
size_t base_idx = i * node_degree + seg_idx * segment_size;
1058+
auto h_neighbor_list = h_graph + base_idx;
1059+
auto h_dist_list = h_dists.data_handle() + base_idx;
1060+
size_t idx = base_idx;
1061+
size_t self_in_this_seg = 0;
10601062
for (size_t j = 0; j < static_cast<size_t>(segment_size); j++) {
1061-
size_t idx = base_idx + j;
10621063
Index_t id = rand_seq[idx % rand_seq.size()] * num_segments + seg_idx;
10631064
if ((size_t)id == i) {
1064-
id = rand_seq[(idx + segment_size) % rand_seq.size()] * num_segments + seg_idx;
1065+
idx++;
1066+
id = rand_seq[idx % rand_seq.size()] * num_segments + seg_idx;
1067+
self_in_this_seg = 1;
10651068
}
1066-
h_neighbor_list[j].id_with_flag() = id;
1067-
h_dist_list[j] = std::numeric_limits<DistData_t>::max();
1069+
1070+
h_neighbor_list[j].id_with_flag() =
1071+
j < (rand_seq.size() - self_in_this_seg) && size_t(id) < nrow
1072+
? id
1073+
: std::numeric_limits<Index_t>::max();
1074+
h_dist_list[j] = std::numeric_limits<DistData_t>::max();
1075+
idx++;
10681076
}
10691077
}
10701078
}

cpp/tests/neighbors/ann_cagra.cuh

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -952,6 +952,12 @@ class AnnCagraIndexMergeTest : public ::testing::TestWithParam<AnnCagraInputs> {
952952
(ps.k * ps.dim * 8 / 5 /*(=magic number)*/ < ps.n_rows))
953953
GTEST_SKIP();
954954

955+
// Avoid splitting datasets with a size of 0
956+
if (ps.n_rows <= 3) GTEST_SKIP();
957+
958+
// IVF_PQ requires the `n_rows >= n_lists`.
959+
if (ps.n_rows < 8 && ps.build_algo == graph_build_algo::IVF_PQ) GTEST_SKIP();
960+
955961
size_t queries_size = ps.n_queries * ps.k;
956962
std::vector<IdxT> indices_Cagra(queries_size);
957963
std::vector<IdxT> indices_naive(queries_size);
@@ -1161,6 +1167,24 @@ inline std::vector<AnnCagraInputs> generate_inputs()
11611167
{0.995});
11621168
inputs.insert(inputs.end(), inputs2.begin(), inputs2.end());
11631169

1170+
// Corner cases for small datasets
1171+
inputs2 = raft::util::itertools::product<AnnCagraInputs>(
1172+
{2},
1173+
{3, 5, 31, 32, 64, 101},
1174+
{1, 10},
1175+
{2}, // k
1176+
{graph_build_algo::IVF_PQ, graph_build_algo::NN_DESCENT},
1177+
{search_algo::SINGLE_CTA, search_algo::MULTI_CTA, search_algo::MULTI_KERNEL},
1178+
{0}, // query size
1179+
{0},
1180+
{256},
1181+
{1},
1182+
{cuvs::distance::DistanceType::L2Expanded},
1183+
{false},
1184+
{true},
1185+
{0.995});
1186+
inputs.insert(inputs.end(), inputs2.begin(), inputs2.end());
1187+
11641188
// Varying dim and build algo.
11651189
inputs2 = raft::util::itertools::product<AnnCagraInputs>(
11661190
{100},

java/cuvs-java/src/main/java/com/nvidia/cuvs/BruteForceQuery.java

Lines changed: 30 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
package com.nvidia.cuvs;
1818

1919
import java.util.Arrays;
20+
import java.util.BitSet;
2021
import java.util.List;
2122

2223
/**
@@ -28,7 +29,8 @@ public class BruteForceQuery {
2829

2930
private List<Integer> mapping;
3031
private float[][] queryVectors;
31-
private long[] prefilter;
32+
private BitSet[] prefilters;
33+
private int numDocs = -1;
3234
private int topK;
3335

3436
/**
@@ -40,12 +42,15 @@ public class BruteForceQuery {
4042
* @param topK the top k results to return
4143
* @param prefilter the prefilter data to use while searching the BRUTEFORCE
4244
* index
45+
* @param numDocs Maximum of bits in each prefilter, representing number of documents in this index.
46+
* Used only when prefilter(s) is/are passed.
4347
*/
44-
public BruteForceQuery(float[][] queryVectors, List<Integer> mapping, int topK, long[] prefilter) {
48+
public BruteForceQuery(float[][] queryVectors, List<Integer> mapping, int topK, BitSet[] prefilters, int numDocs) {
4549
this.queryVectors = queryVectors;
4650
this.mapping = mapping;
4751
this.topK = topK;
48-
this.prefilter = prefilter;
52+
this.prefilters = prefilters;
53+
this.numDocs = numDocs;
4954
}
5055

5156
/**
@@ -78,16 +83,25 @@ public int getTopK() {
7883
/**
7984
* Gets the prefilter long array
8085
*
81-
* @return a long array
86+
* @return an array of bitsets
8287
*/
83-
public long[] getPrefilter() {
84-
return prefilter;
88+
public BitSet[] getPrefilters() {
89+
return prefilters;
90+
}
91+
92+
/**
93+
* Gets the number of documents supposed to be in this index, as used for prefilters
94+
*
95+
* @return number of documents as an integer
96+
*/
97+
public int getNumDocs() {
98+
return numDocs;
8599
}
86100

87101
@Override
88102
public String toString() {
89103
return "BruteForceQuery [mapping=" + mapping + ", queryVectors=" + Arrays.toString(queryVectors) + ", prefilter="
90-
+ Arrays.toString(prefilter) + ", topK=" + topK + "]";
104+
+ Arrays.toString(prefilters) + ", topK=" + topK + "]";
91105
}
92106

93107
/**
@@ -96,7 +110,8 @@ public String toString() {
96110
public static class Builder {
97111

98112
private float[][] queryVectors;
99-
private long[] prefilter;
113+
private BitSet[] prefilters;
114+
private int numDocs;
100115
private List<Integer> mapping;
101116
private int topK = 2;
102117

@@ -134,13 +149,15 @@ public Builder withTopK(int topK) {
134149
}
135150

136151
/**
137-
* Sets the prefilter data for building the {@link BruteForceQuery}.
152+
* Sets the prefilters data for building the {@link BruteForceQuery}.
138153
*
139-
* @param prefilter a one-dimensional long array
154+
* @param prefilters array of bitsets, as many as queries, each containing as
155+
* many bits as there are vectors in the index
140156
* @return an instance of this Builder
141157
*/
142-
public Builder withPrefilter(long[] prefilter) {
143-
this.prefilter = prefilter;
158+
public Builder withPrefilter(BitSet[] prefilters, int numDocs) {
159+
this.prefilters = prefilters;
160+
this.numDocs = numDocs;
144161
return this;
145162
}
146163

@@ -150,7 +167,7 @@ public Builder withPrefilter(long[] prefilter) {
150167
* @return an instance of {@link BruteForceQuery}
151168
*/
152169
public BruteForceQuery build() {
153-
return new BruteForceQuery(queryVectors, mapping, topK, prefilter);
170+
return new BruteForceQuery(queryVectors, mapping, topK, prefilters, numDocs);
154171
}
155172
}
156173
}

java/cuvs-java/src/main/java22/com/nvidia/cuvs/internal/BruteForceIndexImpl.java

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,12 @@
2626
import java.lang.foreign.MemorySegment;
2727
import java.lang.foreign.SequenceLayout;
2828
import java.lang.invoke.MethodHandle;
29+
import java.nio.ByteBuffer;
30+
import java.nio.ByteOrder;
2931
import java.nio.file.Files;
3032
import java.nio.file.Path;
33+
import java.util.Arrays;
34+
import java.util.BitSet;
3135
import java.util.Objects;
3236
import java.util.UUID;
3337

@@ -59,7 +63,7 @@ public class BruteForceIndexImpl implements BruteForceIndex{
5963
FunctionDescriptor.of(ADDRESS, ADDRESS, C_LONG, C_LONG, ADDRESS, ADDRESS, C_INT));
6064

6165
private static final MethodHandle searchMethodHandle = downcallHandle("search_brute_force_index",
62-
FunctionDescriptor.ofVoid(ADDRESS, ADDRESS, C_INT, C_LONG, C_INT, ADDRESS, ADDRESS, ADDRESS, ADDRESS, ADDRESS, C_LONG, C_LONG));
66+
FunctionDescriptor.ofVoid(ADDRESS, ADDRESS, C_INT, C_LONG, C_INT, ADDRESS, ADDRESS, ADDRESS, ADDRESS, ADDRESS, C_LONG));
6367

6468
private static final MethodHandle destroyIndexMethodHandle = downcallHandle("destroy_brute_force_index",
6569
FunctionDescriptor.ofVoid(ADDRESS, ADDRESS));
@@ -169,16 +173,24 @@ public SearchResults search(BruteForceQuery cuvsQuery) throws Throwable {
169173
long numQueries = cuvsQuery.getQueryVectors().length;
170174
long numBlocks = cuvsQuery.getTopK() * numQueries;
171175
int vectorDimension = numQueries > 0 ? cuvsQuery.getQueryVectors()[0].length : 0;
172-
long prefilterDataLength = cuvsQuery.getPrefilter() != null ? cuvsQuery.getPrefilter().length : 0;
173176
long numRows = dataset != null ? dataset.length : 0;
174177

175178
SequenceLayout neighborsSequenceLayout = MemoryLayout.sequenceLayout(numBlocks, C_LONG);
176179
SequenceLayout distancesSequenceLayout = MemoryLayout.sequenceLayout(numBlocks, C_FLOAT);
177180
MemorySegment neighborsMemorySegment = resources.getArena().allocate(neighborsSequenceLayout);
178181
MemorySegment distancesMemorySegment = resources.getArena().allocate(distancesSequenceLayout);
179-
MemorySegment prefilterDataMemorySegment = cuvsQuery.getPrefilter() != null
180-
? Util.buildMemorySegment(resources.getArena(), cuvsQuery.getPrefilter())
181-
: MemorySegment.NULL;
182+
183+
// prepare the prefiltering data
184+
long prefilterDataLength = 0;
185+
MemorySegment prefilterDataMemorySegment = MemorySegment.NULL;
186+
BitSet[] prefilters = cuvsQuery.getPrefilters();
187+
if (prefilters != null && prefilters.length > 0) {
188+
BitSet concatenatedFilters = Util.concatenate(prefilters, cuvsQuery.getNumDocs());
189+
long filters[] = concatenatedFilters.toLongArray();
190+
prefilterDataMemorySegment = Util.buildMemorySegment(resources.getArena(), filters);
191+
prefilterDataLength = cuvsQuery.getNumDocs() * prefilters.length;
192+
}
193+
182194
MemorySegment querySeg = Util.buildMemorySegment(resources.getArena(), cuvsQuery.getQueryVectors());
183195
try (var localArena = Arena.ofConfined()) {
184196
MemorySegment returnValue = localArena.allocate(C_INT);
@@ -193,7 +205,7 @@ public SearchResults search(BruteForceQuery cuvsQuery) throws Throwable {
193205
distancesMemorySegment,
194206
returnValue,
195207
prefilterDataMemorySegment,
196-
prefilterDataLength, numRows
208+
prefilterDataLength
197209
);
198210
checkError(returnValue.get(C_INT, 0L), "searchMethodHandle");
199211
}

java/cuvs-java/src/main/java22/com/nvidia/cuvs/internal/common/Util.java

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@
2525
import java.lang.invoke.MethodHandle;
2626
import java.lang.invoke.VarHandle;
2727
import java.util.ArrayList;
28+
import java.util.Arrays;
29+
import java.util.BitSet;
2830
import java.util.List;
2931

3032
import com.nvidia.cuvs.GPUInfo;
@@ -184,6 +186,14 @@ public static MemorySegment buildMemorySegment(Arena arena, long[] data) {
184186
return dataMemorySegment;
185187
}
186188

189+
public static MemorySegment buildMemorySegment(Arena arena, byte[] data) {
190+
int cells = data.length;
191+
MemoryLayout dataMemoryLayout = MemoryLayout.sequenceLayout(cells, C_CHAR);
192+
MemorySegment dataMemorySegment = arena.allocate(dataMemoryLayout);
193+
MemorySegment.copy(data, 0, dataMemorySegment, C_CHAR, 0, cells);
194+
return dataMemorySegment;
195+
}
196+
187197
/**
188198
* A utility method for building a {@link MemorySegment} for a 2D float array.
189199
*
@@ -201,4 +211,20 @@ public static MemorySegment buildMemorySegment(Arena arena, float[][] data) {
201211
}
202212
return dataMemorySegment;
203213
}
214+
215+
public static BitSet concatenate(BitSet[] arr, int maxSizeOfEachBitSet) {
216+
BitSet ret = new BitSet(maxSizeOfEachBitSet * arr.length);
217+
for (int i = 0; i < arr.length; i++) {
218+
BitSet b = arr[i];
219+
if (b == null || b.length() == 0) {
220+
ret.set(i * maxSizeOfEachBitSet, (i + 1) * maxSizeOfEachBitSet);
221+
} else {
222+
for (int j = 0; j < maxSizeOfEachBitSet; j++) {
223+
ret.set(i * maxSizeOfEachBitSet + j, b.get(j));
224+
}
225+
}
226+
}
227+
return ret;
228+
}
229+
204230
}

0 commit comments

Comments
 (0)