Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 25 additions & 3 deletions cpp/src/neighbors/mg/snmg.cuh
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2024, NVIDIA CORPORATION.
* Copyright (c) 2024-2025, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -602,7 +602,7 @@ void search(const raft::resources& clique,
int64_t n_batches = raft::ceildiv(n_rows, (int64_t)n_rows_per_batch);
if (n_batches <= 1) n_rows_per_batch = n_rows;

if (merge_mode == MERGE_ON_ROOT_RANK) {
if (merge_mode == MERGE_ON_ROOT_RANK && index.num_ranks_ > 1) {
RAFT_LOG_DEBUG("SHARDED SEARCH WITH MERGE_ON_ROOT_RANK MERGE MODE: %d*%drows",
n_batches,
n_rows_per_batch);
Expand All @@ -617,7 +617,7 @@ void search(const raft::resources& clique,
n_cols,
n_neighbors,
n_batches);
} else if (merge_mode == TREE_MERGE) {
} else if (merge_mode == TREE_MERGE && index.num_ranks_ > 1) {
RAFT_LOG_DEBUG(
"SHARDED SEARCH WITH TREE_MERGE MERGE MODE %d*%drows", n_batches, n_rows_per_batch);
sharded_search_with_tree_merge(clique,
Expand All @@ -631,6 +631,28 @@ void search(const raft::resources& clique,
n_cols,
n_neighbors,
n_batches);
} else {
const int rank = 0;
#pragma omp parallel for
for (int64_t batch_idx = 0; batch_idx < n_batches; batch_idx++) {
int64_t offset = batch_idx * n_rows_per_batch;
int64_t query_offset = offset * n_cols;
int64_t output_offset = offset * n_neighbors;
int64_t n_rows_of_current_batch = std::min(n_rows_per_batch, n_rows - offset);

run_search_batch(clique,
index,
rank,
search_params,
queries,
neighbors,
distances,
query_offset,
output_offset,
n_rows_of_current_batch,
n_cols,
n_neighbors);
}
}
}
}
Expand Down