Skip to content

Commit 2b31473

Browse files
authored
🎱 feat: Use Shared Thread Pool across Operations (#166)
1 parent 33a221a commit 2b31473

File tree

5 files changed

+126
-39
lines changed

5 files changed

+126
-39
lines changed

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,3 +6,5 @@ uploads/
66
myenv/
77
venv/
88
*.pyc
9+
dev.yml
10+
SHOPIFY.md

app/routes/document_routes.py

Lines changed: 24 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -38,10 +38,10 @@
3838

3939

4040
@router.get("/ids")
41-
async def get_all_ids():
41+
async def get_all_ids(request: Request):
4242
try:
4343
if isinstance(vector_store, AsyncPgVector):
44-
ids = await vector_store.get_all_ids()
44+
ids = await vector_store.get_all_ids(executor=request.app.state.thread_pool)
4545
else:
4646
ids = vector_store.get_all_ids()
4747

@@ -80,11 +80,11 @@ async def health_check():
8080

8181

8282
@router.get("/documents", response_model=list[DocumentResponse])
83-
async def get_documents_by_ids(ids: list[str] = Query(...)):
83+
async def get_documents_by_ids(request: Request, ids: list[str] = Query(...)):
8484
try:
8585
if isinstance(vector_store, AsyncPgVector):
86-
existing_ids = await vector_store.get_filtered_ids(ids)
87-
documents = await vector_store.get_documents_by_ids(ids)
86+
existing_ids = await vector_store.get_filtered_ids(ids, executor=request.app.state.thread_pool)
87+
documents = await vector_store.get_documents_by_ids(ids, executor=request.app.state.thread_pool)
8888
else:
8989
existing_ids = vector_store.get_filtered_ids(ids)
9090
documents = vector_store.get_documents_by_ids(ids)
@@ -118,11 +118,11 @@ async def get_documents_by_ids(ids: list[str] = Query(...)):
118118

119119

120120
@router.delete("/documents")
121-
async def delete_documents(document_ids: List[str] = Body(...)):
121+
async def delete_documents(request: Request, document_ids: List[str] = Body(...)):
122122
try:
123123
if isinstance(vector_store, AsyncPgVector):
124-
existing_ids = await vector_store.get_filtered_ids(document_ids)
125-
await vector_store.delete(ids=document_ids)
124+
existing_ids = await vector_store.get_filtered_ids(document_ids, executor=request.app.state.thread_pool)
125+
await vector_store.delete(ids=document_ids, executor=request.app.state.thread_pool)
126126
else:
127127
existing_ids = vector_store.get_filtered_ids(document_ids)
128128
vector_store.delete(ids=document_ids)
@@ -175,12 +175,11 @@ async def query_embeddings_by_file_id(
175175
embedding = get_cached_query_embedding(body.query)
176176

177177
if isinstance(vector_store, AsyncPgVector):
178-
documents = await run_in_executor(
179-
None,
180-
vector_store.similarity_search_with_score_by_vector,
178+
documents = await vector_store.asimilarity_search_with_score_by_vector(
181179
embedding,
182180
k=body.k,
183181
filter={"file_id": body.file_id},
182+
executor=request.app.state.thread_pool
184183
)
185184
else:
186185
documents = vector_store.similarity_search_with_score_by_vector(
@@ -246,6 +245,7 @@ async def store_data_in_vector_db(
246245
file_id: str,
247246
user_id: str = "",
248247
clean_content: bool = False,
248+
executor = None,
249249
) -> bool:
250250
text_splitter = RecursiveCharacterTextSplitter(
251251
chunk_size=CHUNK_SIZE, chunk_overlap=CHUNK_OVERLAP
@@ -274,7 +274,7 @@ async def store_data_in_vector_db(
274274
try:
275275
if isinstance(vector_store, AsyncPgVector):
276276
ids = await vector_store.aadd_documents(
277-
docs, ids=[file_id] * len(documents)
277+
docs, ids=[file_id] * len(documents), executor=executor
278278
)
279279
else:
280280
ids = vector_store.add_documents(docs, ids=[file_id] * len(documents))
@@ -312,9 +312,9 @@ async def embed_local_file(
312312
loader, known_type, file_ext = get_loader(
313313
document.filename, document.file_content_type, document.filepath
314314
)
315-
data = await run_in_executor(None, loader.load)
315+
data = await run_in_executor(request.app.state.thread_pool, loader.load)
316316
result = await store_data_in_vector_db(
317-
data, document.file_id, user_id, clean_content=file_ext == "pdf"
317+
data, document.file_id, user_id, clean_content=file_ext == "pdf", executor=request.app.state.thread_pool
318318
)
319319

320320
if result:
@@ -390,9 +390,9 @@ async def embed_file(
390390
loader, known_type, file_ext = get_loader(
391391
file.filename, file.content_type, temp_file_path
392392
)
393-
data = await run_in_executor(None, loader.load)
393+
data = await run_in_executor(request.app.state.thread_pool, loader.load)
394394
result = await store_data_in_vector_db(
395-
data=data, file_id=file_id, user_id=user_id, clean_content=file_ext == "pdf"
395+
data=data, file_id=file_id, user_id=user_id, clean_content=file_ext == "pdf", executor=request.app.state.thread_pool
396396
)
397397

398398
if not result:
@@ -454,12 +454,12 @@ async def embed_file(
454454

455455

456456
@router.get("/documents/{id}/context")
457-
async def load_document_context(id: str):
457+
async def load_document_context(request: Request, id: str):
458458
ids = [id]
459459
try:
460460
if isinstance(vector_store, AsyncPgVector):
461-
existing_ids = await vector_store.get_filtered_ids(ids)
462-
documents = await vector_store.get_documents_by_ids(ids)
461+
existing_ids = await vector_store.get_filtered_ids(ids, executor=request.app.state.thread_pool)
462+
documents = await vector_store.get_documents_by_ids(ids, executor=request.app.state.thread_pool)
463463
else:
464464
existing_ids = vector_store.get_filtered_ids(ids)
465465
documents = vector_store.get_documents_by_ids(ids)
@@ -525,9 +525,9 @@ async def embed_file_upload(
525525
uploaded_file.filename, uploaded_file.content_type, temp_file_path
526526
)
527527

528-
data = await run_in_executor(None, loader.load)
528+
data = await run_in_executor(request.app.state.thread_pool, loader.load)
529529
result = await store_data_in_vector_db(
530-
data, file_id, user_id, clean_content=file_ext == "pdf"
530+
data, file_id, user_id, clean_content=file_ext == "pdf", executor=request.app.state.thread_pool
531531
)
532532

533533
if not result:
@@ -566,19 +566,18 @@ async def embed_file_upload(
566566

567567

568568
@router.post("/query_multiple")
569-
async def query_embeddings_by_file_ids(body: QueryMultipleBody):
569+
async def query_embeddings_by_file_ids(request: Request, body: QueryMultipleBody):
570570
try:
571571
# Get the embedding of the query text
572572
embedding = get_cached_query_embedding(body.query)
573573

574574
# Perform similarity search with the query embedding and filter by the file_ids in metadata
575575
if isinstance(vector_store, AsyncPgVector):
576-
documents = await run_in_executor(
577-
None,
578-
vector_store.similarity_search_with_score_by_vector,
576+
documents = await vector_store.asimilarity_search_with_score_by_vector(
579577
embedding,
580578
k=body.k,
581579
filter={"file_id": {"$in": body.file_ids}},
580+
executor=request.app.state.thread_pool
582581
)
583582
else:
584583
documents = vector_store.similarity_search_with_score_by_vector(
Lines changed: 65 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,75 @@
1-
from typing import Optional
1+
from typing import Optional, List, Tuple, Dict, Any
2+
import asyncio
23
from langchain_core.documents import Document
34
from langchain_core.runnables.config import run_in_executor
45
from .extended_pg_vector import ExtendedPgVector
56

67
class AsyncPgVector(ExtendedPgVector):
7-
async def get_all_ids(self) -> list[str]:
8-
return await run_in_executor(None, super().get_all_ids)
8+
def __init__(self, *args, **kwargs):
9+
super().__init__(*args, **kwargs)
10+
self._thread_pool = None
911

10-
async def get_filtered_ids(self, ids: list[str]) -> list[str]:
11-
return await run_in_executor(None, super().get_filtered_ids, ids)
12+
def _get_thread_pool(self):
13+
if self._thread_pool is None:
14+
try:
15+
# Try to get the thread pool from FastAPI app state
16+
import contextvars
17+
from fastapi import Request
18+
# This is a fallback - in practice, we'll pass the executor explicitly
19+
loop = asyncio.get_running_loop()
20+
self._thread_pool = getattr(loop, '_default_executor', None)
21+
except:
22+
pass
23+
return self._thread_pool
24+
25+
async def get_all_ids(self, executor=None) -> list[str]:
26+
executor = executor or self._get_thread_pool()
27+
return await run_in_executor(executor, super().get_all_ids)
28+
29+
async def get_filtered_ids(self, ids: list[str], executor=None) -> list[str]:
30+
executor = executor or self._get_thread_pool()
31+
return await run_in_executor(executor, super().get_filtered_ids, ids)
1232

13-
async def get_documents_by_ids(self, ids: list[str]) -> list[Document]:
14-
return await run_in_executor(None, super().get_documents_by_ids, ids)
33+
async def get_documents_by_ids(self, ids: list[str], executor=None) -> list[Document]:
34+
executor = executor or self._get_thread_pool()
35+
return await run_in_executor(executor, super().get_documents_by_ids, ids)
1536

1637
async def delete(
17-
self, ids: Optional[list[str]] = None, collection_only: bool = False
38+
self, ids: Optional[list[str]] = None, collection_only: bool = False, executor=None
1839
) -> None:
19-
await run_in_executor(None, self._delete_multiple, ids, collection_only)
40+
executor = executor or self._get_thread_pool()
41+
await run_in_executor(executor, self._delete_multiple, ids, collection_only)
42+
43+
async def asimilarity_search_with_score_by_vector(
44+
self,
45+
embedding: List[float],
46+
k: int = 4,
47+
filter: Optional[Dict[str, Any]] = None,
48+
executor=None
49+
) -> List[Tuple[Document, float]]:
50+
"""Async version of similarity_search_with_score_by_vector"""
51+
executor = executor or self._get_thread_pool()
52+
return await run_in_executor(
53+
executor,
54+
super().similarity_search_with_score_by_vector,
55+
embedding,
56+
k,
57+
filter
58+
)
59+
60+
async def aadd_documents(
61+
self,
62+
documents: List[Document],
63+
ids: Optional[List[str]] = None,
64+
executor=None,
65+
**kwargs
66+
) -> List[str]:
67+
"""Async version of add_documents"""
68+
executor = executor or self._get_thread_pool()
69+
return await run_in_executor(
70+
executor,
71+
super().add_documents,
72+
documents,
73+
ids=ids,
74+
**kwargs
75+
)

main.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
# main.py
2+
import os
23
import uvicorn
34
from fastapi import FastAPI, Request
45
from fastapi.exceptions import RequestValidationError
56
from fastapi.middleware.cors import CORSMiddleware
67
from contextlib import asynccontextmanager
8+
from concurrent.futures import ThreadPoolExecutor
79

810
from starlette.responses import JSONResponse
911

@@ -16,11 +18,21 @@
1618
@asynccontextmanager
1719
async def lifespan(app: FastAPI):
1820
# Startup logic goes here
21+
# Create bounded thread pool executor based on CPU cores
22+
max_workers = min(int(os.getenv("RAG_THREAD_POOL_SIZE", str(os.cpu_count()))), 8) # Cap at 8
23+
app.state.thread_pool = ThreadPoolExecutor(max_workers=max_workers, thread_name_prefix="rag-worker")
24+
logger.info(f"Initialized thread pool with {max_workers} workers (CPU cores: {os.cpu_count()})")
25+
1926
if VECTOR_DB_TYPE == VectorDBType.PGVECTOR:
2027
await PSQLDatabase.get_pool() # Initialize the pool
2128
await ensure_custom_id_index_on_embedding()
2229

2330
yield
31+
32+
# Cleanup logic
33+
logger.info("Shutting down thread pool")
34+
app.state.thread_pool.shutdown(wait=True)
35+
logger.info("Thread pool shutdown complete")
2436

2537
app = FastAPI(lifespan=lifespan, debug=debug_mode)
2638

tests/test_main.py

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import pytest
55
from fastapi.testclient import TestClient
66
from langchain_core.documents import Document
7+
from concurrent.futures import ThreadPoolExecutor
78

89
from main import app
910

@@ -24,19 +25,23 @@ def auth_headers():
2425
def override_vector_store(monkeypatch):
2526
from app.config import vector_store
2627

28+
# Initialize thread pool for tests since TestClient doesn't run lifespan
29+
if not hasattr(app.state, 'thread_pool') or app.state.thread_pool is None:
30+
app.state.thread_pool = ThreadPoolExecutor(max_workers=2, thread_name_prefix="test-worker")
31+
2732
# Override get_all_ids as an async function.
28-
async def dummy_get_all_ids():
33+
async def dummy_get_all_ids(executor=None):
2934
return ["testid1", "testid2"]
3035
monkeypatch.setattr(vector_store, "get_all_ids", dummy_get_all_ids)
3136

3237
# Override get_filtered_ids as an async function.
33-
async def dummy_get_filtered_ids(ids):
38+
async def dummy_get_filtered_ids(ids, executor=None):
3439
dummy_ids = ["testid1", "testid2"]
3540
return [id for id in dummy_ids if id in ids]
3641
monkeypatch.setattr(vector_store, "get_filtered_ids", dummy_get_filtered_ids)
3742

3843
# Override get_documents_by_ids as an async function.
39-
async def dummy_get_documents_by_ids(ids):
44+
async def dummy_get_documents_by_ids(ids, executor=None):
4045
return [
4146
Document(page_content="Test content", metadata={"file_id": id})
4247
for id in ids
@@ -56,22 +61,35 @@ def dummy_similarity_search_with_score_by_vector(embedding, k, filter):
5661
metadata={"file_id": filter.get("file_id", "testid1"), "user_id": "testuser"},
5762
)
5863
return [(doc, 0.9)]
64+
65+
async def dummy_asimilarity_search_with_score_by_vector(embedding, k, filter=None, executor=None):
66+
doc = Document(
67+
page_content="Queried content",
68+
metadata={"file_id": filter.get("file_id", "testid1") if filter else "testid1", "user_id": "testuser"},
69+
)
70+
return [(doc, 0.9)]
71+
5972
monkeypatch.setattr(
6073
vector_store,
6174
"similarity_search_with_score_by_vector",
6275
dummy_similarity_search_with_score_by_vector,
6376
)
77+
monkeypatch.setattr(
78+
vector_store,
79+
"asimilarity_search_with_score_by_vector",
80+
dummy_asimilarity_search_with_score_by_vector,
81+
)
6482

6583
# Override document addition functions.
6684
def dummy_add_documents(docs, ids):
6785
return ids
68-
async def dummy_aadd_documents(docs, ids):
86+
async def dummy_aadd_documents(docs, ids=None, executor=None):
6987
return ids
7088
monkeypatch.setattr(vector_store, "add_documents", dummy_add_documents)
7189
monkeypatch.setattr(vector_store, "aadd_documents", dummy_aadd_documents)
7290

7391
# Override delete function.
74-
async def dummy_delete(ids=None, collection_only=False):
92+
async def dummy_delete(ids=None, collection_only=False, executor=None):
7593
return None
7694
monkeypatch.setattr(vector_store, "delete", dummy_delete)
7795

0 commit comments

Comments
 (0)