-
Notifications
You must be signed in to change notification settings - Fork 250
added milvus support and qol console logs for rag command #1720
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,25 +1,36 @@ | ||
#!/usr/bin/env python3 | ||
|
||
# suppress pkg warning for pymilvus | ||
import argparse | ||
import errno | ||
import hashlib | ||
import itertools | ||
import os | ||
import os.path | ||
import sys | ||
import threading | ||
import time | ||
import uuid | ||
import warnings | ||
from pathlib import Path | ||
|
||
import docling | ||
import qdrant_client | ||
from docling.chunking import HybridChunker | ||
from docling.datamodel.base_models import InputFormat | ||
from docling.datamodel.pipeline_options import PdfPipelineOptions | ||
from docling.document_converter import DocumentConverter, PdfFormatOption | ||
from fastembed import SparseTextEmbedding, TextEmbedding | ||
from pymilvus import DataType, MilvusClient | ||
from qdrant_client import models | ||
|
||
warnings.filterwarnings("ignore", category=UserWarning) | ||
|
||
# Global Vars | ||
EMBED_MODEL = os.getenv("EMBED_MODEL", "jinaai/jina-embeddings-v2-small-en") | ||
SPARSE_MODEL = os.getenv("SPARSE_MODEL", "prithivida/Splade_PP_en_v1") | ||
COLLECTION_NAME = "rag" | ||
os.environ["TOKENIZERS_PARALLELISM"] = "true" | ||
|
||
|
||
class Converter: | ||
|
@@ -37,58 +48,122 @@ class Converter: | |
self.doc_converter = DocumentConverter( | ||
format_options={InputFormat.PDF: PdfFormatOption(pipeline_options=pipeline_options)} | ||
) | ||
if self.format == "qdrant": | ||
self.client = qdrant_client.QdrantClient(path=self.output) | ||
self.client.set_model(EMBED_MODEL) | ||
self.client.set_sparse_model(SPARSE_MODEL) | ||
# optimizations to reduce ram | ||
self.client.create_collection( | ||
collection_name=COLLECTION_NAME, | ||
vectors_config=self.client.get_fastembed_vector_params(on_disk=True), | ||
sparse_vectors_config=self.client.get_fastembed_sparse_vector_params(on_disk=True), | ||
quantization_config=models.ScalarQuantization( | ||
scalar=models.ScalarQuantizationConfig( | ||
type=models.ScalarType.INT8, | ||
always_ram=True, | ||
), | ||
), | ||
) | ||
|
||
def add(self, file_path): | ||
if os.path.isdir(file_path): | ||
self.walk(file_path) # Walk directory and process all files | ||
else: | ||
self.sources.append(file_path) # Process the single file | ||
|
||
def convert_qdrant(self, results): | ||
documents, ids = [], [] | ||
def chunk(self, docs): | ||
chunker = HybridChunker(tokenizer=EMBED_MODEL, overlap=100, merge_peers=True) | ||
for result in results: | ||
chunk_iter = chunker.chunk(dl_doc=result.document) | ||
documents, ids = [], [] | ||
|
||
for file in docs: | ||
chunk_iter = chunker.chunk(dl_doc=file.document) | ||
for chunk in chunk_iter: | ||
# Extract the enriched text from the chunk | ||
doc_text = chunker.contextualize(chunk=chunk) | ||
# Append to respective lists | ||
documents.append(doc_text) | ||
# Generate unique ID for the chunk | ||
doc_id = self.generate_hash(doc_text) | ||
ids.append(doc_id) | ||
return self.client.add(COLLECTION_NAME, documents=documents, ids=ids, batch_size=1) | ||
return documents, ids | ||
|
||
def convert_milvus(self, docs): | ||
output_dir = Path(self.output) | ||
output_dir.mkdir(parents=True, exist_ok=True) | ||
milvus_client = MilvusClient(uri=os.path.join(self.output, "milvus.db")) | ||
collection_name = COLLECTION_NAME | ||
dmodel = TextEmbedding(model_name=EMBED_MODEL) | ||
smodel = SparseTextEmbedding(model_name=SPARSE_MODEL) | ||
test_embedding = next(dmodel.embed("This is a test")) | ||
embedding_dim = len(test_embedding) | ||
schema = MilvusClient.create_schema( | ||
auto_id=False, | ||
enable_dynamic_field=True, | ||
) | ||
schema.add_field(field_name="id", datatype=DataType.INT64, is_primary=True) | ||
schema.add_field(field_name="text", datatype=DataType.VARCHAR, max_length=1000) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The Milvus supports a
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this is intensional to keep ram low but i will keep this in mind |
||
schema.add_field(field_name="sparse", datatype=DataType.SPARSE_FLOAT_VECTOR) | ||
schema.add_field(field_name="dense", datatype=DataType.FLOAT_VECTOR, dim=embedding_dim) | ||
index_params = milvus_client.prepare_index_params() | ||
index_params.add_index(field_name="dense", index_name="dense_index", index_type="AUTOINDEX") | ||
index_params.add_index(field_name="sparse", index_name="sparse_index", index_type="SPARSE_INVERTED_INDEX") | ||
milvus_client.create_collection(collection_name=collection_name, schema=schema, index_params=index_params) | ||
# Chunk and add chunks to collection 1 by 1 | ||
chunks, ids = self.chunk(docs) | ||
# Batch-embed chunks for better performance | ||
dense_embeddings = list(dmodel.embed(chunks)) | ||
sparse_embeddings_list = list(smodel.embed(chunks)) | ||
|
||
for i, (chunk, id) in enumerate(zip(chunks, ids)): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Calling A much more performant approach is to embed all chunks in a single batch before the loop. While this will use more memory to hold all embeddings at once, the speed improvement will be substantial. You can still insert them one by one or in smaller batches to keep memory usage low during the insertion phase.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Very cool! Implemeneted and works! |
||
sparse_vector = sparse_embeddings_list[i].as_dict() | ||
milvus_client.insert( | ||
collection_name=collection_name, | ||
data=[{"id": id, "text": chunk, "sparse": sparse_vector, "dense": dense_embeddings[i]}], | ||
) | ||
# Flush every 100 records to reduce RAM usage | ||
if i % 100 == 0: | ||
milvus_client.flush(collection_name=collection_name) | ||
print(f"\rProcessed chunk {i+1}/{len(chunks)}", end='', flush=True) | ||
milvus_client.flush(collection_name=collection_name) | ||
print("\n") | ||
return | ||
|
||
def convert_qdrant(self, results): | ||
qclient = qdrant_client.QdrantClient(path=self.output) | ||
qclient.set_model(EMBED_MODEL) | ||
qclient.set_sparse_model(SPARSE_MODEL) | ||
# optimizations to reduce ram | ||
qclient.create_collection( | ||
collection_name=COLLECTION_NAME, | ||
vectors_config=qclient.get_fastembed_vector_params(on_disk=True), | ||
sparse_vectors_config=qclient.get_fastembed_sparse_vector_params(on_disk=True), | ||
quantization_config=models.ScalarQuantization( | ||
scalar=models.ScalarQuantizationConfig( | ||
type=models.ScalarType.INT8, | ||
always_ram=True, | ||
), | ||
), | ||
) | ||
chunks, ids = self.chunk(results) | ||
return qclient.add(COLLECTION_NAME, documents=chunks, ids=ids, batch_size=1) | ||
|
||
def show_progress(self, message, stop_event): | ||
spinner = itertools.cycle([".", "..", "..."]) | ||
while not stop_event.is_set(): | ||
sys.stdout.write(f"\r{message} {next(spinner)} ") | ||
sys.stdout.flush() | ||
time.sleep(0.5) | ||
sys.stdout.write("\r" + " " * 50 + "\r") | ||
|
||
def convert(self): | ||
results = self.doc_converter.convert_all(self.sources) | ||
results = [] | ||
names = [] | ||
for source in self.sources: | ||
name = Path((str(source))).stem | ||
names.append(name) | ||
stop_event = threading.Event() | ||
progress_thread = threading.Thread(target=self.show_progress, args=(f"Converting {name}.pdf", stop_event)) | ||
progress_thread.start() | ||
try: | ||
results.append(self.doc_converter.convert(source)) | ||
finally: | ||
stop_event.set() | ||
progress_thread.join() | ||
print(f"Finished converting {name}.pdf") | ||
|
||
if self.format == "qdrant": | ||
return self.convert_qdrant(results) | ||
if self.format == "markdown": | ||
# Export the converted document to Markdown | ||
return self.convert_markdown(results) | ||
if self.format == "json": | ||
# Export the converted document to JSON | ||
return self.convert_json(results) | ||
if self.format == "milvus": | ||
self.convert_milvus(results) | ||
|
||
def convert_markdown(self, results): | ||
ctr = 0 | ||
# Process the conversion results | ||
for ctr, result in enumerate(results): | ||
dirname = self.output + os.path.dirname(self.sources[ctr]) | ||
os.makedirs(dirname, exist_ok=True) | ||
|
@@ -111,12 +186,12 @@ class Converter: | |
if os.path.isfile(file): | ||
self.sources.append(file) | ||
|
||
def generate_hash(self, document: str) -> str: | ||
"""Generate a unique hash for a document.""" | ||
def generate_hash(self, document: str) -> int: | ||
"""Generate a unique int64 hash from the document text.""" | ||
sha256_hash = hashlib.sha256(document.encode('utf-8')).hexdigest() | ||
|
||
# Use the first 32 characters of the hash to create a UUID | ||
return str(uuid.UUID(sha256_hash[:32])) | ||
uuid_val = uuid.UUID(sha256_hash[:32]) | ||
# Convert to signed int64 (Milvus requires signed 64-bit) | ||
return uuid_val.int & ((1 << 63) - 1) | ||
|
||
|
||
def load(): | ||
|
@@ -139,7 +214,7 @@ parser.add_argument( | |
"--format", | ||
default="qdrant", | ||
help="Output format for RAG Data", | ||
choices=["qdrant", "json", "markdown"], | ||
choices=["qdrant", "json", "markdown", "milvus"], | ||
) | ||
parser.add_argument( | ||
"--ocr", | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Instead of embedding a test string to find the embedding dimension, you can get it directly from the model object. Assuming a recent version of
fastembed
, you can use the.dim
property, which is more efficient and readable.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
'TextEmbedding' object has no attribute 'dim'