Skip to content

Commit 93db229

Browse files
ddrcoderfacebook-github-bot
authored andcommitted
RabitQ test coverage for SIMD codepaths (facebookresearch#4571)
Summary: Pull Request resolved: facebookresearch#4571 Reviewed By: luciang Differential Revision: D81825792
1 parent 3671c61 commit 93db229

File tree

1 file changed

+147
-51
lines changed

1 file changed

+147
-51
lines changed

tests/test_rabitq.py

Lines changed: 147 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,11 @@ def random_rotation(d, seed=123):
1717
return Q
1818

1919

20+
# Exercise SIMD codepaths, maintain multiple of 16.
21+
TEST_DIM = 512 * 2 + 256 + 128 + 64 + 16 # 1488
22+
TEST_N = 4096
23+
24+
2025
# based on https://gist.github.com/mdouze/0b2386c31d7fb8b20ae04f3fcbbf4d9d
2126
class ReferenceRabitQ:
2227
"""Exact translation of the paper
@@ -175,53 +180,63 @@ def search(self, x, k):
175180

176181
class TestRaBitQ(unittest.TestCase):
177182
def do_comparison_vs_pq_test(self, metric_type=faiss.METRIC_L2):
178-
ds = datasets.SyntheticDataset(128, 4096, 4096, 100)
183+
ds = datasets.SyntheticDataset(TEST_DIM, TEST_N, TEST_N, 100)
179184
k = 10
180185

181-
# PQ 8-to-1
182-
index_pq = faiss.IndexPQ(ds.d, 16, 8, metric_type)
183-
index_pq.train(ds.get_train())
184-
index_pq.add(ds.get_database())
185-
_, I_pq = index_pq.search(ds.get_queries(), k)
186-
187-
index_rbq = faiss.IndexRaBitQ(ds.d, metric_type)
188-
index_rbq.train(ds.get_train())
189-
index_rbq.add(ds.get_database())
190-
_, I_rbq = index_rbq.search(ds.get_queries(), k)
191-
192-
# try quantized query
193-
rbq_params = faiss.RaBitQSearchParameters(qb=8)
194-
_, I_rbq_q8 = index_rbq.search(ds.get_queries(), k, params=rbq_params)
195-
196-
rbq_params = faiss.RaBitQSearchParameters(qb=4)
197-
_, I_rbq_q4 = index_rbq.search(ds.get_queries(), k, params=rbq_params)
198-
199186
index_flat = faiss.IndexFlat(ds.d, metric_type)
200187
index_flat.train(ds.get_train())
201188
index_flat.add(ds.get_database())
202189
_, I_f = index_flat.search(ds.get_queries(), k)
203190

204-
# ensure that RaBitQ and PQ are relatively close
205-
eval_pq = faiss.eval_intersection(I_pq[:, :k], I_f[:, :k])
206-
eval_pq /= ds.nq * k
207-
eval_rbq = faiss.eval_intersection(I_rbq[:, :k], I_f[:, :k])
208-
eval_rbq /= ds.nq * k
209-
eval_rbq_q8 = faiss.eval_intersection(I_rbq_q8[:, :k], I_f[:, :k])
210-
eval_rbq_q8 /= ds.nq * k
211-
eval_rbq_q4 = faiss.eval_intersection(I_rbq_q4[:, :k], I_f[:, :k])
212-
eval_rbq_q4 /= ds.nq * k
213-
214-
print(
215-
f"PQ is {eval_pq}, "
216-
f"RaBitQ is {eval_rbq}, "
217-
f"q8 RaBitQ is {eval_rbq_q8}, "
218-
f"q4 RaBitQ is {eval_rbq_q4}"
219-
)
191+
def eval_I(I):
192+
return faiss.eval_intersection(I, I_f) / I_f.ravel().shape[0]
193+
194+
print()
195+
196+
for random_rotate in [False, True]:
197+
198+
# PQ{D/4}x4fs, also 1 bit per query dimension
199+
index_pq = faiss.IndexPQFastScan(ds.d, TEST_DIM // 4, 4, metric_type)
200+
# Share a single quantizer (much faster, minimal recall change)
201+
index_pq.pq.train_type = faiss.ProductQuantizer.Train_shared
202+
if random_rotate:
203+
# wrap with random rotations
204+
rrot = faiss.RandomRotationMatrix(ds.d, ds.d)
205+
rrot.init(123)
220206

221-
np.testing.assert_(abs(eval_pq - eval_rbq) < 0.05)
222-
np.testing.assert_(abs(eval_pq - eval_rbq_q8) < 0.05)
223-
np.testing.assert_(abs(eval_pq - eval_rbq_q4) < 0.05)
224-
np.testing.assert_(eval_pq > 0.55)
207+
index_pq = faiss.IndexPreTransform(rrot, index_pq)
208+
index_pq.train(ds.get_train())
209+
index_pq.add(ds.get_database())
210+
211+
D_pq, I_pq = index_pq.search(ds.get_queries(), k)
212+
loss_pq = 1 - eval_I(I_pq)
213+
print(f"{random_rotate=:1}, {loss_pq=:5.3f}")
214+
np.testing.assert_(loss_pq < 0.25, f"{loss_pq}")
215+
216+
index_rbq = faiss.IndexRaBitQ(ds.d, metric_type)
217+
if random_rotate:
218+
# wrap with random rotations
219+
rrot = faiss.RandomRotationMatrix(ds.d, ds.d)
220+
rrot.init(123)
221+
222+
index_rbq = faiss.IndexPreTransform(rrot, index_rbq)
223+
index_rbq.train(ds.get_train())
224+
index_rbq.add(ds.get_database())
225+
226+
for qb in [1, 2, 3, 4, 8]:
227+
params = faiss.RaBitQSearchParameters(qb=qb)
228+
_, I_rbq = index_rbq.search(ds.get_queries(), k, params=params)
229+
230+
# ensure that RaBitQ and PQ are relatively close
231+
loss_rbq = 1 - eval_I(I_rbq)
232+
ratio_threshold = 2 ** (1 / qb)
233+
print(
234+
f"{random_rotate=:1}, {params.qb=}: "
235+
f"{loss_rbq=:5.3f} = loss_pq * {loss_rbq/loss_pq:5.3f}"
236+
f" < {ratio_threshold=:.2f}"
237+
)
238+
239+
np.testing.assert_(loss_rbq < loss_pq * ratio_threshold)
225240

226241
def test_comparison_vs_pq_L2(self):
227242
self.do_comparison_vs_pq_test(faiss.METRIC_L2)
@@ -230,7 +245,7 @@ def test_comparison_vs_pq_IP(self):
230245
self.do_comparison_vs_pq_test(faiss.METRIC_INNER_PRODUCT)
231246

232247
def test_comparison_vs_ref_L2_rrot(self, rrot_seed=123):
233-
ds = datasets.SyntheticDataset(128, 4096, 4096, 1)
248+
ds = datasets.SyntheticDataset(TEST_DIM, TEST_N, TEST_N, 1)
234249

235250
ref_rbq = ReferenceRabitQ(ds.d, Bq=8)
236251
ref_rbq.train(ds.get_train(), random_rotation(ds.d, rrot_seed))
@@ -264,7 +279,7 @@ def test_comparison_vs_ref_L2_rrot(self, rrot_seed=123):
264279
np.testing.assert_(corr > 0.9)
265280

266281
def test_comparison_vs_ref_L2(self):
267-
ds = datasets.SyntheticDataset(128, 4096, 4096, 1)
282+
ds = datasets.SyntheticDataset(TEST_DIM, TEST_N, TEST_N, 1)
268283

269284
ref_rbq = ReferenceRabitQ(ds.d, Bq=8)
270285
ref_rbq.train(ds.get_train(), np.identity(ds.d))
@@ -276,18 +291,21 @@ def test_comparison_vs_ref_L2(self):
276291
index_rbq.add(ds.get_database())
277292

278293
ref_dis = ref_rbq.distances(ds.get_queries())
294+
mean_dist = ref_dis.mean()
279295

280296
dc = index_rbq.get_distance_computer()
281297
xq = ds.get_queries()
282298

283299
dc.set_query(faiss.swig_ptr(xq[0]))
284300
for j in range(ds.nb):
285301
upd_dis = dc(j)
286-
# print(f"{j} {ref_dis[0][j]} {upd_dis}")
287-
np.testing.assert_(abs(ref_dis[0][j] - upd_dis) < 0.001)
302+
np.testing.assert_(
303+
abs(ref_dis[0][j] - upd_dis) < mean_dist * 0.00001,
304+
f"{j} {ref_dis[0][j]} {upd_dis}",
305+
)
288306

289307
def do_test_serde(self, description):
290-
ds = datasets.SyntheticDataset(32, 1000, 100, 20)
308+
ds = datasets.SyntheticDataset(32, TEST_DIM, 100, 20)
291309

292310
index = faiss.index_factory(ds.d, description)
293311
index.train(ds.get_train())
@@ -308,8 +326,90 @@ def test_serde_rabitq(self):
308326

309327

310328
class TestIVFRaBitQ(unittest.TestCase):
329+
def do_comparison_vs_pq_test(self, metric_type=faiss.METRIC_L2):
330+
nlist = 64
331+
nprobe = 8
332+
nq = 1000
333+
ds = datasets.SyntheticDataset(TEST_DIM, TEST_N, TEST_N, nq)
334+
k = 10
335+
336+
d = ds.d
337+
xb = ds.get_database()
338+
xt = ds.get_train()
339+
xq = ds.get_queries()
340+
341+
quantizer = faiss.IndexFlat(d, metric_type)
342+
index_flat = faiss.IndexIVFFlat(quantizer, d, nlist, metric_type)
343+
index_flat.train(xt)
344+
index_flat.add(xb)
345+
D_f, I_f = index_flat.search(
346+
xq, k, params=faiss.IVFSearchParameters(nprobe=nprobe)
347+
)
348+
349+
def eval_I(I):
350+
return faiss.eval_intersection(I, I_f) / I_f.ravel().shape[0]
351+
352+
print()
353+
354+
for random_rotate in [False, True]:
355+
quantizer = faiss.IndexFlat(d, metric_type)
356+
index_rbq = faiss.IndexIVFRaBitQ(quantizer, d, nlist, metric_type)
357+
if random_rotate:
358+
# wrap with random rotations
359+
rrot = faiss.RandomRotationMatrix(d, d)
360+
rrot.init(123)
361+
362+
index_rbq = faiss.IndexPreTransform(rrot, index_rbq)
363+
index_rbq.train(xt)
364+
index_rbq.add(xb)
365+
366+
# PQ{D/4}x4fs, also 1 bit per query dimension,
367+
# reusing quantizer from index_rbq.
368+
index_pq = faiss.IndexIVFPQFastScan(
369+
quantizer, d, nlist, TEST_DIM // 4, 4, metric_type
370+
)
371+
# Share a single quantizer (much faster, minimal recall change)
372+
index_pq.pq.train_type = faiss.ProductQuantizer.Train_shared
373+
if random_rotate:
374+
# wrap with random rotations
375+
rrot = faiss.RandomRotationMatrix(d, d)
376+
rrot.init(123)
377+
378+
index_pq = faiss.IndexPreTransform(rrot, index_pq)
379+
index_pq.train(xt)
380+
index_pq.add(xb)
381+
382+
D_pq, I_pq = index_pq.search(
383+
xq, k, params=faiss.IVFPQSearchParameters(nprobe=nprobe)
384+
)
385+
loss_pq = 1 - eval_I(I_pq)
386+
387+
print(f"{random_rotate=:1}, {loss_pq=:5.3f}")
388+
np.testing.assert_(loss_pq < 0.25, f"{loss_pq}")
389+
390+
for qb in [1, 2, 3, 4, 8]:
391+
params = faiss.IVFRaBitQSearchParameters(nprobe=nprobe, qb=qb)
392+
D_rbq, I_rbq = index_rbq.search(xq, k, params=params)
393+
394+
# ensure that RaBitQ and PQ are relatively close
395+
loss_rbq = 1 - eval_I(I_rbq)
396+
ratio_threshold = 2 ** (1 / qb)
397+
print(
398+
f"{random_rotate=:1}, {params.qb=}: "
399+
f"{loss_rbq=:5.3f} = loss_pq * {loss_rbq/loss_pq:5.3f}"
400+
f" < {ratio_threshold=:.2f}"
401+
)
402+
403+
np.testing.assert_(loss_rbq < loss_pq * ratio_threshold)
404+
405+
def test_comparison_vs_pq_L2(self):
406+
self.do_comparison_vs_pq_test(faiss.METRIC_L2)
407+
408+
def test_comparison_vs_pq_IP(self):
409+
self.do_comparison_vs_pq_test(faiss.METRIC_INNER_PRODUCT)
410+
311411
def test_comparison_vs_ref_L2(self):
312-
ds = datasets.SyntheticDataset(128, 4096, 4096, 100)
412+
ds = datasets.SyntheticDataset(TEST_DIM, TEST_N, TEST_N, 100)
313413

314414
k = 10
315415
nlist = 200
@@ -318,9 +418,7 @@ def test_comparison_vs_ref_L2(self):
318418
ref_rbq.add(ds.get_database())
319419

320420
index_flat = faiss.IndexFlat(ds.d, faiss.METRIC_L2)
321-
index_rbq = faiss.IndexIVFRaBitQ(
322-
index_flat, ds.d, nlist, faiss.METRIC_L2
323-
)
421+
index_rbq = faiss.IndexIVFRaBitQ(index_flat, ds.d, nlist, faiss.METRIC_L2)
324422
index_rbq.qb = 4
325423
index_rbq.train(ds.get_train())
326424
index_rbq.add(ds.get_database())
@@ -358,9 +456,7 @@ def test_comparison_vs_ref_L2_rrot(self):
358456
ref_rbq.add(ds.get_database())
359457

360458
index_flat = faiss.IndexFlat(ds.d, faiss.METRIC_L2)
361-
index_rbq = faiss.IndexIVFRaBitQ(
362-
index_flat, ds.d, nlist, faiss.METRIC_L2
363-
)
459+
index_rbq = faiss.IndexIVFRaBitQ(index_flat, ds.d, nlist, faiss.METRIC_L2)
364460
index_rbq.qb = 4
365461

366462
# wrap with random rotations

0 commit comments

Comments
 (0)