Skip to content

Commit cc73718

Browse files
committed
Allow custom FAISS index types
1 parent 23d38cc commit cc73718

File tree

2 files changed

+14
-4
lines changed

2 files changed

+14
-4
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "DataDreamer"
3-
version = "0.41.0"
3+
version = "0.42.0"
44
description = "Prompt. Generate Synthetic Data. Train & Align Models."
55
license = "MIT"
66
authors= [

src/retrievers/embedding_retriever.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,10 @@
77

88
import numpy as np
99
import torch
10-
from datasets.fingerprint import Hasher
1110
from sqlitedict import SqliteDict
1211

12+
from datasets.fingerprint import Hasher
13+
1314
from ..datasets import OutputDatasetColumn, OutputIterableDatasetColumn
1415
from ..embedders.embedder import Embedder
1516
from ..utils.background_utils import dill_serializer, proxy_resource_in_background
@@ -28,6 +29,7 @@ def __init__(
2829
texts: OutputDatasetColumn | OutputIterableDatasetColumn,
2930
embedder: Embedder,
3031
truncate: bool = False,
32+
index_type: str | None = None,
3133
index_batch_size: int = DEFAULT_BATCH_SIZE,
3234
index_instruction: None | str = None,
3335
query_instruction: None | str = None,
@@ -52,6 +54,7 @@ def __init__(
5254
super().__init__(texts=texts, cache_folder_path=cache_folder_path)
5355
self.embedder = embedder
5456
self.truncate = truncate
57+
self.index_type = index_type
5558
self.index_batch_size = index_batch_size
5659
self.index_instruction = index_instruction
5760
self.query_instruction = query_instruction
@@ -99,7 +102,12 @@ def __init__(self_resource) -> None:
99102
):
100103
self._initialize_retriever_index_folder()
101104
index_logger.info("Building index.")
102-
index = faiss.IndexFlatIP(self.embedder.dims)
105+
if self.index_type is not None: # pragma: no cover
106+
index = faiss.index_factory(
107+
self.embedder.dims, self.index_type
108+
)
109+
else:
110+
index = faiss.IndexFlatIP(self.embedder.dims)
103111
if self.device is not None: # pragma: no cover
104112
index = faiss.index_cpu_to_gpus_list(
105113
index=index, gpus=self.device
@@ -137,6 +145,8 @@ def __init__(self_resource) -> None:
137145
),
138146
)
139147
)
148+
if not index.is_trained: # pragma: no cover
149+
index.train(texts_embedded)
140150
index.add(texts_embedded)
141151
for id_, text in zip(ids_batch, texts_batch):
142152
index_lookup[id_] = text
@@ -212,7 +222,7 @@ def _batch_lookup(self, indices: list[int]) -> dict[int, dict[str, Any]]:
212222
break
213223
lookup_query = (
214224
"SELECT key, value FROM lookup"
215-
f' WHERE key IN ({",".join(["?"] * len(indices_batch))})'
225+
f" WHERE key IN ({','.join(['?'] * len(indices_batch))})"
216226
)
217227
results.update(
218228
{

0 commit comments

Comments
 (0)