@@ -17,6 +17,11 @@ def random_rotation(d, seed=123):
17
17
return Q
18
18
19
19
20
+ # Exercise SIMD codepaths, maintain multiple of 16.
21
+ TEST_DIM = 512 * 2 + 256 + 128 + 64 + 16 # 1488
22
+ TEST_N = 4096
23
+
24
+
20
25
# based on https://gist.github.com/mdouze/0b2386c31d7fb8b20ae04f3fcbbf4d9d
21
26
class ReferenceRabitQ :
22
27
"""Exact translation of the paper
@@ -175,53 +180,63 @@ def search(self, x, k):
175
180
176
181
class TestRaBitQ (unittest .TestCase ):
177
182
def do_comparison_vs_pq_test (self , metric_type = faiss .METRIC_L2 ):
178
- ds = datasets .SyntheticDataset (128 , 4096 , 4096 , 100 )
183
+ ds = datasets .SyntheticDataset (TEST_DIM , TEST_N , TEST_N , 100 )
179
184
k = 10
180
185
181
- # PQ 8-to-1
182
- index_pq = faiss .IndexPQ (ds .d , 16 , 8 , metric_type )
183
- index_pq .train (ds .get_train ())
184
- index_pq .add (ds .get_database ())
185
- _ , I_pq = index_pq .search (ds .get_queries (), k )
186
-
187
- index_rbq = faiss .IndexRaBitQ (ds .d , metric_type )
188
- index_rbq .train (ds .get_train ())
189
- index_rbq .add (ds .get_database ())
190
- _ , I_rbq = index_rbq .search (ds .get_queries (), k )
191
-
192
- # try quantized query
193
- rbq_params = faiss .RaBitQSearchParameters (qb = 8 )
194
- _ , I_rbq_q8 = index_rbq .search (ds .get_queries (), k , params = rbq_params )
195
-
196
- rbq_params = faiss .RaBitQSearchParameters (qb = 4 )
197
- _ , I_rbq_q4 = index_rbq .search (ds .get_queries (), k , params = rbq_params )
198
-
199
186
index_flat = faiss .IndexFlat (ds .d , metric_type )
200
187
index_flat .train (ds .get_train ())
201
188
index_flat .add (ds .get_database ())
202
189
_ , I_f = index_flat .search (ds .get_queries (), k )
203
190
204
- # ensure that RaBitQ and PQ are relatively close
205
- eval_pq = faiss .eval_intersection (I_pq [:, :k ], I_f [:, :k ])
206
- eval_pq /= ds .nq * k
207
- eval_rbq = faiss .eval_intersection (I_rbq [:, :k ], I_f [:, :k ])
208
- eval_rbq /= ds .nq * k
209
- eval_rbq_q8 = faiss .eval_intersection (I_rbq_q8 [:, :k ], I_f [:, :k ])
210
- eval_rbq_q8 /= ds .nq * k
211
- eval_rbq_q4 = faiss .eval_intersection (I_rbq_q4 [:, :k ], I_f [:, :k ])
212
- eval_rbq_q4 /= ds .nq * k
213
-
214
- print (
215
- f"PQ is { eval_pq } , "
216
- f"RaBitQ is { eval_rbq } , "
217
- f"q8 RaBitQ is { eval_rbq_q8 } , "
218
- f"q4 RaBitQ is { eval_rbq_q4 } "
219
- )
191
+ def eval_I (I ):
192
+ return faiss .eval_intersection (I , I_f ) / I_f .ravel ().shape [0 ]
193
+
194
+ print ()
195
+
196
+ for random_rotate in [False , True ]:
197
+
198
+ # PQ{D/4}x4fs, also 1 bit per query dimension
199
+ index_pq = faiss .IndexPQFastScan (ds .d , TEST_DIM // 4 , 4 , metric_type )
200
+ # Share a single quantizer (much faster, minimal recall change)
201
+ index_pq .pq .train_type = faiss .ProductQuantizer .Train_shared
202
+ if random_rotate :
203
+ # wrap with random rotations
204
+ rrot = faiss .RandomRotationMatrix (ds .d , ds .d )
205
+ rrot .init (123 )
220
206
221
- np .testing .assert_ (abs (eval_pq - eval_rbq ) < 0.05 )
222
- np .testing .assert_ (abs (eval_pq - eval_rbq_q8 ) < 0.05 )
223
- np .testing .assert_ (abs (eval_pq - eval_rbq_q4 ) < 0.05 )
224
- np .testing .assert_ (eval_pq > 0.55 )
207
+ index_pq = faiss .IndexPreTransform (rrot , index_pq )
208
+ index_pq .train (ds .get_train ())
209
+ index_pq .add (ds .get_database ())
210
+
211
+ D_pq , I_pq = index_pq .search (ds .get_queries (), k )
212
+ loss_pq = 1 - eval_I (I_pq )
213
+ print (f"{ random_rotate = :1} , { loss_pq = :5.3f} " )
214
+ np .testing .assert_ (loss_pq < 0.25 , f"{ loss_pq } " )
215
+
216
+ index_rbq = faiss .IndexRaBitQ (ds .d , metric_type )
217
+ if random_rotate :
218
+ # wrap with random rotations
219
+ rrot = faiss .RandomRotationMatrix (ds .d , ds .d )
220
+ rrot .init (123 )
221
+
222
+ index_rbq = faiss .IndexPreTransform (rrot , index_rbq )
223
+ index_rbq .train (ds .get_train ())
224
+ index_rbq .add (ds .get_database ())
225
+
226
+ for qb in [1 , 2 , 3 , 4 , 8 ]:
227
+ params = faiss .RaBitQSearchParameters (qb = qb )
228
+ _ , I_rbq = index_rbq .search (ds .get_queries (), k , params = params )
229
+
230
+ # ensure that RaBitQ and PQ are relatively close
231
+ loss_rbq = 1 - eval_I (I_rbq )
232
+ ratio_threshold = 2 ** (1 / qb )
233
+ print (
234
+ f"{ random_rotate = :1} , { params .qb = } : "
235
+ f"{ loss_rbq = :5.3f} = loss_pq * { loss_rbq / loss_pq :5.3f} "
236
+ f" < { ratio_threshold = :.2f} "
237
+ )
238
+
239
+ np .testing .assert_ (loss_rbq < loss_pq * ratio_threshold )
225
240
226
241
def test_comparison_vs_pq_L2 (self ):
227
242
self .do_comparison_vs_pq_test (faiss .METRIC_L2 )
@@ -230,7 +245,7 @@ def test_comparison_vs_pq_IP(self):
230
245
self .do_comparison_vs_pq_test (faiss .METRIC_INNER_PRODUCT )
231
246
232
247
def test_comparison_vs_ref_L2_rrot (self , rrot_seed = 123 ):
233
- ds = datasets .SyntheticDataset (128 , 4096 , 4096 , 1 )
248
+ ds = datasets .SyntheticDataset (TEST_DIM , TEST_N , TEST_N , 1 )
234
249
235
250
ref_rbq = ReferenceRabitQ (ds .d , Bq = 8 )
236
251
ref_rbq .train (ds .get_train (), random_rotation (ds .d , rrot_seed ))
@@ -264,7 +279,7 @@ def test_comparison_vs_ref_L2_rrot(self, rrot_seed=123):
264
279
np .testing .assert_ (corr > 0.9 )
265
280
266
281
def test_comparison_vs_ref_L2 (self ):
267
- ds = datasets .SyntheticDataset (128 , 4096 , 4096 , 1 )
282
+ ds = datasets .SyntheticDataset (TEST_DIM , TEST_N , TEST_N , 1 )
268
283
269
284
ref_rbq = ReferenceRabitQ (ds .d , Bq = 8 )
270
285
ref_rbq .train (ds .get_train (), np .identity (ds .d ))
@@ -276,18 +291,21 @@ def test_comparison_vs_ref_L2(self):
276
291
index_rbq .add (ds .get_database ())
277
292
278
293
ref_dis = ref_rbq .distances (ds .get_queries ())
294
+ mean_dist = ref_dis .mean ()
279
295
280
296
dc = index_rbq .get_distance_computer ()
281
297
xq = ds .get_queries ()
282
298
283
299
dc .set_query (faiss .swig_ptr (xq [0 ]))
284
300
for j in range (ds .nb ):
285
301
upd_dis = dc (j )
286
- # print(f"{j} {ref_dis[0][j]} {upd_dis}")
287
- np .testing .assert_ (abs (ref_dis [0 ][j ] - upd_dis ) < 0.001 )
302
+ np .testing .assert_ (
303
+ abs (ref_dis [0 ][j ] - upd_dis ) < mean_dist * 0.00001 ,
304
+ f"{ j } { ref_dis [0 ][j ]} { upd_dis } " ,
305
+ )
288
306
289
307
def do_test_serde (self , description ):
290
- ds = datasets .SyntheticDataset (32 , 1000 , 100 , 20 )
308
+ ds = datasets .SyntheticDataset (32 , TEST_DIM , 100 , 20 )
291
309
292
310
index = faiss .index_factory (ds .d , description )
293
311
index .train (ds .get_train ())
@@ -308,8 +326,90 @@ def test_serde_rabitq(self):
308
326
309
327
310
328
class TestIVFRaBitQ (unittest .TestCase ):
329
+ def do_comparison_vs_pq_test (self , metric_type = faiss .METRIC_L2 ):
330
+ nlist = 64
331
+ nprobe = 8
332
+ nq = 1000
333
+ ds = datasets .SyntheticDataset (TEST_DIM , TEST_N , TEST_N , nq )
334
+ k = 10
335
+
336
+ d = ds .d
337
+ xb = ds .get_database ()
338
+ xt = ds .get_train ()
339
+ xq = ds .get_queries ()
340
+
341
+ quantizer = faiss .IndexFlat (d , metric_type )
342
+ index_flat = faiss .IndexIVFFlat (quantizer , d , nlist , metric_type )
343
+ index_flat .train (xt )
344
+ index_flat .add (xb )
345
+ D_f , I_f = index_flat .search (
346
+ xq , k , params = faiss .IVFSearchParameters (nprobe = nprobe )
347
+ )
348
+
349
+ def eval_I (I ):
350
+ return faiss .eval_intersection (I , I_f ) / I_f .ravel ().shape [0 ]
351
+
352
+ print ()
353
+
354
+ for random_rotate in [False , True ]:
355
+ quantizer = faiss .IndexFlat (d , metric_type )
356
+ index_rbq = faiss .IndexIVFRaBitQ (quantizer , d , nlist , metric_type )
357
+ if random_rotate :
358
+ # wrap with random rotations
359
+ rrot = faiss .RandomRotationMatrix (d , d )
360
+ rrot .init (123 )
361
+
362
+ index_rbq = faiss .IndexPreTransform (rrot , index_rbq )
363
+ index_rbq .train (xt )
364
+ index_rbq .add (xb )
365
+
366
+ # PQ{D/4}x4fs, also 1 bit per query dimension,
367
+ # reusing quantizer from index_rbq.
368
+ index_pq = faiss .IndexIVFPQFastScan (
369
+ quantizer , d , nlist , TEST_DIM // 4 , 4 , metric_type
370
+ )
371
+ # Share a single quantizer (much faster, minimal recall change)
372
+ index_pq .pq .train_type = faiss .ProductQuantizer .Train_shared
373
+ if random_rotate :
374
+ # wrap with random rotations
375
+ rrot = faiss .RandomRotationMatrix (d , d )
376
+ rrot .init (123 )
377
+
378
+ index_pq = faiss .IndexPreTransform (rrot , index_pq )
379
+ index_pq .train (xt )
380
+ index_pq .add (xb )
381
+
382
+ D_pq , I_pq = index_pq .search (
383
+ xq , k , params = faiss .IVFPQSearchParameters (nprobe = nprobe )
384
+ )
385
+ loss_pq = 1 - eval_I (I_pq )
386
+
387
+ print (f"{ random_rotate = :1} , { loss_pq = :5.3f} " )
388
+ np .testing .assert_ (loss_pq < 0.25 , f"{ loss_pq } " )
389
+
390
+ for qb in [1 , 2 , 3 , 4 , 8 ]:
391
+ params = faiss .IVFRaBitQSearchParameters (nprobe = nprobe , qb = qb )
392
+ D_rbq , I_rbq = index_rbq .search (xq , k , params = params )
393
+
394
+ # ensure that RaBitQ and PQ are relatively close
395
+ loss_rbq = 1 - eval_I (I_rbq )
396
+ ratio_threshold = 2 ** (1 / qb )
397
+ print (
398
+ f"{ random_rotate = :1} , { params .qb = } : "
399
+ f"{ loss_rbq = :5.3f} = loss_pq * { loss_rbq / loss_pq :5.3f} "
400
+ f" < { ratio_threshold = :.2f} "
401
+ )
402
+
403
+ np .testing .assert_ (loss_rbq < loss_pq * ratio_threshold )
404
+
405
+ def test_comparison_vs_pq_L2 (self ):
406
+ self .do_comparison_vs_pq_test (faiss .METRIC_L2 )
407
+
408
+ def test_comparison_vs_pq_IP (self ):
409
+ self .do_comparison_vs_pq_test (faiss .METRIC_INNER_PRODUCT )
410
+
311
411
def test_comparison_vs_ref_L2 (self ):
312
- ds = datasets .SyntheticDataset (128 , 4096 , 4096 , 100 )
412
+ ds = datasets .SyntheticDataset (TEST_DIM , TEST_N , TEST_N , 100 )
313
413
314
414
k = 10
315
415
nlist = 200
@@ -318,9 +418,7 @@ def test_comparison_vs_ref_L2(self):
318
418
ref_rbq .add (ds .get_database ())
319
419
320
420
index_flat = faiss .IndexFlat (ds .d , faiss .METRIC_L2 )
321
- index_rbq = faiss .IndexIVFRaBitQ (
322
- index_flat , ds .d , nlist , faiss .METRIC_L2
323
- )
421
+ index_rbq = faiss .IndexIVFRaBitQ (index_flat , ds .d , nlist , faiss .METRIC_L2 )
324
422
index_rbq .qb = 4
325
423
index_rbq .train (ds .get_train ())
326
424
index_rbq .add (ds .get_database ())
@@ -358,9 +456,7 @@ def test_comparison_vs_ref_L2_rrot(self):
358
456
ref_rbq .add (ds .get_database ())
359
457
360
458
index_flat = faiss .IndexFlat (ds .d , faiss .METRIC_L2 )
361
- index_rbq = faiss .IndexIVFRaBitQ (
362
- index_flat , ds .d , nlist , faiss .METRIC_L2
363
- )
459
+ index_rbq = faiss .IndexIVFRaBitQ (index_flat , ds .d , nlist , faiss .METRIC_L2 )
364
460
index_rbq .qb = 4
365
461
366
462
# wrap with random rotations
0 commit comments