Skip to content

Commit 8d4d1a2

Browse files
committed
add more syncs, use thrust_policy
1 parent 4d36c80 commit 8d4d1a2

File tree

1 file changed

+14
-5
lines changed

1 file changed

+14
-5
lines changed

cpp/src/neighbors/detail/nn_descent.cuh

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
#include <raft/core/error.hpp>
2525
#include <raft/core/host_mdarray.hpp>
2626
#include <raft/core/resource/cuda_stream.hpp>
27+
#include <raft/core/resource/thrust_policy.hpp>
2728
#include <raft/core/resources.hpp>
2829

2930
#include <raft/util/arch.cuh> // raft::util::arch::SM_*
@@ -1162,15 +1163,19 @@ GNND<Data_t, Index_t>::GNND(raft::resources const& res, const BuildConfig& build
11621163
{
11631164
static_assert(NUM_SAMPLES <= 32);
11641165

1165-
thrust::fill(thrust::device,
1166+
thrust::fill(raft::resource::get_thrust_policy(res),
11661167
dists_buffer_.data_handle(),
11671168
dists_buffer_.data_handle() + dists_buffer_.size(),
11681169
std::numeric_limits<float>::max());
1169-
thrust::fill(thrust::device,
1170+
thrust::fill(raft::resource::get_thrust_policy(res),
11701171
reinterpret_cast<Index_t*>(graph_buffer_.data_handle()),
11711172
reinterpret_cast<Index_t*>(graph_buffer_.data_handle()) + graph_buffer_.size(),
11721173
std::numeric_limits<Index_t>::max());
1173-
thrust::fill(thrust::device, d_locks_.data_handle(), d_locks_.data_handle() + d_locks_.size(), 0);
1174+
thrust::fill(raft::resource::get_thrust_policy(res),
1175+
d_locks_.data_handle(),
1176+
d_locks_.data_handle() + d_locks_.size(),
1177+
0);
1178+
raft::resource::sync_stream(res);
11741179
};
11751180

11761181
template <typename Data_t, typename Index_t>
@@ -1190,7 +1195,7 @@ void GNND<Data_t, Index_t>::add_reverse_edges(Index_t* graph_ptr,
11901195
template <typename Data_t, typename Index_t>
11911196
void GNND<Data_t, Index_t>::local_join(cudaStream_t stream)
11921197
{
1193-
thrust::fill(thrust::device.on(stream),
1198+
thrust::fill(raft::resource::get_thrust_policy(res),
11941199
dists_buffer_.data_handle(),
11951200
dists_buffer_.data_handle() + dists_buffer_.size(),
11961201
std::numeric_limits<float>::max());
@@ -1209,6 +1214,7 @@ void GNND<Data_t, Index_t>::local_join(cudaStream_t stream)
12091214
DEGREE_ON_DEVICE,
12101215
d_locks_.data_handle(),
12111216
l2_norms_.data_handle());
1217+
raft::resource::sync_stream(res);
12121218
}
12131219

12141220
template <typename Data_t, typename Index_t>
@@ -1240,10 +1246,11 @@ void GNND<Data_t, Index_t>::build(Data_t* data, const Index_t nrow, Index_t* out
12401246
batch.offset());
12411247
}
12421248
1243-
thrust::fill(thrust::device.on(stream),
1249+
thrust::fill(raft::resource::get_thrust_policy(res),
12441250
(Index_t*)graph_buffer_.data_handle(),
12451251
(Index_t*)graph_buffer_.data_handle() + graph_buffer_.size(),
12461252
std::numeric_limits<Index_t>::max());
1253+
raft::resource::sync_stream(res);
12471254
12481255
graph_.clear();
12491256
graph_.init_random_graph();
@@ -1330,6 +1337,7 @@ void GNND<Data_t, Index_t>::build(Data_t* data, const Index_t nrow, Index_t* out
13301337
graph_.sample_graph_new(thrust::raw_pointer_cast(graph_host_buffer_.data()), DEGREE_ON_DEVICE);
13311338
}
13321339
1340+
raft::resource::sync_stream(res);
13331341
graph_.update_graph(thrust::raw_pointer_cast(graph_host_buffer_.data()),
13341342
thrust::raw_pointer_cast(dists_host_buffer_.data()),
13351343
DEGREE_ON_DEVICE,
@@ -1415,6 +1423,7 @@ void build(raft::resources const& res,
14151423
14161424
GNND<const T, int> nnd(res, build_config);
14171425
nnd.build(dataset.data_handle(), dataset.extent(0), int_graph.data_handle());
1426+
raft::resource::sync_stream(res);
14181427
14191428
#pragma omp parallel for
14201429
for (size_t i = 0; i < static_cast<size_t>(dataset.extent(0)); i++) {

0 commit comments

Comments
 (0)