1717package com .nvidia .cuvs ;
1818
1919import java .util .Arrays ;
20+ import java .util .BitSet ;
2021import java .util .List ;
2122
2223/**
@@ -28,7 +29,8 @@ public class BruteForceQuery {
2829
2930 private List <Integer > mapping ;
3031 private float [][] queryVectors ;
31- private long [] prefilter ;
32+ private BitSet [] prefilters ;
33+ private int numDocs = -1 ;
3234 private int topK ;
3335
3436 /**
@@ -40,12 +42,15 @@ public class BruteForceQuery {
4042 * @param topK the top k results to return
4143 * @param prefilter the prefilter data to use while searching the BRUTEFORCE
4244 * index
45+ * @param numDocs Maximum of bits in each prefilter, representing number of documents in this index.
46+ * Used only when prefilter(s) is/are passed.
4347 */
44- public BruteForceQuery (float [][] queryVectors , List <Integer > mapping , int topK , long [] prefilter ) {
48+ public BruteForceQuery (float [][] queryVectors , List <Integer > mapping , int topK , BitSet [] prefilters , int numDocs ) {
4549 this .queryVectors = queryVectors ;
4650 this .mapping = mapping ;
4751 this .topK = topK ;
48- this .prefilter = prefilter ;
52+ this .prefilters = prefilters ;
53+ this .numDocs = numDocs ;
4954 }
5055
5156 /**
@@ -78,16 +83,25 @@ public int getTopK() {
7883 /**
7984 * Gets the prefilter long array
8085 *
81- * @return a long array
86+ * @return an array of bitsets
8287 */
83- public long [] getPrefilter () {
84- return prefilter ;
88+ public BitSet [] getPrefilters () {
89+ return prefilters ;
90+ }
91+
92+ /**
93+ * Gets the number of documents supposed to be in this index, as used for prefilters
94+ *
95+ * @return number of documents as an integer
96+ */
97+ public int getNumDocs () {
98+ return numDocs ;
8599 }
86100
87101 @ Override
88102 public String toString () {
89103 return "BruteForceQuery [mapping=" + mapping + ", queryVectors=" + Arrays .toString (queryVectors ) + ", prefilter="
90- + Arrays .toString (prefilter ) + ", topK=" + topK + "]" ;
104+ + Arrays .toString (prefilters ) + ", topK=" + topK + "]" ;
91105 }
92106
93107 /**
@@ -96,7 +110,8 @@ public String toString() {
96110 public static class Builder {
97111
98112 private float [][] queryVectors ;
99- private long [] prefilter ;
113+ private BitSet [] prefilters ;
114+ private int numDocs ;
100115 private List <Integer > mapping ;
101116 private int topK = 2 ;
102117
@@ -134,13 +149,15 @@ public Builder withTopK(int topK) {
134149 }
135150
136151 /**
137- * Sets the prefilter data for building the {@link BruteForceQuery}.
152+ * Sets the prefilters data for building the {@link BruteForceQuery}.
138153 *
139- * @param prefilter a one-dimensional long array
154+ * @param prefilters array of bitsets, as many as queries, each containing as
155+ * many bits as there are vectors in the index
140156 * @return an instance of this Builder
141157 */
142- public Builder withPrefilter (long [] prefilter ) {
143- this .prefilter = prefilter ;
158+ public Builder withPrefilter (BitSet [] prefilters , int numDocs ) {
159+ this .prefilters = prefilters ;
160+ this .numDocs = numDocs ;
144161 return this ;
145162 }
146163
@@ -150,7 +167,7 @@ public Builder withPrefilter(long[] prefilter) {
150167 * @return an instance of {@link BruteForceQuery}
151168 */
152169 public BruteForceQuery build () {
153- return new BruteForceQuery (queryVectors , mapping , topK , prefilter );
170+ return new BruteForceQuery (queryVectors , mapping , topK , prefilters , numDocs );
154171 }
155172 }
156173}
0 commit comments