1
1
#!/usr/bin/env python3
2
2
3
+ # suppress pkg warning for pymilvus
3
4
import argparse
4
5
import errno
5
6
import hashlib
7
+ import itertools
6
8
import os
7
9
import os .path
8
10
import sys
11
+ import threading
12
+ import time
9
13
import uuid
14
+ import warnings
15
+ from pathlib import Path
10
16
11
17
import docling
12
18
import qdrant_client
13
19
from docling .chunking import HybridChunker
14
20
from docling .datamodel .base_models import InputFormat
15
21
from docling .datamodel .pipeline_options import PdfPipelineOptions
16
22
from docling .document_converter import DocumentConverter , PdfFormatOption
23
+ from fastembed import SparseTextEmbedding , TextEmbedding
24
+ from pymilvus import DataType , MilvusClient
17
25
from qdrant_client import models
18
26
27
+ warnings .filterwarnings ("ignore" , category = UserWarning )
28
+
19
29
# Global Vars
20
30
EMBED_MODEL = os .getenv ("EMBED_MODEL" , "jinaai/jina-embeddings-v2-small-en" )
21
31
SPARSE_MODEL = os .getenv ("SPARSE_MODEL" , "prithivida/Splade_PP_en_v1" )
22
32
COLLECTION_NAME = "rag"
33
+ os .environ ["TOKENIZERS_PARALLELISM" ] = "true"
23
34
24
35
25
36
class Converter :
@@ -37,58 +48,122 @@ class Converter:
37
48
self .doc_converter = DocumentConverter (
38
49
format_options = {InputFormat .PDF : PdfFormatOption (pipeline_options = pipeline_options )}
39
50
)
40
- if self .format == "qdrant" :
41
- self .client = qdrant_client .QdrantClient (path = self .output )
42
- self .client .set_model (EMBED_MODEL )
43
- self .client .set_sparse_model (SPARSE_MODEL )
44
- # optimizations to reduce ram
45
- self .client .create_collection (
46
- collection_name = COLLECTION_NAME ,
47
- vectors_config = self .client .get_fastembed_vector_params (on_disk = True ),
48
- sparse_vectors_config = self .client .get_fastembed_sparse_vector_params (on_disk = True ),
49
- quantization_config = models .ScalarQuantization (
50
- scalar = models .ScalarQuantizationConfig (
51
- type = models .ScalarType .INT8 ,
52
- always_ram = True ,
53
- ),
54
- ),
55
- )
56
51
57
52
def add (self , file_path ):
58
53
if os .path .isdir (file_path ):
59
54
self .walk (file_path ) # Walk directory and process all files
60
55
else :
61
56
self .sources .append (file_path ) # Process the single file
62
57
63
- def convert_qdrant (self , results ):
64
- documents , ids = [], []
58
+ def chunk (self , docs ):
65
59
chunker = HybridChunker (tokenizer = EMBED_MODEL , overlap = 100 , merge_peers = True )
66
- for result in results :
67
- chunk_iter = chunker .chunk (dl_doc = result .document )
60
+ documents , ids = [], []
61
+
62
+ for file in docs :
63
+ chunk_iter = chunker .chunk (dl_doc = file .document )
68
64
for chunk in chunk_iter :
69
65
# Extract the enriched text from the chunk
70
66
doc_text = chunker .contextualize (chunk = chunk )
71
- # Append to respective lists
72
67
documents .append (doc_text )
73
- # Generate unique ID for the chunk
74
68
doc_id = self .generate_hash (doc_text )
75
69
ids .append (doc_id )
76
- return self .client .add (COLLECTION_NAME , documents = documents , ids = ids , batch_size = 1 )
70
+ return documents , ids
71
+
72
+ def convert_milvus (self , docs ):
73
+ output_dir = Path (self .output )
74
+ output_dir .mkdir (parents = True , exist_ok = True )
75
+ milvus_client = MilvusClient (uri = os .path .join (self .output , "milvus.db" ))
76
+ collection_name = COLLECTION_NAME
77
+ dmodel = TextEmbedding (model_name = EMBED_MODEL )
78
+ smodel = SparseTextEmbedding (model_name = SPARSE_MODEL )
79
+ test_embedding = next (dmodel .embed ("This is a test" ))
80
+ embedding_dim = len (test_embedding )
81
+ schema = MilvusClient .create_schema (
82
+ auto_id = False ,
83
+ enable_dynamic_field = True ,
84
+ )
85
+ schema .add_field (field_name = "id" , datatype = DataType .INT64 , is_primary = True )
86
+ schema .add_field (field_name = "text" , datatype = DataType .VARCHAR , max_length = 1000 )
87
+ schema .add_field (field_name = "sparse" , datatype = DataType .SPARSE_FLOAT_VECTOR )
88
+ schema .add_field (field_name = "dense" , datatype = DataType .FLOAT_VECTOR , dim = embedding_dim )
89
+ index_params = milvus_client .prepare_index_params ()
90
+ index_params .add_index (field_name = "dense" , index_name = "dense_index" , index_type = "AUTOINDEX" )
91
+ index_params .add_index (field_name = "sparse" , index_name = "sparse_index" , index_type = "SPARSE_INVERTED_INDEX" )
92
+ milvus_client .create_collection (collection_name = collection_name , schema = schema , index_params = index_params )
93
+ # Chunk and add chunks to collection 1 by 1
94
+ chunks , ids = self .chunk (docs )
95
+ # Batch-embed chunks for better performance
96
+ dense_embeddings = list (dmodel .embed (chunks ))
97
+ sparse_embeddings_list = list (smodel .embed (chunks ))
98
+
99
+ for i , (chunk , id ) in enumerate (zip (chunks , ids )):
100
+ sparse_vector = sparse_embeddings_list [i ].as_dict ()
101
+ milvus_client .insert (
102
+ collection_name = collection_name ,
103
+ data = [{"id" : id , "text" : chunk , "sparse" : sparse_vector , "dense" : dense_embeddings [i ]}],
104
+ )
105
+ # Flush every 100 records to reduce RAM usage
106
+ if i % 100 == 0 :
107
+ milvus_client .flush (collection_name = collection_name )
108
+ print (f"\r Processed chunk { i + 1 } /{ len (chunks )} " , end = '' , flush = True )
109
+ milvus_client .flush (collection_name = collection_name )
110
+ print ("\n " )
111
+ return
112
+
113
+ def convert_qdrant (self , results ):
114
+ qclient = qdrant_client .QdrantClient (path = self .output )
115
+ qclient .set_model (EMBED_MODEL )
116
+ qclient .set_sparse_model (SPARSE_MODEL )
117
+ # optimizations to reduce ram
118
+ qclient .create_collection (
119
+ collection_name = COLLECTION_NAME ,
120
+ vectors_config = qclient .get_fastembed_vector_params (on_disk = True ),
121
+ sparse_vectors_config = qclient .get_fastembed_sparse_vector_params (on_disk = True ),
122
+ quantization_config = models .ScalarQuantization (
123
+ scalar = models .ScalarQuantizationConfig (
124
+ type = models .ScalarType .INT8 ,
125
+ always_ram = True ,
126
+ ),
127
+ ),
128
+ )
129
+ chunks , ids = self .chunk (results )
130
+ return qclient .add (COLLECTION_NAME , documents = chunks , ids = ids , batch_size = 1 )
131
+
132
+ def show_progress (self , message , stop_event ):
133
+ spinner = itertools .cycle (["." , ".." , "..." ])
134
+ while not stop_event .is_set ():
135
+ sys .stdout .write (f"\r { message } { next (spinner )} " )
136
+ sys .stdout .flush ()
137
+ time .sleep (0.5 )
138
+ sys .stdout .write ("\r " + " " * 50 + "\r " )
77
139
78
140
def convert (self ):
79
- results = self .doc_converter .convert_all (self .sources )
141
+ results = []
142
+ names = []
143
+ for source in self .sources :
144
+ name = Path ((str (source ))).stem
145
+ names .append (name )
146
+ stop_event = threading .Event ()
147
+ progress_thread = threading .Thread (target = self .show_progress , args = (f"Converting { name } .pdf" , stop_event ))
148
+ progress_thread .start ()
149
+ try :
150
+ results .append (self .doc_converter .convert (source ))
151
+ finally :
152
+ stop_event .set ()
153
+ progress_thread .join ()
154
+ print (f"Finished converting { name } .pdf" )
155
+
80
156
if self .format == "qdrant" :
81
157
return self .convert_qdrant (results )
82
158
if self .format == "markdown" :
83
- # Export the converted document to Markdown
84
159
return self .convert_markdown (results )
85
160
if self .format == "json" :
86
- # Export the converted document to JSON
87
161
return self .convert_json (results )
162
+ if self .format == "milvus" :
163
+ self .convert_milvus (results )
88
164
89
165
def convert_markdown (self , results ):
90
166
ctr = 0
91
- # Process the conversion results
92
167
for ctr , result in enumerate (results ):
93
168
dirname = self .output + os .path .dirname (self .sources [ctr ])
94
169
os .makedirs (dirname , exist_ok = True )
@@ -111,12 +186,12 @@ class Converter:
111
186
if os .path .isfile (file ):
112
187
self .sources .append (file )
113
188
114
- def generate_hash (self , document : str ) -> str :
115
- """Generate a unique hash for a document."""
189
+ def generate_hash (self , document : str ) -> int :
190
+ """Generate a unique int64 hash from the document text ."""
116
191
sha256_hash = hashlib .sha256 (document .encode ('utf-8' )).hexdigest ()
117
-
118
- # Use the first 32 characters of the hash to create a UUID
119
- return str ( uuid . UUID ( sha256_hash [: 32 ]) )
192
+ uuid_val = uuid . UUID ( sha256_hash [: 32 ])
193
+ # Convert to signed int64 (Milvus requires signed 64-bit)
194
+ return uuid_val . int & (( 1 << 63 ) - 1 )
120
195
121
196
122
197
def load ():
@@ -139,7 +214,7 @@ parser.add_argument(
139
214
"--format" ,
140
215
default = "qdrant" ,
141
216
help = "Output format for RAG Data" ,
142
- choices = ["qdrant" , "json" , "markdown" ],
217
+ choices = ["qdrant" , "json" , "markdown" , "milvus" ],
143
218
)
144
219
parser .add_argument (
145
220
"--ocr" ,
0 commit comments