1717#pragma once
1818#include " ./knn_brute_force.cuh"
1919
20+ #include < raft/core/resource/cuda_stream.hpp>
21+ #include < raft/core/resources.hpp>
22+ #include < raft/linalg/map.cuh>
2023#include < raft/linalg/unary_op.cuh>
2124#include < raft/sparse/convert/csr.cuh>
2225#include < raft/sparse/linalg/symmetrize.cuh>
2326#include < raft/util/cuda_utils.cuh>
2427#include < raft/util/cudart_utils.hpp>
2528
2629#include < rmm/device_uvector.hpp>
27- #include < rmm/exec_policy.hpp>
28-
29- #include < thrust/iterator/counting_iterator.h>
30- #include < thrust/iterator/zip_iterator.h>
31- #include < thrust/transform.h>
32- #include < thrust/tuple.h>
3330
3431namespace cuvs ::neighbors::detail::reachability {
3532
@@ -47,17 +44,19 @@ namespace cuvs::neighbors::detail::reachability {
4744 * @param[in] stream stream for which to order cuda operations
4845 */
4946template <typename value_idx, typename value_t , int tpb = 256 >
50- void core_distances (
51- value_t * knn_dists, int min_samples, int n_neighbors, size_t n, value_t * out, cudaStream_t stream)
47+ void core_distances (raft::resources const & handle,
48+ value_t * knn_dists,
49+ int min_samples,
50+ int n_neighbors,
51+ size_t n,
52+ value_t * out)
5253{
5354 ASSERT (n_neighbors >= min_samples,
5455 " the size of the neighborhood should be greater than or equal to min_samples" );
5556
56- auto exec_policy = rmm::exec_policy (stream);
57-
58- auto indices = thrust::make_counting_iterator<value_idx>(0 );
57+ auto out_view = raft::make_device_vector_view<value_t , value_idx>(out, n);
5958
60- thrust::transform (exec_policy, indices, indices + n, out , [=] __device__ (value_idx row) {
59+ raft::linalg::map_offset (handle, out_view , [=] __device__ (value_idx row) {
6160 return knn_dists[row * n_neighbors + (min_samples - 1 )];
6261 });
6362}
@@ -118,7 +117,7 @@ void _compute_core_dists(const raft::resources& handle,
118117 compute_knn (handle, X, inds.data (), dists.data (), m, n, X, m, min_samples, metric);
119118
120119 // Slice core distances (distances to kth nearest neighbor)
121- core_distances<value_idx>(dists.data (), min_samples, min_samples, m, core_dists, stream );
120+ core_distances<value_idx>(handle, dists.data (), min_samples, min_samples, m, core_dists);
122121}
123122
124123// Functor to post-process distances into reachability space
@@ -202,8 +201,7 @@ void mutual_reachability_graph(const raft::resources& handle,
202201 RAFT_EXPECTS (metric == cuvs::distance::DistanceType::L2SqrtExpanded,
203202 " Currently only L2 expanded distance is supported" );
204203
205- auto stream = raft::resource::get_cuda_stream (handle);
206- auto exec_policy = raft::resource::get_thrust_policy (handle);
204+ auto stream = raft::resource::get_cuda_stream (handle);
207205
208206 rmm::device_uvector<value_idx> coo_rows (min_samples * m, stream);
209207 rmm::device_uvector<value_idx> inds (min_samples * m, stream);
@@ -213,7 +211,7 @@ void mutual_reachability_graph(const raft::resources& handle,
213211 compute_knn (handle, X, inds.data (), dists.data (), m, n, X, m, min_samples, metric);
214212
215213 // Slice core distances (distances to kth nearest neighbor)
216- core_distances<value_idx>(dists.data (), min_samples, min_samples, m, core_dists, stream );
214+ core_distances<value_idx>(handle, dists.data (), min_samples, min_samples, m, core_dists);
217215
218216 /* *
219217 * Compute L2 norm
@@ -222,12 +220,12 @@ void mutual_reachability_graph(const raft::resources& handle,
222220 handle, inds.data (), dists.data (), X, m, n, min_samples, core_dists, (value_t )1.0 / alpha);
223221
224222 // self-loops get max distance
225- auto coo_rows_counting_itr = thrust::make_counting_iterator<value_idx>( 0 );
226- thrust::transform (exec_policy,
227- coo_rows_counting_itr,
228- coo_rows_counting_itr + (m * min_samples),
229- coo_rows. data (),
230- [min_samples] __device__ (value_idx c) -> value_idx { return c / min_samples; });
223+ auto coo_rows_view =
224+ raft::make_device_vector_view<value_idx, value_idx>(coo_rows. data (), m * min_samples);
225+ raft::linalg::map_offset (
226+ handle, coo_rows_view, [min_samples] __device__ (value_idx c) -> value_idx {
227+ return c / min_samples;
228+ });
231229
232230 raft::sparse::linalg::symmetrize (handle,
233231 coo_rows.data (),
@@ -241,18 +239,20 @@ void mutual_reachability_graph(const raft::resources& handle,
241239 raft::sparse::convert::sorted_coo_to_csr (out.rows (), out.nnz , indptr, m + 1 , stream);
242240
243241 // self-loops get max distance
244- auto transform_in =
245- thrust::make_zip_iterator (thrust::make_tuple (out.rows (), out.cols (), out.vals ()));
242+ auto rows_view = raft::make_device_vector_view<const value_idx, nnz_t >(out.rows (), out.nnz );
243+ auto cols_view = raft::make_device_vector_view<const value_idx, nnz_t >(out.cols (), out.nnz );
244+ auto vals_in_view = raft::make_device_vector_view<const value_t , nnz_t >(out.vals (), out.nnz );
245+ auto vals_out_view = raft::make_device_vector_view<value_t , nnz_t >(out.vals (), out.nnz );
246246
247- thrust::transform (exec_policy,
248- transform_in ,
249- transform_in + out. nnz ,
250- out. vals (),
251- [=] __device__ ( const thrust::tuple<value_idx, value_idx, value_t >& tup) {
252- return thrust::get< 0 >(tup) == thrust::get< 1 >(tup)
253- ? std::numeric_limits< value_t >:: max ()
254- : thrust::get< 2 >(tup);
255- } );
247+ raft::linalg::map (
248+ handle ,
249+ vals_out_view ,
250+ [=] __device__ ( const value_idx row, const value_idx col, const value_t val) {
251+ return row == col ? std::numeric_limits< value_t >:: max () : val;
252+ },
253+ rows_view,
254+ cols_view,
255+ vals_in_view );
256256}
257257
258258} // namespace cuvs::neighbors::detail::reachability
0 commit comments