3131#include < cuvs/neighbors/hnsw.hpp>
3232
3333namespace {
34- template <typename T, typename QueriesT >
34+ template <typename T>
3535void _search (cuvsResources_t res,
3636 cuvsHnswSearchParams params,
3737 cuvsHnswIndex index,
@@ -46,7 +46,7 @@ void _search(cuvsResources_t res,
4646 search_params.ef = params.ef ;
4747 search_params.num_threads = params.numThreads ;
4848
49- using queries_mdspan_type = raft::host_matrix_view<QueriesT const , int64_t , raft::row_major>;
49+ using queries_mdspan_type = raft::host_matrix_view<T const , int64_t , raft::row_major>;
5050 using neighbors_mdspan_type = raft::host_matrix_view<uint64_t , int64_t , raft::row_major>;
5151 using distances_mdspan_type = raft::host_matrix_view<float , int64_t , raft::row_major>;
5252 auto queries_mds = cuvs::core::from_dlpack<queries_mdspan_type>(queries_tensor);
@@ -127,16 +127,13 @@ extern "C" cuvsError_t cuvsHnswSearch(cuvsResources_t res,
127127
128128 auto index = *index_c_ptr;
129129 RAFT_EXPECTS (queries.dtype .code == index.dtype .code , " type mismatch between index and queries" );
130- RAFT_EXPECTS (queries.dtype .bits == 32 , " number of bits in queries dtype should be 32" );
131130
132131 if (index.dtype .code == kDLFloat ) {
133- _search<float , float >(
134- res, *params, index, queries_tensor, neighbors_tensor, distances_tensor);
132+ _search<float >(res, *params, index, queries_tensor, neighbors_tensor, distances_tensor);
135133 } else if (index.dtype .code == kDLUInt ) {
136- _search<uint8_t , int >(
137- res, *params, index, queries_tensor, neighbors_tensor, distances_tensor);
134+ _search<uint8_t >(res, *params, index, queries_tensor, neighbors_tensor, distances_tensor);
138135 } else if (index.dtype .code == kDLInt ) {
139- _search<int8_t , int >(res, *params, index, queries_tensor, neighbors_tensor, distances_tensor);
136+ _search<int8_t >(res, *params, index, queries_tensor, neighbors_tensor, distances_tensor);
140137 } else {
141138 RAFT_FAIL (" Unsupported index dtype: %d and bits: %d" , queries.dtype .code , queries.dtype .bits );
142139 }
@@ -152,13 +149,10 @@ extern "C" cuvsError_t cuvsHnswDeserialize(cuvsResources_t res,
152149 return cuvs::core::translate_exceptions ([=] {
153150 if (index->dtype .code == kDLFloat && index->dtype .bits == 32 ) {
154151 index->addr = reinterpret_cast <uintptr_t >(_deserialize<float >(res, filename, dim, metric));
155- index->dtype .code = kDLFloat ;
156152 } else if (index->dtype .code == kDLUInt && index->dtype .bits == 8 ) {
157153 index->addr = reinterpret_cast <uintptr_t >(_deserialize<uint8_t >(res, filename, dim, metric));
158- index->dtype .code = kDLInt ;
159154 } else if (index->dtype .code == kDLInt && index->dtype .bits == 8 ) {
160155 index->addr = reinterpret_cast <uintptr_t >(_deserialize<int8_t >(res, filename, dim, metric));
161- index->dtype .code = kDLUInt ;
162156 } else {
163157 RAFT_FAIL (" Unsupported dtype in file %s" , filename);
164158 }
0 commit comments