77
88import numpy as np
99import torch
10- from datasets .fingerprint import Hasher
1110from sqlitedict import SqliteDict
1211
12+ from datasets .fingerprint import Hasher
13+
1314from ..datasets import OutputDatasetColumn , OutputIterableDatasetColumn
1415from ..embedders .embedder import Embedder
1516from ..utils .background_utils import dill_serializer , proxy_resource_in_background
@@ -28,6 +29,7 @@ def __init__(
2829 texts : OutputDatasetColumn | OutputIterableDatasetColumn ,
2930 embedder : Embedder ,
3031 truncate : bool = False ,
32+ index_type : str | None = None ,
3133 index_batch_size : int = DEFAULT_BATCH_SIZE ,
3234 index_instruction : None | str = None ,
3335 query_instruction : None | str = None ,
@@ -52,6 +54,7 @@ def __init__(
5254 super ().__init__ (texts = texts , cache_folder_path = cache_folder_path )
5355 self .embedder = embedder
5456 self .truncate = truncate
57+ self .index_type = index_type
5558 self .index_batch_size = index_batch_size
5659 self .index_instruction = index_instruction
5760 self .query_instruction = query_instruction
@@ -99,7 +102,12 @@ def __init__(self_resource) -> None:
99102 ):
100103 self ._initialize_retriever_index_folder ()
101104 index_logger .info ("Building index." )
102- index = faiss .IndexFlatIP (self .embedder .dims )
105+ if self .index_type is not None : # pragma: no cover
106+ index = faiss .index_factory (
107+ self .embedder .dims , self .index_type
108+ )
109+ else :
110+ index = faiss .IndexFlatIP (self .embedder .dims )
103111 if self .device is not None : # pragma: no cover
104112 index = faiss .index_cpu_to_gpus_list (
105113 index = index , gpus = self .device
@@ -137,6 +145,8 @@ def __init__(self_resource) -> None:
137145 ),
138146 )
139147 )
148+ if not index .is_trained : # pragma: no cover
149+ index .train (texts_embedded )
140150 index .add (texts_embedded )
141151 for id_ , text in zip (ids_batch , texts_batch ):
142152 index_lookup [id_ ] = text
@@ -212,7 +222,7 @@ def _batch_lookup(self, indices: list[int]) -> dict[int, dict[str, Any]]:
212222 break
213223 lookup_query = (
214224 "SELECT key, value FROM lookup"
215- f' WHERE key IN ({ "," .join (["?" ] * len (indices_batch ))} )'
225+ f" WHERE key IN ({ ',' .join (['?' ] * len (indices_batch ))} )"
216226 )
217227 results .update (
218228 {
0 commit comments