Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions cpp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -432,6 +432,14 @@ if(BUILD_SHARED_LIBS)
src/neighbors/ivf_flat/ivf_flat_search_half_int64_t.cu
src/neighbors/ivf_flat/ivf_flat_search_int8_t_int64_t.cu
src/neighbors/ivf_flat/ivf_flat_search_uint8_t_int64_t.cu
src/neighbors/ivf_flat/ivf_flat_interleaved_scan_float_int64_t.cu
src/neighbors/ivf_flat/ivf_flat_interleaved_scan_half_int64_t.cu
src/neighbors/ivf_flat/ivf_flat_interleaved_scan_int8_t_int64_t.cu
src/neighbors/ivf_flat/ivf_flat_interleaved_scan_uint8_t_int64_t.cu
src/neighbors/ivf_flat/ivf_flat_interleaved_scan_float_int64_t_bitset.cu
src/neighbors/ivf_flat/ivf_flat_interleaved_scan_half_int64_t_bitset.cu
src/neighbors/ivf_flat/ivf_flat_interleaved_scan_int8_t_int64_t_bitset.cu
src/neighbors/ivf_flat/ivf_flat_interleaved_scan_uint8_t_int64_t_bitset.cu
src/neighbors/ivf_flat/ivf_flat_serialize_float_int64_t.cu
src/neighbors/ivf_flat/ivf_flat_serialize_half_int64_t.cu
src/neighbors/ivf_flat/ivf_flat_serialize_int8_t_int64_t.cu
Expand Down
16 changes: 9 additions & 7 deletions cpp/src/neighbors/ivf_common.cuh
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2024, NVIDIA CORPORATION.
* Copyright (c) 2024-2025, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -148,17 +148,19 @@ __device__ inline auto find_chunk_ix(uint32_t& sample_ix, // NOLINT
return ix_min;
}

template <int BlockDim, typename IdxT>
template <int BlockDim, typename IdxT, typename DbIdxT>
__launch_bounds__(BlockDim) RAFT_KERNEL
postprocess_neighbors_kernel(IdxT* neighbors_out, // [n_queries, topk]
const uint32_t* neighbors_in, // [n_queries, topk]
const IdxT* const* db_indices, // [n_clusters][..]
const DbIdxT* const* db_indices, // [n_clusters][..]
const uint32_t* clusters_to_probe, // [n_queries, n_probes]
const uint32_t* chunk_indices, // [n_queries, n_probes]
uint32_t n_queries,
uint32_t n_probes,
uint32_t topk)
{
static_assert(!raft::is_narrowing_v<uint32_t, IdxT>,
"IdxT must be able to represent all values of uint32_t");
const uint64_t i = threadIdx.x + BlockDim * uint64_t(blockIdx.x);
const uint32_t query_ix = i / uint64_t(topk);
if (query_ix >= n_queries) { return; }
Expand All @@ -170,8 +172,8 @@ __launch_bounds__(BlockDim) RAFT_KERNEL
uint32_t data_ix = neighbors_in[k];
const uint32_t chunk_ix = find_chunk_ix(data_ix, n_probes, chunk_indices);
const bool valid = chunk_ix < n_probes;
neighbors_out[k] =
valid ? db_indices[clusters_to_probe[chunk_ix]][data_ix] : kOutOfBoundsRecord<IdxT>;
neighbors_out[k] = valid ? static_cast<IdxT>(db_indices[clusters_to_probe[chunk_ix]][data_ix])
: kOutOfBoundsRecord<IdxT>;
}

/**
Expand All @@ -181,10 +183,10 @@ __launch_bounds__(BlockDim) RAFT_KERNEL
* probed clusters / defined by the `chunk_indices`.
* We assume the searched sample sizes (for a single query) fit into `uint32_t`.
*/
template <typename IdxT>
template <typename IdxT, typename DbIdxT>
void postprocess_neighbors(IdxT* neighbors_out, // [n_queries, topk]
const uint32_t* neighbors_in, // [n_queries, topk]
const IdxT* const* db_indices, // [n_clusters][..]
const DbIdxT* const* db_indices, // [n_clusters][..]
const uint32_t* clusters_to_probe, // [n_queries, n_probes]
const uint32_t* chunk_indices, // [n_queries, n_probes]
uint32_t n_queries,
Expand Down
20 changes: 12 additions & 8 deletions cpp/src/neighbors/ivf_flat/ivf_flat_build.cuh
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2022-2024, NVIDIA CORPORATION.
* Copyright (c) 2022-2025, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -100,8 +100,9 @@ auto clone(const raft::resources& res, const index<T, IdxT>& source) -> index<T,
* there are no dependencies between threads, hence no constraints on the block size.
*
* @tparam T element type.
* @tparam IdxT type of the indices in the source source_vecs
* @tparam IdxT type of the vector ids in the index (corresponds to second arg ofindex<T, IdxT>)
* @tparam LabelT label type
* @tparam SourceIndexT input index type (usually same as IdxT)
* @tparam gather_src if false, then we build the index from vectors source_vecs[i,:], otherwise
* we use source_vecs[source_ixs[i],:]. In both cases i=0..n_rows-1.
*
Expand All @@ -118,10 +119,10 @@ auto clone(const raft::resources& res, const index<T, IdxT>& source) -> index<T,
* @param veclen size of vectorized loads/stores; must satisfy `dim % veclen == 0`.
*
*/
template <typename T, typename IdxT, typename LabelT, bool gather_src = false>
template <typename T, typename IdxT, typename LabelT, typename SourceIdxT, bool gather_src = false>
RAFT_KERNEL build_index_kernel(const LabelT* labels,
const T* source_vecs,
const IdxT* source_ixs,
const SourceIdxT* source_ixs,
T** list_data_ptrs,
IdxT** list_index_ptrs,
uint32_t* list_sizes_ptr,
Expand All @@ -135,7 +136,10 @@ RAFT_KERNEL build_index_kernel(const LabelT* labels,
auto source_ix = source_ixs == nullptr ? i + batch_offset : source_ixs[i];
// In the context of refinement, some indices may be invalid (the generating NN algorithm does
// not return enough valid items). Do not add the item to the index in this case.
if (source_ix == ivf::kInvalidRecord<IdxT> || source_ix == raft::upper_bound<IdxT>()) { return; }
if (source_ix == ivf::kInvalidRecord<SourceIdxT> ||
source_ix == raft::upper_bound<SourceIdxT>()) {
return;
}

auto list_id = labels[i];
auto inlist_id = atomicAdd(list_sizes_ptr + list_id, 1);
Expand Down Expand Up @@ -460,11 +464,11 @@ inline auto build(raft::resources const& handle,
* @param[in] candidate_idx device pointer to neighbor candidates, size [n_queries, n_candidates]
* @param[in] n_candidates of neighbor_candidates
*/
template <typename T, typename IdxT>
template <typename T, typename IdxT, typename CandidateIdxT>
inline void fill_refinement_index(raft::resources const& handle,
index<T, IdxT>* refinement_index,
const T* dataset,
const IdxT* candidate_idx,
const CandidateIdxT* candidate_idx,
IdxT n_queries,
uint32_t n_candidates)
{
Expand Down Expand Up @@ -500,7 +504,7 @@ inline void fill_refinement_index(raft::resources const& handle,

const dim3 block_dim(256);
const dim3 grid_dim(raft::ceildiv<IdxT>(n_queries * n_candidates, block_dim.x));
build_index_kernel<T, IdxT, LabelT, true>
build_index_kernel<T, IdxT, LabelT, CandidateIdxT, true>
<<<grid_dim, block_dim, 0, stream>>>(new_labels.data(),
dataset,
candidate_idx,
Expand Down
7 changes: 1 addition & 6 deletions cpp/src/neighbors/ivf_flat/ivf_flat_interleaved_scan.cuh
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2022-2024, NVIDIA CORPORATION.
* Copyright (c) 2022-2025, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -40,11 +40,6 @@ using namespace cuvs::spatial::knn::detail; // NOLINT

constexpr int kThreadsPerBlock = 128;

auto RAFT_WEAK_FUNCTION is_local_topk_feasible(uint32_t k) -> bool
{
return k <= raft::matrix::detail::select::warpsort::kMaxCapacity;
}

/**
* @brief Copy `n` elements per block from one place to another.
*
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
/*
* Copyright (c) 2025, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#pragma once

#include "../detail/ann_utils.cuh"
#include "ivf_flat_interleaved_scan.cuh"
#include <cstdint>
#include <cuvs/neighbors/common.hpp>
#include <cuvs/neighbors/ivf_flat.hpp>
#include <raft/core/resource/cuda_stream.hpp>
#include <raft/core/resources.hpp>

#define CUVS_INST_IVF_FLAT_INTERLEAVED_SCAN(T, IdxT, SampleFilterT) \
template void \
ivfflat_interleaved_scan<T, \
typename cuvs::spatial::knn::detail::utils::config<T>::value_t, \
IdxT, \
SampleFilterT>(const index<T, IdxT>& index, \
const T* queries, \
const uint32_t* coarse_query_results, \
const uint32_t n_queries, \
const uint32_t queries_offset, \
const cuvs::distance::DistanceType metric, \
const uint32_t n_probes, \
const uint32_t k, \
const uint32_t max_samples, \
const uint32_t* chunk_indices, \
const bool select_min, \
SampleFilterT sample_filter, \
uint32_t* neighbors, \
float* distances, \
uint32_t& grid_dim_x, \
rmm::cuda_stream_view stream);

#define COMMA ,
96 changes: 96 additions & 0 deletions cpp/src/neighbors/ivf_flat/ivf_flat_interleaved_scan_ext.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
/*
* Copyright (c) 2025, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#pragma once

#include <cstdint>
#include <cuda_fp16.h>

#include "../detail/ann_utils.cuh"
#include <cuvs/neighbors/common.hpp>
#include <raft/core/resource/cuda_stream.hpp>
#include <raft/core/resources.hpp>
#include <raft/util/raft_explicit.hpp>

namespace cuvs::neighbors::ivf_flat::detail {
template <typename T, typename AccT, typename IdxT, typename IvfSampleFilterT>
void ivfflat_interleaved_scan(const index<T, IdxT>& index,
const T* queries,
const uint32_t* coarse_query_results,
const uint32_t n_queries,
const uint32_t queries_offset,
const cuvs::distance::DistanceType metric,
const uint32_t n_probes,
const uint32_t k,
const uint32_t max_samples,
const uint32_t* chunk_indices,
const bool select_min,
IvfSampleFilterT sample_filter,
uint32_t* neighbors,
float* distances,
uint32_t& grid_dim_x,
rmm::cuda_stream_view stream) RAFT_EXPLICIT;

#define CUVS_INST_IVF_FLAT_INTERLEAVED_SCAN(T, IdxT, SampleFilterT) \
extern template void \
ivfflat_interleaved_scan<T, \
typename cuvs::spatial::knn::detail::utils::config<T>::value_t, \
IdxT, \
SampleFilterT>(const index<T, IdxT>& index, \
const T* queries, \
const uint32_t* coarse_query_results, \
const uint32_t n_queries, \
const uint32_t queries_offset, \
const cuvs::distance::DistanceType metric, \
const uint32_t n_probes, \
const uint32_t k, \
const uint32_t max_samples, \
const uint32_t* chunk_indices, \
const bool select_min, \
SampleFilterT sample_filter, \
uint32_t* neighbors, \
float* distances, \
uint32_t& grid_dim_x, \
rmm::cuda_stream_view stream);

CUVS_INST_IVF_FLAT_INTERLEAVED_SCAN(float, int64_t, cuvs::neighbors::filtering::none_sample_filter);

CUVS_INST_IVF_FLAT_INTERLEAVED_SCAN(half, int64_t, cuvs::neighbors::filtering::none_sample_filter);

CUVS_INST_IVF_FLAT_INTERLEAVED_SCAN(int8_t,
int64_t,
cuvs::neighbors::filtering::none_sample_filter);

CUVS_INST_IVF_FLAT_INTERLEAVED_SCAN(uint8_t,
int64_t,
cuvs::neighbors::filtering::none_sample_filter);

#define COMMA ,
CUVS_INST_IVF_FLAT_INTERLEAVED_SCAN(
float, int64_t, cuvs::neighbors::filtering::bitset_filter<uint32_t COMMA int64_t>);

CUVS_INST_IVF_FLAT_INTERLEAVED_SCAN(
half, int64_t, cuvs::neighbors::filtering::bitset_filter<uint32_t COMMA int64_t>);

CUVS_INST_IVF_FLAT_INTERLEAVED_SCAN(
int8_t, int64_t, cuvs::neighbors::filtering::bitset_filter<uint32_t COMMA int64_t>);

CUVS_INST_IVF_FLAT_INTERLEAVED_SCAN(
uint8_t, int64_t, cuvs::neighbors::filtering::bitset_filter<uint32_t COMMA int64_t>);
#undef COMMA
#undef CUVS_INST_IVF_FLAT_INTERLEAVED_SCAN

} // namespace cuvs::neighbors::ivf_flat::detail
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
/*
* Copyright (c) 2025, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ivf_flat_interleaved_scan_explicit_inst.cuh"

namespace cuvs::neighbors::ivf_flat::detail {

CUVS_INST_IVF_FLAT_INTERLEAVED_SCAN(float, int64_t, filtering::none_sample_filter);
} // namespace cuvs::neighbors::ivf_flat::detail
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
/*
* Copyright (c) 2025, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ivf_flat_interleaved_scan_explicit_inst.cuh"

namespace cuvs::neighbors::ivf_flat::detail {

CUVS_INST_IVF_FLAT_INTERLEAVED_SCAN(float,
int64_t,
filtering::bitset_filter<uint32_t COMMA int64_t>);
} // namespace cuvs::neighbors::ivf_flat::detail
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
/*
* Copyright (c) 2025, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "ivf_flat_interleaved_scan_explicit_inst.cuh"
#include <cuda_fp16.h>

namespace cuvs::neighbors::ivf_flat::detail {

CUVS_INST_IVF_FLAT_INTERLEAVED_SCAN(half, int64_t, filtering::none_sample_filter);

} // namespace cuvs::neighbors::ivf_flat::detail
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
/*
* Copyright (c) 2025, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "ivf_flat_interleaved_scan_explicit_inst.cuh"
#include <cuda_fp16.h>

namespace cuvs::neighbors::ivf_flat::detail {

CUVS_INST_IVF_FLAT_INTERLEAVED_SCAN(half,
int64_t,
filtering::bitset_filter<uint32_t COMMA int64_t>);

} // namespace cuvs::neighbors::ivf_flat::detail
Loading