Skip to content

Commit 42f4e6f

Browse files
committed
added milvus support through fastembed and qol console logs for rag command
Signed-off-by: Brian <[email protected]>
1 parent 3f8e31a commit 42f4e6f

File tree

10 files changed

+197
-120
lines changed

10 files changed

+197
-120
lines changed

container-images/scripts/build_rag.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ docling() {
5757
}
5858

5959
rag() {
60-
${python} -m pip install --prefix=/usr wheel qdrant_client fastembed openai fastapi uvicorn
60+
${python} -m pip install --prefix=/usr wheel qdrant_client pymilvus fastembed openai fastapi uvicorn
6161
rag_framework load
6262
}
6363

container-images/scripts/doc2rag

Lines changed: 108 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,36 @@
11
#!/usr/bin/env python3
22

3+
# suppress pkg warning for pymilvus
34
import argparse
45
import errno
56
import hashlib
7+
import itertools
68
import os
79
import os.path
810
import sys
11+
import threading
12+
import time
913
import uuid
14+
import warnings
15+
from pathlib import Path
1016

1117
import docling
1218
import qdrant_client
1319
from docling.chunking import HybridChunker
1420
from docling.datamodel.base_models import InputFormat
1521
from docling.datamodel.pipeline_options import PdfPipelineOptions
1622
from docling.document_converter import DocumentConverter, PdfFormatOption
23+
from fastembed import SparseTextEmbedding, TextEmbedding
24+
from pymilvus import DataType, MilvusClient
1725
from qdrant_client import models
1826

27+
warnings.filterwarnings("ignore", category=UserWarning)
28+
1929
# Global Vars
2030
EMBED_MODEL = os.getenv("EMBED_MODEL", "jinaai/jina-embeddings-v2-small-en")
2131
SPARSE_MODEL = os.getenv("SPARSE_MODEL", "prithivida/Splade_PP_en_v1")
2232
COLLECTION_NAME = "rag"
33+
os.environ["TOKENIZERS_PARALLELISM"] = "true"
2334

2435

2536
class Converter:
@@ -37,58 +48,122 @@ class Converter:
3748
self.doc_converter = DocumentConverter(
3849
format_options={InputFormat.PDF: PdfFormatOption(pipeline_options=pipeline_options)}
3950
)
40-
if self.format == "qdrant":
41-
self.client = qdrant_client.QdrantClient(path=self.output)
42-
self.client.set_model(EMBED_MODEL)
43-
self.client.set_sparse_model(SPARSE_MODEL)
44-
# optimizations to reduce ram
45-
self.client.create_collection(
46-
collection_name=COLLECTION_NAME,
47-
vectors_config=self.client.get_fastembed_vector_params(on_disk=True),
48-
sparse_vectors_config=self.client.get_fastembed_sparse_vector_params(on_disk=True),
49-
quantization_config=models.ScalarQuantization(
50-
scalar=models.ScalarQuantizationConfig(
51-
type=models.ScalarType.INT8,
52-
always_ram=True,
53-
),
54-
),
55-
)
5651

5752
def add(self, file_path):
5853
if os.path.isdir(file_path):
5954
self.walk(file_path) # Walk directory and process all files
6055
else:
6156
self.sources.append(file_path) # Process the single file
6257

63-
def convert_qdrant(self, results):
64-
documents, ids = [], []
58+
def chunk(self, docs):
6559
chunker = HybridChunker(tokenizer=EMBED_MODEL, overlap=100, merge_peers=True)
66-
for result in results:
67-
chunk_iter = chunker.chunk(dl_doc=result.document)
60+
documents, ids = [], []
61+
62+
for file in docs:
63+
chunk_iter = chunker.chunk(dl_doc=file.document)
6864
for chunk in chunk_iter:
6965
# Extract the enriched text from the chunk
7066
doc_text = chunker.contextualize(chunk=chunk)
71-
# Append to respective lists
7267
documents.append(doc_text)
73-
# Generate unique ID for the chunk
7468
doc_id = self.generate_hash(doc_text)
7569
ids.append(doc_id)
76-
return self.client.add(COLLECTION_NAME, documents=documents, ids=ids, batch_size=1)
70+
return documents, ids
71+
72+
def convert_milvus(self, docs):
73+
output_dir = Path(self.output)
74+
output_dir.mkdir(parents=True, exist_ok=True)
75+
milvus_client = MilvusClient(uri=os.path.join(self.output, "milvus.db"))
76+
collection_name = COLLECTION_NAME
77+
dmodel = TextEmbedding(model_name=EMBED_MODEL)
78+
smodel = SparseTextEmbedding(model_name=SPARSE_MODEL)
79+
test_embedding = next(dmodel.embed("This is a test"))
80+
embedding_dim = len(test_embedding)
81+
schema = MilvusClient.create_schema(
82+
auto_id=False,
83+
enable_dynamic_field=True,
84+
)
85+
schema.add_field(field_name="id", datatype=DataType.INT64, is_primary=True)
86+
schema.add_field(field_name="text", datatype=DataType.VARCHAR, max_length=1000)
87+
schema.add_field(field_name="sparse", datatype=DataType.SPARSE_FLOAT_VECTOR)
88+
schema.add_field(field_name="dense", datatype=DataType.FLOAT_VECTOR, dim=embedding_dim)
89+
index_params = milvus_client.prepare_index_params()
90+
index_params.add_index(field_name="dense", index_name="dense_index", index_type="AUTOINDEX")
91+
index_params.add_index(field_name="sparse", index_name="sparse_index", index_type="SPARSE_INVERTED_INDEX")
92+
milvus_client.create_collection(collection_name=collection_name, schema=schema, index_params=index_params)
93+
# Chunk and add chunks to collection 1 by 1
94+
chunks, ids = self.chunk(docs)
95+
# Batch-embed chunks for better performance
96+
dense_embeddings = list(dmodel.embed(chunks))
97+
sparse_embeddings_list = list(smodel.embed(chunks))
98+
99+
for i, (chunk, id) in enumerate(zip(chunks, ids)):
100+
sparse_vector = sparse_embeddings_list[i].as_dict()
101+
milvus_client.insert(
102+
collection_name=collection_name,
103+
data=[{"id": id, "text": chunk, "sparse": sparse_vector, "dense": dense_embeddings[i]}],
104+
)
105+
# Flush every 100 records to reduce RAM usage
106+
if i % 100 == 0:
107+
milvus_client.flush(collection_name=collection_name)
108+
print(f"\rProcessed chunk {i+1}/{len(chunks)}", end='', flush=True)
109+
milvus_client.flush(collection_name=collection_name)
110+
print("\n")
111+
return
112+
113+
def convert_qdrant(self, results):
114+
qclient = qdrant_client.QdrantClient(path=self.output)
115+
qclient.set_model(EMBED_MODEL)
116+
qclient.set_sparse_model(SPARSE_MODEL)
117+
# optimizations to reduce ram
118+
qclient.create_collection(
119+
collection_name=COLLECTION_NAME,
120+
vectors_config=qclient.get_fastembed_vector_params(on_disk=True),
121+
sparse_vectors_config=qclient.get_fastembed_sparse_vector_params(on_disk=True),
122+
quantization_config=models.ScalarQuantization(
123+
scalar=models.ScalarQuantizationConfig(
124+
type=models.ScalarType.INT8,
125+
always_ram=True,
126+
),
127+
),
128+
)
129+
chunks, ids = self.chunk(results)
130+
return qclient.add(COLLECTION_NAME, documents=chunks, ids=ids, batch_size=1)
131+
132+
def show_progress(self, message, stop_event):
133+
spinner = itertools.cycle([".", "..", "..."])
134+
while not stop_event.is_set():
135+
sys.stdout.write(f"\r{message} {next(spinner)} ")
136+
sys.stdout.flush()
137+
time.sleep(0.5)
138+
sys.stdout.write("\r" + " " * 50 + "\r")
77139

78140
def convert(self):
79-
results = self.doc_converter.convert_all(self.sources)
141+
results = []
142+
names = []
143+
for source in self.sources:
144+
name = Path((str(source))).stem
145+
names.append(name)
146+
stop_event = threading.Event()
147+
progress_thread = threading.Thread(target=self.show_progress, args=(f"Converting {name}.pdf", stop_event))
148+
progress_thread.start()
149+
try:
150+
results.append(self.doc_converter.convert(source))
151+
finally:
152+
stop_event.set()
153+
progress_thread.join()
154+
print(f"Finished converting {name}.pdf")
155+
80156
if self.format == "qdrant":
81157
return self.convert_qdrant(results)
82158
if self.format == "markdown":
83-
# Export the converted document to Markdown
84159
return self.convert_markdown(results)
85160
if self.format == "json":
86-
# Export the converted document to JSON
87161
return self.convert_json(results)
162+
if self.format == "milvus":
163+
self.convert_milvus(results)
88164

89165
def convert_markdown(self, results):
90166
ctr = 0
91-
# Process the conversion results
92167
for ctr, result in enumerate(results):
93168
dirname = self.output + os.path.dirname(self.sources[ctr])
94169
os.makedirs(dirname, exist_ok=True)
@@ -111,12 +186,12 @@ class Converter:
111186
if os.path.isfile(file):
112187
self.sources.append(file)
113188

114-
def generate_hash(self, document: str) -> str:
115-
"""Generate a unique hash for a document."""
189+
def generate_hash(self, document: str) -> int:
190+
"""Generate a unique int64 hash from the document text."""
116191
sha256_hash = hashlib.sha256(document.encode('utf-8')).hexdigest()
117-
118-
# Use the first 32 characters of the hash to create a UUID
119-
return str(uuid.UUID(sha256_hash[:32]))
192+
uuid_val = uuid.UUID(sha256_hash[:32])
193+
# Convert to signed int64 (Milvus requires signed 64-bit)
194+
return uuid_val.int & ((1 << 63) - 1)
120195

121196

122197
def load():
@@ -139,7 +214,7 @@ parser.add_argument(
139214
"--format",
140215
default="qdrant",
141216
help="Output format for RAG Data",
142-
choices=["qdrant", "json", "markdown"],
217+
choices=["qdrant", "json", "markdown", "milvus"],
143218
)
144219
parser.add_argument(
145220
"--ocr",

0 commit comments

Comments
 (0)