-
Notifications
You must be signed in to change notification settings - Fork 143
Open
Labels
bugSomething isn't workingSomething isn't working
Description
Describe the bug
Heuristics incorrectly selects tiled execution of brute force-knn when the output tile could still fit the memory. This makes knn search slower than torch matmul + topk.
** Additional context **
| tileCols = std::min(targetUsage / preferredTileRows, numCentroids); |
Steps/Code to reproduce bug
Run brute force vector search using input with 1Mx128 input matrix and small number of queries. (The example below uses pylibraft, python wrappers, which currently has the same code as cuvs).
import rmm
mr = rmm.mr.PoolMemoryResource(rmm.mr.CudaMemoryResource(), initial_pool_size=2**30)
rmm.mr.set_current_device_resource(mr)
import torch
import numpy as np
import pylibraft
from pylibraft.common import Handle
from pylibraft.neighbors.brute_force import knn
import cupy as cp
import time
device = torch.device("cuda")
class BenchmarkTimer:
"""Provides a context manager that runs a code block `reps` times
and records results to the instance variable `timings`. Use like:
.. code-block:: python
timer = BenchmarkTimer(rep=5)
for _ in timer.benchmark_runs():
... do something ...
print(np.min(timer.timings))
This class is part of the rapids/cuml benchmark suite
"""
def __init__(self, reps=1, warmup=0):
self.warmup = warmup
self.reps = reps
self.timings = []
def benchmark_runs(self):
for r in range(self.reps + self.warmup):
t0 = time.time()
yield r
t1 = time.time()
if r >= self.warmup:
self.timings.append(t1 - t0)
rows = 1000000
cols = 128
n_queries = 8
k = 10
dataset = torch.randn(rows, cols, device=device)
queries = torch.randn(n_queries, cols, device=device)
dataset_cp = cp.asarray(dataset)
queries_cp = cp.asarray(queries)
timer = BenchmarkTimer(reps=100, warmup=5)
for rep in timer.benchmark_runs():
distance = torch.matmul(queries, dataset.T)
distances, indices = torch.topk(distance, k, dim=1, largest=True)
timings = np.asarray(timer.timings)
avg_time = timings.mean() * 1000
std_time = timings.std() * 1000
print("Average search time: {0:7.3f} +/- {1:7.3} ms".format(avg_time, std_time))
timer = BenchmarkTimer(reps=100, warmup=5)
handle = Handle()
for rep in timer.benchmark_runs():
distances, indices = knn(dataset_cp, queries_cp, k=10, metric="sqeuclidean", handle=handle)
handle.sync()
timings = np.asarray(timer.timings)
avg_time = timings.mean() * 1000
std_time = timings.std() * 1000
print("Average search time: {0:7.3f} +/- {1:7.3} ms".format(avg_time, std_time))Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working