Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion container-images/scripts/build_rag.sh
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ docling() {
}

rag() {
${python} -m pip install --prefix=/usr wheel qdrant_client fastembed openai fastapi uvicorn
${python} -m pip install --prefix=/usr wheel qdrant_client pymilvus fastembed openai fastapi uvicorn
rag_framework load
}

Expand Down
141 changes: 108 additions & 33 deletions container-images/scripts/doc2rag
Original file line number Diff line number Diff line change
@@ -1,25 +1,36 @@
#!/usr/bin/env python3

# suppress pkg warning for pymilvus
import argparse
import errno
import hashlib
import itertools
import os
import os.path
import sys
import threading
import time
import uuid
import warnings
from pathlib import Path

import docling
import qdrant_client
from docling.chunking import HybridChunker
from docling.datamodel.base_models import InputFormat
from docling.datamodel.pipeline_options import PdfPipelineOptions
from docling.document_converter import DocumentConverter, PdfFormatOption
from fastembed import SparseTextEmbedding, TextEmbedding
from pymilvus import DataType, MilvusClient
from qdrant_client import models

warnings.filterwarnings("ignore", category=UserWarning)

# Global Vars
EMBED_MODEL = os.getenv("EMBED_MODEL", "jinaai/jina-embeddings-v2-small-en")
SPARSE_MODEL = os.getenv("SPARSE_MODEL", "prithivida/Splade_PP_en_v1")
COLLECTION_NAME = "rag"
os.environ["TOKENIZERS_PARALLELISM"] = "true"


class Converter:
Expand All @@ -37,58 +48,122 @@ class Converter:
self.doc_converter = DocumentConverter(
format_options={InputFormat.PDF: PdfFormatOption(pipeline_options=pipeline_options)}
)
if self.format == "qdrant":
self.client = qdrant_client.QdrantClient(path=self.output)
self.client.set_model(EMBED_MODEL)
self.client.set_sparse_model(SPARSE_MODEL)
# optimizations to reduce ram
self.client.create_collection(
collection_name=COLLECTION_NAME,
vectors_config=self.client.get_fastembed_vector_params(on_disk=True),
sparse_vectors_config=self.client.get_fastembed_sparse_vector_params(on_disk=True),
quantization_config=models.ScalarQuantization(
scalar=models.ScalarQuantizationConfig(
type=models.ScalarType.INT8,
always_ram=True,
),
),
)

def add(self, file_path):
if os.path.isdir(file_path):
self.walk(file_path) # Walk directory and process all files
else:
self.sources.append(file_path) # Process the single file

def convert_qdrant(self, results):
documents, ids = [], []
def chunk(self, docs):
chunker = HybridChunker(tokenizer=EMBED_MODEL, overlap=100, merge_peers=True)
for result in results:
chunk_iter = chunker.chunk(dl_doc=result.document)
documents, ids = [], []

for file in docs:
chunk_iter = chunker.chunk(dl_doc=file.document)
for chunk in chunk_iter:
# Extract the enriched text from the chunk
doc_text = chunker.contextualize(chunk=chunk)
# Append to respective lists
documents.append(doc_text)
# Generate unique ID for the chunk
doc_id = self.generate_hash(doc_text)
ids.append(doc_id)
return self.client.add(COLLECTION_NAME, documents=documents, ids=ids, batch_size=1)
return documents, ids

def convert_milvus(self, docs):
output_dir = Path(self.output)
output_dir.mkdir(parents=True, exist_ok=True)
milvus_client = MilvusClient(uri=os.path.join(self.output, "milvus.db"))
collection_name = COLLECTION_NAME
dmodel = TextEmbedding(model_name=EMBED_MODEL)
smodel = SparseTextEmbedding(model_name=SPARSE_MODEL)
test_embedding = next(dmodel.embed("This is a test"))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Instead of embedding a test string to find the embedding dimension, you can get it directly from the model object. Assuming a recent version of fastembed, you can use the .dim property, which is more efficient and readable.

embedding_dim = dmodel.dim

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

'TextEmbedding' object has no attribute 'dim'

embedding_dim = len(test_embedding)
schema = MilvusClient.create_schema(
auto_id=False,
enable_dynamic_field=True,
)
schema.add_field(field_name="id", datatype=DataType.INT64, is_primary=True)
schema.add_field(field_name="text", datatype=DataType.VARCHAR, max_length=1000)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The max_length for the text field is hardcoded to 1000 characters. If the document chunker produces chunks larger than this, the text will be truncated upon insertion into Milvus, leading to silent data loss and potentially affecting the quality of RAG results.

Milvus supports a max_length of up to 65,535 for VARCHAR. I recommend increasing this limit significantly to prevent data truncation.

schema.add_field(field_name="text", datatype=DataType.VARCHAR, max_length=65535)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is intensional to keep ram low but i will keep this in mind

schema.add_field(field_name="sparse", datatype=DataType.SPARSE_FLOAT_VECTOR)
schema.add_field(field_name="dense", datatype=DataType.FLOAT_VECTOR, dim=embedding_dim)
index_params = milvus_client.prepare_index_params()
index_params.add_index(field_name="dense", index_name="dense_index", index_type="AUTOINDEX")
index_params.add_index(field_name="sparse", index_name="sparse_index", index_type="SPARSE_INVERTED_INDEX")
milvus_client.create_collection(collection_name=collection_name, schema=schema, index_params=index_params)
# Chunk and add chunks to collection 1 by 1
chunks, ids = self.chunk(docs)
# Batch-embed chunks for better performance
dense_embeddings = list(dmodel.embed(chunks))
sparse_embeddings_list = list(smodel.embed(chunks))

for i, (chunk, id) in enumerate(zip(chunks, ids)):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Calling dmodel.embed() and smodel.embed() inside a loop for each chunk is highly inefficient, as fastembed is optimized for batch processing. This will be very slow for a large number of chunks.

A much more performant approach is to embed all chunks in a single batch before the loop. While this will use more memory to hold all embeddings at once, the speed improvement will be substantial. You can still insert them one by one or in smaller batches to keep memory usage low during the insertion phase.

# Batch-embed chunks for better performance
dense_embeddings = list(dmodel.embed(chunks))
sparse_embeddings_list = list(smodel.embed(chunks))

for i, (chunk, id) in enumerate(zip(chunks, ids)):
    sparse_vector = sparse_embeddings_list[i].as_dict()
    milvus_client.insert(
        collection_name=collection_name,
        data=[{
            "id": id,
            "text": chunk,
            "sparse": sparse_vector,
            "dense": dense_embeddings[i]
        }]
    )
    # Flush every 100 records to reduce RAM usage
    if (i + 1) % 100 == 0: 
        milvus_client.flush(collection_name=collection_name)
    print(f"\rProcessed chunk {i+1}/{len(chunks)}", end='', flush=True)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very cool! Implemeneted and works!

sparse_vector = sparse_embeddings_list[i].as_dict()
milvus_client.insert(
collection_name=collection_name,
data=[{"id": id, "text": chunk, "sparse": sparse_vector, "dense": dense_embeddings[i]}],
)
# Flush every 100 records to reduce RAM usage
if i % 100 == 0:
milvus_client.flush(collection_name=collection_name)
print(f"\rProcessed chunk {i+1}/{len(chunks)}", end='', flush=True)
milvus_client.flush(collection_name=collection_name)
print("\n")
return

def convert_qdrant(self, results):
qclient = qdrant_client.QdrantClient(path=self.output)
qclient.set_model(EMBED_MODEL)
qclient.set_sparse_model(SPARSE_MODEL)
# optimizations to reduce ram
qclient.create_collection(
collection_name=COLLECTION_NAME,
vectors_config=qclient.get_fastembed_vector_params(on_disk=True),
sparse_vectors_config=qclient.get_fastembed_sparse_vector_params(on_disk=True),
quantization_config=models.ScalarQuantization(
scalar=models.ScalarQuantizationConfig(
type=models.ScalarType.INT8,
always_ram=True,
),
),
)
chunks, ids = self.chunk(results)
return qclient.add(COLLECTION_NAME, documents=chunks, ids=ids, batch_size=1)

def show_progress(self, message, stop_event):
spinner = itertools.cycle([".", "..", "..."])
while not stop_event.is_set():
sys.stdout.write(f"\r{message} {next(spinner)} ")
sys.stdout.flush()
time.sleep(0.5)
sys.stdout.write("\r" + " " * 50 + "\r")

def convert(self):
results = self.doc_converter.convert_all(self.sources)
results = []
names = []
for source in self.sources:
name = Path((str(source))).stem
names.append(name)
stop_event = threading.Event()
progress_thread = threading.Thread(target=self.show_progress, args=(f"Converting {name}.pdf", stop_event))
progress_thread.start()
try:
results.append(self.doc_converter.convert(source))
finally:
stop_event.set()
progress_thread.join()
print(f"Finished converting {name}.pdf")

if self.format == "qdrant":
return self.convert_qdrant(results)
if self.format == "markdown":
# Export the converted document to Markdown
return self.convert_markdown(results)
if self.format == "json":
# Export the converted document to JSON
return self.convert_json(results)
if self.format == "milvus":
self.convert_milvus(results)

def convert_markdown(self, results):
ctr = 0
# Process the conversion results
for ctr, result in enumerate(results):
dirname = self.output + os.path.dirname(self.sources[ctr])
os.makedirs(dirname, exist_ok=True)
Expand All @@ -111,12 +186,12 @@ class Converter:
if os.path.isfile(file):
self.sources.append(file)

def generate_hash(self, document: str) -> str:
"""Generate a unique hash for a document."""
def generate_hash(self, document: str) -> int:
"""Generate a unique int64 hash from the document text."""
sha256_hash = hashlib.sha256(document.encode('utf-8')).hexdigest()

# Use the first 32 characters of the hash to create a UUID
return str(uuid.UUID(sha256_hash[:32]))
uuid_val = uuid.UUID(sha256_hash[:32])
# Convert to signed int64 (Milvus requires signed 64-bit)
return uuid_val.int & ((1 << 63) - 1)


def load():
Expand All @@ -139,7 +214,7 @@ parser.add_argument(
"--format",
default="qdrant",
help="Output format for RAG Data",
choices=["qdrant", "json", "markdown"],
choices=["qdrant", "json", "markdown", "milvus"],
)
parser.add_argument(
"--ocr",
Expand Down
Loading
Loading