Skip to content

Commit 68480c9

Browse files
authored
Brute force knn tile size heuristic (#316)
This PR modifies the tile size heuristic for brute force knn as mentioned in (#277). It also removes some unneeded cuda calls to save a couple of microseconds which might be relevant when running smaller batches. CC @tfeher Authors: - Malte Förster (https://github.com/mfoerste4) - Ben Frederickson (https://github.com/benfred) - Tamas Bela Feher (https://github.com/tfeher) Approvers: - Ben Frederickson (https://github.com/benfred) - Tamas Bela Feher (https://github.com/tfeher) URL: #316
1 parent 2124789 commit 68480c9

File tree

2 files changed

+36
-41
lines changed

2 files changed

+36
-41
lines changed

cpp/src/neighbors/detail/faiss_distance_utils.h

Lines changed: 21 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -14,39 +14,39 @@ inline void chooseTileSize(size_t numQueries,
1414
size_t numCentroids,
1515
size_t dim,
1616
size_t elementSize,
17-
size_t totalMem,
1817
size_t& tileRows,
1918
size_t& tileCols)
2019
{
20+
// 512 seems to be a batch size sweetspot for float32.
21+
// If we are on float16, increase to 512.
22+
// If the k size (vec dim) of the matrix multiplication is small (<= 32),
23+
// increase to 1024.
24+
size_t preferredTileRows = 512;
25+
if (dim <= 32) { preferredTileRows = 1024; }
26+
27+
tileRows = std::min(preferredTileRows, numQueries);
28+
2129
// The matrix multiplication should be large enough to be efficient, but if
2230
// it is too large, we seem to lose efficiency as opposed to
2331
// double-streaming. Each tile size here defines 1/2 of the memory use due
2432
// to double streaming. We ignore available temporary memory, as that is
2533
// adjusted independently by the user and can thus meet these requirements
2634
// (or not). For <= 4 GB GPUs, prefer 512 MB of usage. For <= 8 GB GPUs,
2735
// prefer 768 MB of usage. Otherwise, prefer 1 GB of usage.
28-
size_t targetUsage = 0;
29-
30-
if (totalMem <= ((size_t)4) * 1024 * 1024 * 1024) {
31-
targetUsage = 512 * 1024 * 1024;
32-
} else if (totalMem <= ((size_t)8) * 1024 * 1024 * 1024) {
33-
targetUsage = 768 * 1024 * 1024;
36+
size_t targetUsage = 512 * 1024 * 1024;
37+
if (tileRows * numCentroids * elementSize * 2 <= targetUsage) {
38+
tileCols = numCentroids;
3439
} else {
35-
targetUsage = 1024 * 1024 * 1024;
36-
}
40+
// only query total memory in case it potentially impacts tilesize
41+
size_t totalMem = rmm::available_device_memory().second;
3742

38-
targetUsage /= 2 * elementSize;
43+
if (totalMem > ((size_t)8) * 1024 * 1024 * 1024) {
44+
targetUsage = 1024 * 1024 * 1024;
45+
} else if (totalMem > ((size_t)4) * 1024 * 1024 * 1024) {
46+
targetUsage = 768 * 1024 * 1024;
47+
}
3948

40-
// 512 seems to be a batch size sweetspot for float32.
41-
// If we are on float16, increase to 512.
42-
// If the k size (vec dim) of the matrix multiplication is small (<= 32),
43-
// increase to 1024.
44-
size_t preferredTileRows = 512;
45-
if (dim <= 32) { preferredTileRows = 1024; }
46-
47-
tileRows = std::min(preferredTileRows, numQueries);
48-
49-
// tileCols is the remainder size
50-
tileCols = std::min(targetUsage / preferredTileRows, numCentroids);
49+
tileCols = std::min(targetUsage / (2 * elementSize * tileRows), numCentroids);
50+
}
5151
}
5252
} // namespace cuvs::neighbors::detail::faiss_select

cpp/src/neighbors/detail/knn_brute_force.cuh

Lines changed: 15 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -81,14 +81,12 @@ void tiled_brute_force_knn(const raft::resources& handle,
8181
const uint32_t* filter_bitmap = nullptr)
8282
{
8383
// Figure out the number of rows/cols to tile for
84-
size_t tile_rows = 0;
85-
size_t tile_cols = 0;
86-
auto stream = raft::resource::get_cuda_stream(handle);
87-
auto device_memory = raft::resource::get_workspace_resource(handle);
88-
auto total_mem = rmm::available_device_memory().second;
84+
size_t tile_rows = 0;
85+
size_t tile_cols = 0;
86+
auto stream = raft::resource::get_cuda_stream(handle);
8987

9088
cuvs::neighbors::detail::faiss_select::chooseTileSize(
91-
m, n, d, sizeof(DistanceT), total_mem, tile_rows, tile_cols);
89+
m, n, d, sizeof(DistanceT), tile_rows, tile_cols);
9290

9391
// for unittesting, its convenient to be able to put a max size on the tiles
9492
// so we can test the tiling logic without having to use huge inputs.
@@ -356,27 +354,26 @@ void brute_force_knn_impl(
356354

357355
ASSERT(input.size() == sizes.size(), "input and sizes vectors should be the same size");
358356

359-
std::vector<IdxType>* id_ranges;
360-
if (translations == nullptr) {
357+
std::vector<IdxType> id_ranges;
358+
if (translations != nullptr) {
359+
// use the given translations
360+
id_ranges.insert(id_ranges.end(), translations->begin(), translations->end());
361+
} else if (input.size() > 1) {
361362
// If we don't have explicit translations
362363
// for offsets of the indices, build them
363364
// from the local partitions
364-
id_ranges = new std::vector<IdxType>();
365365
IdxType total_n = 0;
366366
for (size_t i = 0; i < input.size(); i++) {
367-
id_ranges->push_back(total_n);
367+
id_ranges.push_back(total_n);
368368
total_n += sizes[i];
369369
}
370-
} else {
371-
// otherwise, use the given translations
372-
id_ranges = translations;
373370
}
374371

375-
int device;
376-
RAFT_CUDA_TRY(cudaGetDevice(&device));
377-
378-
rmm::device_uvector<IdxType> trans(id_ranges->size(), userStream);
379-
raft::update_device(trans.data(), id_ranges->data(), id_ranges->size(), userStream);
372+
rmm::device_uvector<IdxType> trans(0, userStream);
373+
if (id_ranges.size() > 0) {
374+
trans.resize(id_ranges.size(), userStream);
375+
raft::update_device(trans.data(), id_ranges.data(), id_ranges.size(), userStream);
376+
}
380377

381378
rmm::device_uvector<DistType> all_D(0, userStream);
382379
rmm::device_uvector<IdxType> all_I(0, userStream);
@@ -513,8 +510,6 @@ void brute_force_knn_impl(
513510
// no translations or partitions to combine, it can be skipped.
514511
knn_merge_parts(out_D, out_I, res_D, res_I, n, input.size(), k, userStream, trans.data());
515512
}
516-
517-
if (translations == nullptr) delete id_ranges;
518513
};
519514

520515
template <typename T,

0 commit comments

Comments
 (0)