Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 21 additions & 21 deletions cpp/src/neighbors/detail/faiss_distance_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,39 +14,39 @@ inline void chooseTileSize(size_t numQueries,
size_t numCentroids,
size_t dim,
size_t elementSize,
size_t totalMem,
size_t& tileRows,
size_t& tileCols)
{
// 512 seems to be a batch size sweetspot for float32.
// If we are on float16, increase to 512.
// If the k size (vec dim) of the matrix multiplication is small (<= 32),
// increase to 1024.
size_t preferredTileRows = 512;
if (dim <= 32) { preferredTileRows = 1024; }

tileRows = std::min(preferredTileRows, numQueries);

// The matrix multiplication should be large enough to be efficient, but if
// it is too large, we seem to lose efficiency as opposed to
// double-streaming. Each tile size here defines 1/2 of the memory use due
// to double streaming. We ignore available temporary memory, as that is
// adjusted independently by the user and can thus meet these requirements
// (or not). For <= 4 GB GPUs, prefer 512 MB of usage. For <= 8 GB GPUs,
// prefer 768 MB of usage. Otherwise, prefer 1 GB of usage.
size_t targetUsage = 0;

if (totalMem <= ((size_t)4) * 1024 * 1024 * 1024) {
targetUsage = 512 * 1024 * 1024;
} else if (totalMem <= ((size_t)8) * 1024 * 1024 * 1024) {
targetUsage = 768 * 1024 * 1024;
size_t targetUsage = 512 * 1024 * 1024;
if (tileRows * numCentroids * elementSize * 2 <= targetUsage) {
tileCols = numCentroids;
} else {
targetUsage = 1024 * 1024 * 1024;
}
// only query total memory in case it potentially impacts tilesize
size_t totalMem = rmm::available_device_memory().second;

targetUsage /= 2 * elementSize;
if (totalMem > ((size_t)8) * 1024 * 1024 * 1024) {
targetUsage = 1024 * 1024 * 1024;
} else if (totalMem > ((size_t)4) * 1024 * 1024 * 1024) {
targetUsage = 768 * 1024 * 1024;
}

// 512 seems to be a batch size sweetspot for float32.
// If we are on float16, increase to 512.
// If the k size (vec dim) of the matrix multiplication is small (<= 32),
// increase to 1024.
size_t preferredTileRows = 512;
if (dim <= 32) { preferredTileRows = 1024; }

tileRows = std::min(preferredTileRows, numQueries);

// tileCols is the remainder size
tileCols = std::min(targetUsage / preferredTileRows, numCentroids);
tileCols = std::min(targetUsage / (2 * elementSize * tileRows), numCentroids);
}
}
} // namespace cuvs::neighbors::detail::faiss_select
35 changes: 15 additions & 20 deletions cpp/src/neighbors/detail/knn_brute_force.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -81,14 +81,12 @@ void tiled_brute_force_knn(const raft::resources& handle,
const uint32_t* filter_bitmap = nullptr)
{
// Figure out the number of rows/cols to tile for
size_t tile_rows = 0;
size_t tile_cols = 0;
auto stream = raft::resource::get_cuda_stream(handle);
auto device_memory = raft::resource::get_workspace_resource(handle);
auto total_mem = rmm::available_device_memory().second;
size_t tile_rows = 0;
size_t tile_cols = 0;
auto stream = raft::resource::get_cuda_stream(handle);

cuvs::neighbors::detail::faiss_select::chooseTileSize(
m, n, d, sizeof(DistanceT), total_mem, tile_rows, tile_cols);
m, n, d, sizeof(DistanceT), tile_rows, tile_cols);

// for unittesting, its convenient to be able to put a max size on the tiles
// so we can test the tiling logic without having to use huge inputs.
Expand Down Expand Up @@ -356,27 +354,26 @@ void brute_force_knn_impl(

ASSERT(input.size() == sizes.size(), "input and sizes vectors should be the same size");

std::vector<IdxType>* id_ranges;
if (translations == nullptr) {
std::vector<IdxType> id_ranges;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

longer term - we can probably remove the code that handles translations entirely. Its not being used in the public api anymore, and is just left over from the RAFT version. (doesn't need to change in this PR though)

if (translations != nullptr) {
// use the given translations
id_ranges.insert(id_ranges.end(), translations->begin(), translations->end());
} else if (input.size() > 1) {
// If we don't have explicit translations
// for offsets of the indices, build them
// from the local partitions
id_ranges = new std::vector<IdxType>();
IdxType total_n = 0;
for (size_t i = 0; i < input.size(); i++) {
id_ranges->push_back(total_n);
id_ranges.push_back(total_n);
total_n += sizes[i];
}
} else {
// otherwise, use the given translations
id_ranges = translations;
}

int device;
RAFT_CUDA_TRY(cudaGetDevice(&device));

rmm::device_uvector<IdxType> trans(id_ranges->size(), userStream);
raft::update_device(trans.data(), id_ranges->data(), id_ranges->size(), userStream);
rmm::device_uvector<IdxType> trans(0, userStream);
if (id_ranges.size() > 0) {
trans.resize(id_ranges.size(), userStream);
raft::update_device(trans.data(), id_ranges.data(), id_ranges.size(), userStream);
}

rmm::device_uvector<DistType> all_D(0, userStream);
rmm::device_uvector<IdxType> all_I(0, userStream);
Expand Down Expand Up @@ -513,8 +510,6 @@ void brute_force_knn_impl(
// no translations or partitions to combine, it can be skipped.
knn_merge_parts(out_D, out_I, res_D, res_I, n, input.size(), k, userStream, trans.data());
}

if (translations == nullptr) delete id_ranges;
};

template <typename T,
Expand Down