11/*
2- * Copyright (c) 2024, NVIDIA CORPORATION.
2+ * Copyright (c) 2024-2025 , NVIDIA CORPORATION.
33 *
44 * Licensed under the Apache License, Version 2.0 (the "License");
55 * you may not use this file except in compliance with the License.
@@ -602,7 +602,7 @@ void search(const raft::resources& clique,
602602 int64_t n_batches = raft::ceildiv (n_rows, (int64_t )n_rows_per_batch);
603603 if (n_batches <= 1 ) n_rows_per_batch = n_rows;
604604
605- if (merge_mode == MERGE_ON_ROOT_RANK) {
605+ if (merge_mode == MERGE_ON_ROOT_RANK && index. num_ranks_ > 1 ) {
606606 RAFT_LOG_DEBUG (" SHARDED SEARCH WITH MERGE_ON_ROOT_RANK MERGE MODE: %d*%drows" ,
607607 n_batches,
608608 n_rows_per_batch);
@@ -617,7 +617,7 @@ void search(const raft::resources& clique,
617617 n_cols,
618618 n_neighbors,
619619 n_batches);
620- } else if (merge_mode == TREE_MERGE) {
620+ } else if (merge_mode == TREE_MERGE && index. num_ranks_ > 1 ) {
621621 RAFT_LOG_DEBUG (
622622 " SHARDED SEARCH WITH TREE_MERGE MERGE MODE %d*%drows" , n_batches, n_rows_per_batch);
623623 sharded_search_with_tree_merge (clique,
@@ -631,6 +631,28 @@ void search(const raft::resources& clique,
631631 n_cols,
632632 n_neighbors,
633633 n_batches);
634+ } else {
635+ const int rank = 0 ;
636+ #pragma omp parallel for
637+ for (int64_t batch_idx = 0 ; batch_idx < n_batches; batch_idx++) {
638+ int64_t offset = batch_idx * n_rows_per_batch;
639+ int64_t query_offset = offset * n_cols;
640+ int64_t output_offset = offset * n_neighbors;
641+ int64_t n_rows_of_current_batch = std::min (n_rows_per_batch, n_rows - offset);
642+
643+ run_search_batch (clique,
644+ index,
645+ rank,
646+ search_params,
647+ queries,
648+ neighbors,
649+ distances,
650+ query_offset,
651+ output_offset,
652+ n_rows_of_current_batch,
653+ n_cols,
654+ n_neighbors);
655+ }
634656 }
635657 }
636658}
0 commit comments