2424#include < raft/core/error.hpp>
2525#include < raft/core/host_mdarray.hpp>
2626#include < raft/core/resource/cuda_stream.hpp>
27+ #include < raft/core/resource/thrust_policy.hpp>
2728#include < raft/core/resources.hpp>
2829
2930#include < raft/util/arch.cuh> // raft::util::arch::SM_*
@@ -1162,15 +1163,19 @@ GNND<Data_t, Index_t>::GNND(raft::resources const& res, const BuildConfig& build
11621163{
11631164 static_assert (NUM_SAMPLES <= 32 );
11641165
1165- thrust::fill (thrust::device ,
1166+ thrust::fill (raft::resource::get_thrust_policy (res) ,
11661167 dists_buffer_.data_handle (),
11671168 dists_buffer_.data_handle () + dists_buffer_.size (),
11681169 std::numeric_limits<float >::max ());
1169- thrust::fill (thrust::device ,
1170+ thrust::fill (raft::resource::get_thrust_policy (res) ,
11701171 reinterpret_cast <Index_t*>(graph_buffer_.data_handle ()),
11711172 reinterpret_cast <Index_t*>(graph_buffer_.data_handle ()) + graph_buffer_.size (),
11721173 std::numeric_limits<Index_t>::max ());
1173- thrust::fill (thrust::device, d_locks_.data_handle (), d_locks_.data_handle () + d_locks_.size (), 0 );
1174+ thrust::fill (raft::resource::get_thrust_policy (res),
1175+ d_locks_.data_handle (),
1176+ d_locks_.data_handle () + d_locks_.size (),
1177+ 0 );
1178+ raft::resource::sync_stream (res);
11741179};
11751180
11761181template <typename Data_t, typename Index_t>
@@ -1190,7 +1195,7 @@ void GNND<Data_t, Index_t>::add_reverse_edges(Index_t* graph_ptr,
11901195template <typename Data_t, typename Index_t>
11911196void GNND<Data_t, Index_t>::local_join(cudaStream_t stream)
11921197{
1193- thrust::fill (thrust::device. on (stream ),
1198+ thrust::fill (raft::resource::get_thrust_policy (res ),
11941199 dists_buffer_.data_handle (),
11951200 dists_buffer_.data_handle () + dists_buffer_.size (),
11961201 std::numeric_limits<float >::max ());
@@ -1209,6 +1214,7 @@ void GNND<Data_t, Index_t>::local_join(cudaStream_t stream)
12091214 DEGREE_ON_DEVICE,
12101215 d_locks_.data_handle (),
12111216 l2_norms_.data_handle ());
1217+ raft::resource::sync_stream (res);
12121218}
12131219
12141220template <typename Data_t, typename Index_t>
@@ -1240,10 +1246,11 @@ void GNND<Data_t, Index_t>::build(Data_t* data, const Index_t nrow, Index_t* out
12401246 batch.offset());
12411247 }
12421248
1243- thrust::fill (thrust::device.on(stream ),
1249+ thrust::fill (raft::resource::get_thrust_policy(res ),
12441250 (Index_t*)graph_buffer_.data_handle(),
12451251 (Index_t*)graph_buffer_.data_handle() + graph_buffer_.size(),
12461252 std::numeric_limits<Index_t>::max());
1253+ raft::resource::sync_stream (res);
12471254
12481255 graph_.clear();
12491256 graph_.init_random_graph();
@@ -1330,6 +1337,7 @@ void GNND<Data_t, Index_t>::build(Data_t* data, const Index_t nrow, Index_t* out
13301337 graph_.sample_graph_new (thrust::raw_pointer_cast (graph_host_buffer_.data ()), DEGREE_ON_DEVICE);
13311338 }
13321339
1340+ raft::resource::sync_stream (res);
13331341 graph_.update_graph(thrust::raw_pointer_cast(graph_host_buffer_.data()),
13341342 thrust::raw_pointer_cast (dists_host_buffer_.data()),
13351343 DEGREE_ON_DEVICE,
@@ -1415,6 +1423,7 @@ void build(raft::resources const& res,
14151423
14161424 GNND<const T, int > nnd (res, build_config);
14171425 nnd.build (dataset.data_handle (), dataset.extent (0 ), int_graph.data_handle ());
1426+ raft::resource::sync_stream (res);
14181427
14191428#pragma omp parallel for
14201429 for (size_t i = 0 ; i < static_cast <size_t >(dataset.extent (0 )); i++) {
0 commit comments