Skip to content

Commit 4322b2f

Browse files
committed
VPQ distance: don't pass n_subspace as parameter, because it can be cheaply computed from dim and PQ_LEN
1 parent 6fac19b commit 4322b2f

File tree

2 files changed

+6
-26
lines changed

2 files changed

+6
-26
lines changed

cpp/src/neighbors/detail/cagra/compute_distance_vpq-impl.cuh

Lines changed: 6 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,6 @@ struct cagra_q_dataset_descriptor_t : public dataset_descriptor_base_t<DataT, In
6060
// const CODE_BOOK_T* vq_code_book_ptr;
6161
// const CODE_BOOK_T* pq_code_book_ptr;
6262
// std::uint32_t encoded_dataset_dim;
63-
// std::uint32_t n_subspace;
6463

6564
RAFT_INLINE_FUNCTION static constexpr auto encoded_dataset_ptr(args_t& args) noexcept
6665
-> const uint8_t*&
@@ -80,10 +79,6 @@ struct cagra_q_dataset_descriptor_t : public dataset_descriptor_base_t<DataT, In
8079
{
8180
return args.extra_word1;
8281
}
83-
RAFT_INLINE_FUNCTION static constexpr auto n_subspace(args_t& args) noexcept -> uint32_t&
84-
{
85-
return args.extra_word2;
86-
}
8782

8883
RAFT_INLINE_FUNCTION static constexpr auto encoded_dataset_ptr(const args_t& args) noexcept
8984
-> const uint8_t* const&
@@ -104,11 +99,6 @@ struct cagra_q_dataset_descriptor_t : public dataset_descriptor_base_t<DataT, In
10499
{
105100
return args.extra_word1;
106101
}
107-
RAFT_INLINE_FUNCTION static constexpr auto n_subspace(const args_t& args) noexcept
108-
-> const uint32_t&
109-
{
110-
return args.extra_word2;
111-
}
112102

113103
static constexpr std::uint32_t kSMemCodeBookSizeInBytes =
114104
(1 << PQ_BITS) * PQ_LEN * utils::size_of<CODE_BOOK_T>();
@@ -117,7 +107,6 @@ struct cagra_q_dataset_descriptor_t : public dataset_descriptor_base_t<DataT, In
117107
compute_distance_type* compute_distance_impl,
118108
const std::uint8_t* encoded_dataset_ptr,
119109
std::uint32_t encoded_dataset_dim,
120-
std::uint32_t n_subspace,
121110
const CODE_BOOK_T* vq_code_book_ptr,
122111
const CODE_BOOK_T* pq_code_book_ptr,
123112
IndexT size,
@@ -133,7 +122,6 @@ struct cagra_q_dataset_descriptor_t : public dataset_descriptor_base_t<DataT, In
133122
cagra_q_dataset_descriptor_t::vq_code_book_ptr(args) = vq_code_book_ptr;
134123
this->pq_code_book_ptr() = pq_code_book_ptr;
135124
cagra_q_dataset_descriptor_t::encoded_dataset_dim(args) = encoded_dataset_dim;
136-
cagra_q_dataset_descriptor_t::n_subspace(args) = n_subspace;
137125
static_assert(sizeof(*this) == sizeof(base_type));
138126
static_assert(alignof(cagra_q_dataset_descriptor_t) == alignof(base_type));
139127
}
@@ -241,8 +229,7 @@ _RAFT_DEVICE RAFT_DEVICE_INLINE_FUNCTION auto compute_distance_vpq_worker(
241229
const uint8_t* __restrict__ dataset_ptr,
242230
const typename DescriptorT::CODE_BOOK_T* __restrict__ vq_code_book_ptr,
243231
uint32_t dim,
244-
uint32_t pq_codebook_ptr,
245-
uint32_t n_subspace) -> typename DescriptorT::DISTANCE_T
232+
uint32_t pq_codebook_ptr) -> typename DescriptorT::DISTANCE_T
246233
{
247234
using DISTANCE_T = typename DescriptorT::DISTANCE_T;
248235
using LOAD_T = typename DescriptorT::LOAD_T;
@@ -262,8 +249,9 @@ _RAFT_DEVICE RAFT_DEVICE_INLINE_FUNCTION auto compute_distance_vpq_worker(
262249
constexpr auto kTeamMask = DescriptorT::kTeamSize - 1;
263250
constexpr auto kTeamVLen = TeamSize * vlen;
264251

265-
const auto laneId = threadIdx.x & kTeamMask;
266-
DISTANCE_T norm = 0;
252+
const auto n_subspace = raft::div_rounding_up_unsafe(dim, PQ_LEN);
253+
const auto laneId = threadIdx.x & kTeamMask;
254+
DISTANCE_T norm = 0;
267255
for (uint32_t elem_offset = 0; elem_offset * PQ_LEN < dim;
268256
elem_offset += DatasetBlockDim / PQ_LEN) {
269257
// Loading PQ codes
@@ -369,11 +357,10 @@ _RAFT_DEVICE __noinline__ auto compute_distance_vpq(
369357
uint32_t vq_code;
370358
device::ldg_cg(vq_code, reinterpret_cast<const std::uint32_t*>(dataset_ptr));
371359
return compute_distance_vpq_worker<DescriptorT>(
372-
dataset_ptr,
360+
dataset_ptr + 4 /* advance dataset pointer by the size of vq_code */,
373361
DescriptorT::vq_code_book_ptr(args) + args.dim * vq_code,
374362
args.dim,
375-
args.smem_ws_ptr,
376-
DescriptorT::n_subspace(args));
363+
args.smem_ws_ptr);
377364
}
378365

379366
template <cuvs::distance::DistanceType Metric,
@@ -389,7 +376,6 @@ RAFT_KERNEL __launch_bounds__(1, 1)
389376
vpq_dataset_descriptor_init_kernel(dataset_descriptor_base_t<DataT, IndexT, DistanceT>* out,
390377
const std::uint8_t* encoded_dataset_ptr,
391378
uint32_t encoded_dataset_dim,
392-
uint32_t n_subspace,
393379
const CodebookT* vq_code_book_ptr,
394380
const CodebookT* pq_code_book_ptr,
395381
IndexT size,
@@ -410,7 +396,6 @@ RAFT_KERNEL __launch_bounds__(1, 1)
410396
reinterpret_cast<typename base_type::compute_distance_type*>(&compute_distance_vpq<desc_type>),
411397
encoded_dataset_ptr,
412398
encoded_dataset_dim,
413-
n_subspace,
414399
vq_code_book_ptr,
415400
pq_code_book_ptr,
416401
size,
@@ -438,7 +423,6 @@ vpq_descriptor_spec<Metric,
438423
DistanceT>::init_(const cagra::search_params& params,
439424
const std::uint8_t* encoded_dataset_ptr,
440425
uint32_t encoded_dataset_dim,
441-
uint32_t n_subspace,
442426
const CodebookT* vq_code_book_ptr,
443427
const CodebookT* pq_code_book_ptr,
444428
IndexT size,
@@ -460,7 +444,6 @@ vpq_descriptor_spec<Metric,
460444
nullptr,
461445
encoded_dataset_ptr,
462446
encoded_dataset_dim,
463-
n_subspace,
464447
vq_code_book_ptr,
465448
pq_code_book_ptr,
466449
size,
@@ -477,7 +460,6 @@ vpq_descriptor_spec<Metric,
477460
DistanceT><<<1, 1, 0, stream>>>(result.dev_ptr,
478461
encoded_dataset_ptr,
479462
encoded_dataset_dim,
480-
n_subspace,
481463
vq_code_book_ptr,
482464
pq_code_book_ptr,
483465
size,

cpp/src/neighbors/detail/cagra/compute_distance_vpq.hpp

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,6 @@ struct vpq_descriptor_spec : public instance_spec<DataT, IndexT, DistanceT> {
6363
return init_(params,
6464
dataset.data.data_handle(),
6565
dataset.encoded_row_length(),
66-
dataset.pq_dim(),
6766
dataset.vq_code_book.data_handle(),
6867
dataset.pq_code_book.data_handle(),
6968
IndexT(dataset.n_rows()),
@@ -91,7 +90,6 @@ struct vpq_descriptor_spec : public instance_spec<DataT, IndexT, DistanceT> {
9190
const cagra::search_params& params,
9291
const std::uint8_t* encoded_dataset_ptr,
9392
uint32_t encoded_dataset_dim,
94-
uint32_t n_subspace,
9593
const CodebookT* vq_code_book_ptr,
9694
const CodebookT* pq_code_book_ptr,
9795
IndexT size,

0 commit comments

Comments
 (0)