Skip to content

Commit d29a06a

Browse files
authored
Fix single GPU sharded search merge (#1094)
#904 Authors: - Victor Lafargue (https://github.com/viclafargue) - Corey J. Nolet (https://github.com/cjnolet) Approvers: - Corey J. Nolet (https://github.com/cjnolet) URL: #1094
1 parent 99a6b75 commit d29a06a

File tree

1 file changed

+25
-3
lines changed

1 file changed

+25
-3
lines changed

cpp/src/neighbors/mg/snmg.cuh

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright (c) 2024, NVIDIA CORPORATION.
2+
* Copyright (c) 2024-2025, NVIDIA CORPORATION.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -602,7 +602,7 @@ void search(const raft::resources& clique,
602602
int64_t n_batches = raft::ceildiv(n_rows, (int64_t)n_rows_per_batch);
603603
if (n_batches <= 1) n_rows_per_batch = n_rows;
604604

605-
if (merge_mode == MERGE_ON_ROOT_RANK) {
605+
if (merge_mode == MERGE_ON_ROOT_RANK && index.num_ranks_ > 1) {
606606
RAFT_LOG_DEBUG("SHARDED SEARCH WITH MERGE_ON_ROOT_RANK MERGE MODE: %d*%drows",
607607
n_batches,
608608
n_rows_per_batch);
@@ -617,7 +617,7 @@ void search(const raft::resources& clique,
617617
n_cols,
618618
n_neighbors,
619619
n_batches);
620-
} else if (merge_mode == TREE_MERGE) {
620+
} else if (merge_mode == TREE_MERGE && index.num_ranks_ > 1) {
621621
RAFT_LOG_DEBUG(
622622
"SHARDED SEARCH WITH TREE_MERGE MERGE MODE %d*%drows", n_batches, n_rows_per_batch);
623623
sharded_search_with_tree_merge(clique,
@@ -631,6 +631,28 @@ void search(const raft::resources& clique,
631631
n_cols,
632632
n_neighbors,
633633
n_batches);
634+
} else {
635+
const int rank = 0;
636+
#pragma omp parallel for
637+
for (int64_t batch_idx = 0; batch_idx < n_batches; batch_idx++) {
638+
int64_t offset = batch_idx * n_rows_per_batch;
639+
int64_t query_offset = offset * n_cols;
640+
int64_t output_offset = offset * n_neighbors;
641+
int64_t n_rows_of_current_batch = std::min(n_rows_per_batch, n_rows - offset);
642+
643+
run_search_batch(clique,
644+
index,
645+
rank,
646+
search_params,
647+
queries,
648+
neighbors,
649+
distances,
650+
query_offset,
651+
output_offset,
652+
n_rows_of_current_batch,
653+
n_cols,
654+
n_neighbors);
655+
}
634656
}
635657
}
636658
}

0 commit comments

Comments
 (0)