@@ -60,7 +60,6 @@ struct cagra_q_dataset_descriptor_t : public dataset_descriptor_base_t<DataT, In
6060 // const CODE_BOOK_T* vq_code_book_ptr;
6161 // const CODE_BOOK_T* pq_code_book_ptr;
6262 // std::uint32_t encoded_dataset_dim;
63- // std::uint32_t n_subspace;
6463
6564 RAFT_INLINE_FUNCTION static constexpr auto encoded_dataset_ptr (args_t & args) noexcept
6665 -> const uint8_t*&
@@ -80,10 +79,6 @@ struct cagra_q_dataset_descriptor_t : public dataset_descriptor_base_t<DataT, In
8079 {
8180 return args.extra_word1 ;
8281 }
83- RAFT_INLINE_FUNCTION static constexpr auto n_subspace (args_t & args) noexcept -> uint32_t&
84- {
85- return args.extra_word2 ;
86- }
8782
8883 RAFT_INLINE_FUNCTION static constexpr auto encoded_dataset_ptr (const args_t & args) noexcept
8984 -> const uint8_t* const &
@@ -104,11 +99,6 @@ struct cagra_q_dataset_descriptor_t : public dataset_descriptor_base_t<DataT, In
10499 {
105100 return args.extra_word1 ;
106101 }
107- RAFT_INLINE_FUNCTION static constexpr auto n_subspace (const args_t & args) noexcept
108- -> const uint32_t&
109- {
110- return args.extra_word2 ;
111- }
112102
113103 static constexpr std::uint32_t kSMemCodeBookSizeInBytes =
114104 (1 << PQ_BITS) * PQ_LEN * utils::size_of<CODE_BOOK_T>();
@@ -117,7 +107,6 @@ struct cagra_q_dataset_descriptor_t : public dataset_descriptor_base_t<DataT, In
117107 compute_distance_type* compute_distance_impl,
118108 const std::uint8_t * encoded_dataset_ptr,
119109 std::uint32_t encoded_dataset_dim,
120- std::uint32_t n_subspace,
121110 const CODE_BOOK_T* vq_code_book_ptr,
122111 const CODE_BOOK_T* pq_code_book_ptr,
123112 IndexT size,
@@ -133,7 +122,6 @@ struct cagra_q_dataset_descriptor_t : public dataset_descriptor_base_t<DataT, In
133122 cagra_q_dataset_descriptor_t::vq_code_book_ptr (args) = vq_code_book_ptr;
134123 this ->pq_code_book_ptr () = pq_code_book_ptr;
135124 cagra_q_dataset_descriptor_t::encoded_dataset_dim (args) = encoded_dataset_dim;
136- cagra_q_dataset_descriptor_t::n_subspace (args) = n_subspace;
137125 static_assert (sizeof (*this ) == sizeof (base_type));
138126 static_assert (alignof (cagra_q_dataset_descriptor_t ) == alignof (base_type));
139127 }
@@ -241,8 +229,7 @@ _RAFT_DEVICE RAFT_DEVICE_INLINE_FUNCTION auto compute_distance_vpq_worker(
241229 const uint8_t * __restrict__ dataset_ptr,
242230 const typename DescriptorT::CODE_BOOK_T* __restrict__ vq_code_book_ptr,
243231 uint32_t dim,
244- uint32_t pq_codebook_ptr,
245- uint32_t n_subspace) -> typename DescriptorT::DISTANCE_T
232+ uint32_t pq_codebook_ptr) -> typename DescriptorT::DISTANCE_T
246233{
247234 using DISTANCE_T = typename DescriptorT::DISTANCE_T;
248235 using LOAD_T = typename DescriptorT::LOAD_T;
@@ -262,8 +249,9 @@ _RAFT_DEVICE RAFT_DEVICE_INLINE_FUNCTION auto compute_distance_vpq_worker(
262249 constexpr auto kTeamMask = DescriptorT::kTeamSize - 1 ;
263250 constexpr auto kTeamVLen = TeamSize * vlen;
264251
265- const auto laneId = threadIdx .x & kTeamMask ;
266- DISTANCE_T norm = 0 ;
252+ const auto n_subspace = raft::div_rounding_up_unsafe (dim, PQ_LEN);
253+ const auto laneId = threadIdx .x & kTeamMask ;
254+ DISTANCE_T norm = 0 ;
267255 for (uint32_t elem_offset = 0 ; elem_offset * PQ_LEN < dim;
268256 elem_offset += DatasetBlockDim / PQ_LEN) {
269257 // Loading PQ codes
@@ -369,11 +357,10 @@ _RAFT_DEVICE __noinline__ auto compute_distance_vpq(
369357 uint32_t vq_code;
370358 device::ldg_cg (vq_code, reinterpret_cast <const std::uint32_t *>(dataset_ptr));
371359 return compute_distance_vpq_worker<DescriptorT>(
372- dataset_ptr,
360+ dataset_ptr + 4 /* advance dataset pointer by the size of vq_code */ ,
373361 DescriptorT::vq_code_book_ptr (args) + args.dim * vq_code,
374362 args.dim ,
375- args.smem_ws_ptr ,
376- DescriptorT::n_subspace (args));
363+ args.smem_ws_ptr );
377364}
378365
379366template <cuvs::distance::DistanceType Metric,
@@ -389,7 +376,6 @@ RAFT_KERNEL __launch_bounds__(1, 1)
389376 vpq_dataset_descriptor_init_kernel(dataset_descriptor_base_t <DataT, IndexT, DistanceT>* out,
390377 const std::uint8_t * encoded_dataset_ptr,
391378 uint32_t encoded_dataset_dim,
392- uint32_t n_subspace,
393379 const CodebookT* vq_code_book_ptr,
394380 const CodebookT* pq_code_book_ptr,
395381 IndexT size,
@@ -410,7 +396,6 @@ RAFT_KERNEL __launch_bounds__(1, 1)
410396 reinterpret_cast <typename base_type::compute_distance_type*>(&compute_distance_vpq<desc_type>),
411397 encoded_dataset_ptr,
412398 encoded_dataset_dim,
413- n_subspace,
414399 vq_code_book_ptr,
415400 pq_code_book_ptr,
416401 size,
@@ -438,7 +423,6 @@ vpq_descriptor_spec<Metric,
438423 DistanceT>::init_(const cagra::search_params& params,
439424 const std::uint8_t * encoded_dataset_ptr,
440425 uint32_t encoded_dataset_dim,
441- uint32_t n_subspace,
442426 const CodebookT* vq_code_book_ptr,
443427 const CodebookT* pq_code_book_ptr,
444428 IndexT size,
@@ -460,7 +444,6 @@ vpq_descriptor_spec<Metric,
460444 nullptr ,
461445 encoded_dataset_ptr,
462446 encoded_dataset_dim,
463- n_subspace,
464447 vq_code_book_ptr,
465448 pq_code_book_ptr,
466449 size,
@@ -477,7 +460,6 @@ vpq_descriptor_spec<Metric,
477460 DistanceT><<<1 , 1 , 0 , stream>>> (result.dev_ptr ,
478461 encoded_dataset_ptr,
479462 encoded_dataset_dim,
480- n_subspace,
481463 vq_code_book_ptr,
482464 pq_code_book_ptr,
483465 size,
0 commit comments