Skip to content

Commit 028e98b

Browse files
committed
instantiate ivfflat_interleaved_scan in separate files
1 parent daf2181 commit 028e98b

11 files changed

+325
-74
lines changed

cpp/CMakeLists.txt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -432,6 +432,10 @@ if(BUILD_SHARED_LIBS)
432432
src/neighbors/ivf_flat/ivf_flat_search_half_int64_t.cu
433433
src/neighbors/ivf_flat/ivf_flat_search_int8_t_int64_t.cu
434434
src/neighbors/ivf_flat/ivf_flat_search_uint8_t_int64_t.cu
435+
src/neighbors/ivf_flat/ivf_flat_interleaved_scan_float_int64_t.cu
436+
src/neighbors/ivf_flat/ivf_flat_interleaved_scan_half_int64_t.cu
437+
src/neighbors/ivf_flat/ivf_flat_interleaved_scan_int8_t_int64_t.cu
438+
src/neighbors/ivf_flat/ivf_flat_interleaved_scan_uint8_t_int64_t.cu
435439
src/neighbors/ivf_flat/ivf_flat_serialize_float_int64_t.cu
436440
src/neighbors/ivf_flat/ivf_flat_serialize_half_int64_t.cu
437441
src/neighbors/ivf_flat/ivf_flat_serialize_int8_t_int64_t.cu

cpp/src/neighbors/ivf_common.cuh

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright (c) 2024, NVIDIA CORPORATION.
2+
* Copyright (c) 2024-2025, NVIDIA CORPORATION.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -148,11 +148,11 @@ __device__ inline auto find_chunk_ix(uint32_t& sample_ix, // NOLINT
148148
return ix_min;
149149
}
150150

151-
template <int BlockDim, typename IdxT>
151+
template <int BlockDim, typename IdxT, typename DbIdxT>
152152
__launch_bounds__(BlockDim) RAFT_KERNEL
153153
postprocess_neighbors_kernel(IdxT* neighbors_out, // [n_queries, topk]
154154
const uint32_t* neighbors_in, // [n_queries, topk]
155-
const IdxT* const* db_indices, // [n_clusters][..]
155+
const DbIdxT* const* db_indices, // [n_clusters][..]
156156
const uint32_t* clusters_to_probe, // [n_queries, n_probes]
157157
const uint32_t* chunk_indices, // [n_queries, n_probes]
158158
uint32_t n_queries,
@@ -181,10 +181,10 @@ __launch_bounds__(BlockDim) RAFT_KERNEL
181181
* probed clusters / defined by the `chunk_indices`.
182182
* We assume the searched sample sizes (for a single query) fit into `uint32_t`.
183183
*/
184-
template <typename IdxT>
184+
template <typename IdxT, typename DbIdxT>
185185
void postprocess_neighbors(IdxT* neighbors_out, // [n_queries, topk]
186186
const uint32_t* neighbors_in, // [n_queries, topk]
187-
const IdxT* const* db_indices, // [n_clusters][..]
187+
const DbIdxT* const* db_indices, // [n_clusters][..]
188188
const uint32_t* clusters_to_probe, // [n_queries, n_probes]
189189
const uint32_t* chunk_indices, // [n_queries, n_probes]
190190
uint32_t n_queries,

cpp/src/neighbors/ivf_flat/ivf_flat_build.cuh

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright (c) 2022-2024, NVIDIA CORPORATION.
2+
* Copyright (c) 2022-2025, NVIDIA CORPORATION.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -100,10 +100,11 @@ auto clone(const raft::resources& res, const index<T, IdxT>& source) -> index<T,
100100
* there are no dependencies between threads, hence no constraints on the block size.
101101
*
102102
* @tparam T element type.
103-
* @tparam IdxT type of the indices in the source source_vecs
103+
* @tparam IdxT type of the vector ids in the index (corresponds to second arg ofindex<T, IdxT>)
104104
* @tparam LabelT label type
105105
* @tparam gather_src if false, then we build the index from vectors source_vecs[i,:], otherwise
106106
* we use source_vecs[source_ixs[i],:]. In both cases i=0..n_rows-1.
107+
* @tparam SourceIndexT input index type (usually same as IdxT)
107108
*
108109
* @param[in] labels device pointer to the cluster ids for each row [n_rows]
109110
* @param[in] source_vecs device pointer to the input data [n_rows, dim]
@@ -118,10 +119,10 @@ auto clone(const raft::resources& res, const index<T, IdxT>& source) -> index<T,
118119
* @param veclen size of vectorized loads/stores; must satisfy `dim % veclen == 0`.
119120
*
120121
*/
121-
template <typename T, typename IdxT, typename LabelT, bool gather_src = false>
122+
template <typename T, typename IdxT, typename LabelT, bool gather_src = false, typename SourceIdxT>
122123
RAFT_KERNEL build_index_kernel(const LabelT* labels,
123124
const T* source_vecs,
124-
const IdxT* source_ixs,
125+
const SourceIdxT* source_ixs,
125126
T** list_data_ptrs,
126127
IdxT** list_index_ptrs,
127128
uint32_t* list_sizes_ptr,
@@ -135,7 +136,10 @@ RAFT_KERNEL build_index_kernel(const LabelT* labels,
135136
auto source_ix = source_ixs == nullptr ? i + batch_offset : source_ixs[i];
136137
// In the context of refinement, some indices may be invalid (the generating NN algorithm does
137138
// not return enough valid items). Do not add the item to the index in this case.
138-
if (source_ix == ivf::kInvalidRecord<IdxT> || source_ix == raft::upper_bound<IdxT>()) { return; }
139+
if (source_ix == ivf::kInvalidRecord<SourceIdxT> ||
140+
source_ix == raft::upper_bound<SourceIdxT>()) {
141+
return;
142+
}
139143

140144
auto list_id = labels[i];
141145
auto inlist_id = atomicAdd(list_sizes_ptr + list_id, 1);
@@ -460,11 +464,11 @@ inline auto build(raft::resources const& handle,
460464
* @param[in] candidate_idx device pointer to neighbor candidates, size [n_queries, n_candidates]
461465
* @param[in] n_candidates of neighbor_candidates
462466
*/
463-
template <typename T, typename IdxT>
467+
template <typename T, typename IdxT, typename CandidateIdxT>
464468
inline void fill_refinement_index(raft::resources const& handle,
465469
index<T, IdxT>* refinement_index,
466470
const T* dataset,
467-
const IdxT* candidate_idx,
471+
const CandidateIdxT* candidate_idx,
468472
IdxT n_queries,
469473
uint32_t n_candidates)
470474
{

cpp/src/neighbors/ivf_flat/ivf_flat_interleaved_scan.cuh

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright (c) 2022-2024, NVIDIA CORPORATION.
2+
* Copyright (c) 2022-2025, NVIDIA CORPORATION.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -40,11 +40,6 @@ using namespace cuvs::spatial::knn::detail; // NOLINT
4040

4141
constexpr int kThreadsPerBlock = 128;
4242

43-
auto RAFT_WEAK_FUNCTION is_local_topk_feasible(uint32_t k) -> bool
44-
{
45-
return k <= raft::matrix::detail::select::warpsort::kMaxCapacity;
46-
}
47-
4843
/**
4944
* @brief Copy `n` elements per block from one place to another.
5045
*

cpp/src/neighbors/ivf_flat/ivf_flat_interleaved_scan_ext.cuh

Lines changed: 41 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,11 @@
1919
#include <cstdint>
2020
#include <cuda_fp16.h>
2121

22+
#include "../detail/ann_utils.cuh"
2223
#include <cuvs/neighbors/common.hpp>
2324
#include <raft/core/resource/cuda_stream.hpp>
2425
#include <raft/core/resources.hpp>
26+
#include <raft/util/raft_explicit.hpp>
2527

2628
namespace cuvs::neighbors::ivf_flat::detail {
2729
template <typename T, typename AccT, typename IdxT, typename IvfSampleFilterT>
@@ -40,45 +42,55 @@ void ivfflat_interleaved_scan(const index<T, IdxT>& index,
4042
uint32_t* neighbors,
4143
float* distances,
4244
uint32_t& grid_dim_x,
43-
rmm::cuda_stream_view stream);
45+
rmm::cuda_stream_view stream) RAFT_EXPLICIT;
4446

45-
#define CUVS_INST_IVF_FLAT_INTERLEAVED_SCAN(T, IdxT, AccT, SampleFilterT) \
46-
extern template void ivfflat_interleaved_scan<T, AccT, IdxT, SampleFilterT>( \
47-
const index<T, IdxT>& index, \
48-
const T* queries, \
49-
const uint32_t* coarse_query_results, \
50-
const uint32_t n_queries, \
51-
const uint32_t queries_offset, \
52-
const cuvs::distance::DistanceType metric, \
53-
const uint32_t n_probes, \
54-
const uint32_t k, \
55-
const uint32_t max_samples, \
56-
const uint32_t* chunk_indices, \
57-
const bool select_min, \
58-
SampleFilterT sample_filter, \
59-
uint32_t* neighbors, \
60-
float* distances, \
61-
uint32_t& grid_dim_x, \
62-
rmm::cuda_stream_view stream);
47+
#define CUVS_INST_IVF_FLAT_INTERLEAVED_SCAN(T, IdxT, SampleFilterT) \
48+
extern template void \
49+
ivfflat_interleaved_scan<T, \
50+
typename cuvs::spatial::knn::detail::utils::config<T>::value_t, \
51+
IdxT, \
52+
SampleFilterT>(const index<T, IdxT>& index, \
53+
const T* queries, \
54+
const uint32_t* coarse_query_results, \
55+
const uint32_t n_queries, \
56+
const uint32_t queries_offset, \
57+
const cuvs::distance::DistanceType metric, \
58+
const uint32_t n_probes, \
59+
const uint32_t k, \
60+
const uint32_t max_samples, \
61+
const uint32_t* chunk_indices, \
62+
const bool select_min, \
63+
SampleFilterT sample_filter, \
64+
uint32_t* neighbors, \
65+
float* distances, \
66+
uint32_t& grid_dim_x, \
67+
rmm::cuda_stream_view stream);
6368

64-
CUVS_INST_IVF_FLAT_INTERLEAVED_SCAN(float,
65-
int64_t,
66-
float,
67-
cuvs::neighbors::filtering::none_sample_filter);
68-
CUVS_INST_IVF_FLAT_INTERLEAVED_SCAN(half,
69-
int64_t,
70-
float,
71-
cuvs::neighbors::filtering::none_sample_filter);
69+
CUVS_INST_IVF_FLAT_INTERLEAVED_SCAN(float, int64_t, cuvs::neighbors::filtering::none_sample_filter);
70+
71+
CUVS_INST_IVF_FLAT_INTERLEAVED_SCAN(half, int64_t, cuvs::neighbors::filtering::none_sample_filter);
7272

7373
CUVS_INST_IVF_FLAT_INTERLEAVED_SCAN(int8_t,
7474
int64_t,
75-
float,
7675
cuvs::neighbors::filtering::none_sample_filter);
76+
7777
CUVS_INST_IVF_FLAT_INTERLEAVED_SCAN(uint8_t,
7878
int64_t,
79-
float,
8079
cuvs::neighbors::filtering::none_sample_filter);
8180

81+
#define COMMA ,
82+
CUVS_INST_IVF_FLAT_INTERLEAVED_SCAN(
83+
float, int64_t, cuvs::neighbors::filtering::bitset_filter<uint32_t COMMA int64_t>);
84+
85+
CUVS_INST_IVF_FLAT_INTERLEAVED_SCAN(
86+
half, int64_t, cuvs::neighbors::filtering::bitset_filter<uint32_t COMMA int64_t>);
87+
88+
CUVS_INST_IVF_FLAT_INTERLEAVED_SCAN(
89+
int8_t, int64_t, cuvs::neighbors::filtering::bitset_filter<uint32_t COMMA int64_t>);
90+
91+
CUVS_INST_IVF_FLAT_INTERLEAVED_SCAN(
92+
uint8_t, int64_t, cuvs::neighbors::filtering::bitset_filter<uint32_t COMMA int64_t>);
93+
#undef COMMA
8294
#undef CUVS_INST_IVF_FLAT_INTERLEAVED_SCAN
8395

8496
} // namespace cuvs::neighbors::ivf_flat::detail
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
/*
2+
* Copyright (c) 2025, NVIDIA CORPORATION.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
#include "../detail/ann_utils.cuh"
18+
#include "ivf_flat_interleaved_scan.cuh"
19+
#include <cstdint>
20+
#include <cuda_fp16.h>
21+
#include <cuvs/neighbors/common.hpp>
22+
#include <cuvs/neighbors/ivf_flat.hpp>
23+
#include <raft/core/resource/cuda_stream.hpp>
24+
#include <raft/core/resources.hpp>
25+
26+
namespace cuvs::neighbors::ivf_flat::detail {
27+
28+
#define CUVS_INST_IVF_FLAT_INTERLEAVED_SCAN(T, IdxT, SampleFilterT) \
29+
template void \
30+
ivfflat_interleaved_scan<T, \
31+
typename cuvs::spatial::knn::detail::utils::config<T>::value_t, \
32+
IdxT, \
33+
SampleFilterT>(const index<T, IdxT>& index, \
34+
const T* queries, \
35+
const uint32_t* coarse_query_results, \
36+
const uint32_t n_queries, \
37+
const uint32_t queries_offset, \
38+
const cuvs::distance::DistanceType metric, \
39+
const uint32_t n_probes, \
40+
const uint32_t k, \
41+
const uint32_t max_samples, \
42+
const uint32_t* chunk_indices, \
43+
const bool select_min, \
44+
SampleFilterT sample_filter, \
45+
uint32_t* neighbors, \
46+
float* distances, \
47+
uint32_t& grid_dim_x, \
48+
rmm::cuda_stream_view stream);
49+
#define COMMA ,
50+
CUVS_INST_IVF_FLAT_INTERLEAVED_SCAN(float, int64_t, filtering::none_sample_filter);
51+
CUVS_INST_IVF_FLAT_INTERLEAVED_SCAN(float,
52+
int64_t,
53+
filtering::bitset_filter<uint32_t COMMA int64_t>);
54+
#undef COMMA
55+
#undef CUVS_INST_IVF_FLAT_INTERLEAVED_SCAN
56+
57+
} // namespace cuvs::neighbors::ivf_flat::detail
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
/*
2+
* Copyright (c) 2025, NVIDIA CORPORATION.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
#include "../detail/ann_utils.cuh"
18+
#include "ivf_flat_interleaved_scan.cuh"
19+
#include <cstdint>
20+
#include <cuda_fp16.h>
21+
#include <cuvs/neighbors/common.hpp>
22+
#include <cuvs/neighbors/ivf_flat.hpp>
23+
#include <raft/core/resource/cuda_stream.hpp>
24+
#include <raft/core/resources.hpp>
25+
26+
namespace cuvs::neighbors::ivf_flat::detail {
27+
28+
#define CUVS_INST_IVF_FLAT_INTERLEAVED_SCAN(T, IdxT, SampleFilterT) \
29+
template void \
30+
ivfflat_interleaved_scan<T, \
31+
typename cuvs::spatial::knn::detail::utils::config<T>::value_t, \
32+
IdxT, \
33+
SampleFilterT>(const index<T, IdxT>& index, \
34+
const T* queries, \
35+
const uint32_t* coarse_query_results, \
36+
const uint32_t n_queries, \
37+
const uint32_t queries_offset, \
38+
const cuvs::distance::DistanceType metric, \
39+
const uint32_t n_probes, \
40+
const uint32_t k, \
41+
const uint32_t max_samples, \
42+
const uint32_t* chunk_indices, \
43+
const bool select_min, \
44+
SampleFilterT sample_filter, \
45+
uint32_t* neighbors, \
46+
float* distances, \
47+
uint32_t& grid_dim_x, \
48+
rmm::cuda_stream_view stream);
49+
#define COMMA ,
50+
CUVS_INST_IVF_FLAT_INTERLEAVED_SCAN(half, int64_t, filtering::none_sample_filter);
51+
CUVS_INST_IVF_FLAT_INTERLEAVED_SCAN(half,
52+
int64_t,
53+
filtering::bitset_filter<uint32_t COMMA int64_t>);
54+
#undef COMMA
55+
#undef CUVS_INST_IVF_FLAT_INTERLEAVED_SCAN
56+
57+
} // namespace cuvs::neighbors::ivf_flat::detail

0 commit comments

Comments
 (0)