Skip to content

Commit 28a7cbd

Browse files
authored
feat: support jina rerank for text & image (#62)
* feat: support jina rerank for text & image Signed-off-by: Keming <[email protected]> * align the interface Signed-off-by: Keming <[email protected]> --------- Signed-off-by: Keming <[email protected]>
1 parent 5944604 commit 28a7cbd

File tree

6 files changed

+166
-26
lines changed

6 files changed

+166
-26
lines changed

vechord/model/__init__.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,12 @@
1313
SparseEmbedding,
1414
UMBRELAScore,
1515
)
16-
from vechord.model.jina import JinaEmbeddingRequest, JinaEmbeddingResponse
16+
from vechord.model.jina import (
17+
JinaEmbeddingRequest,
18+
JinaEmbeddingResponse,
19+
JinaRerankRequest,
20+
JinaRerankResponse,
21+
)
1722
from vechord.model.llamacloud import (
1823
LlamaCloudMimeType,
1924
LlamaCloudParseRequest,
@@ -44,6 +49,8 @@
4449
"InputType",
4550
"JinaEmbeddingRequest",
4651
"JinaEmbeddingResponse",
52+
"JinaRerankRequest",
53+
"JinaRerankResponse",
4754
"LlamaCloudMimeType",
4855
"LlamaCloudParseRequest",
4956
"LlamaCloudParseResponse",

vechord/model/jina.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,3 +84,61 @@ def get_emb(self) -> np.ndarray:
8484
if isinstance(emb, list):
8585
return np.array(emb, dtype=np.float32)
8686
return np.frombuffer(emb, dtype=np.float32)
87+
88+
89+
class JinaRerankRequest(msgspec.Struct, kw_only=True):
90+
model: Literal["jina-reranker-v2-base-multilingual", "jina-reranker-m0"]
91+
query: str
92+
top_n: int
93+
documents: list[str | JinaInput]
94+
return_documents: bool = False
95+
96+
@classmethod
97+
def from_query_docs(
98+
cls,
99+
query: str,
100+
documents: list[str],
101+
model: Literal["jina-reranker-m0", "jina-reranker-v2-base-multilingual"],
102+
) -> Self:
103+
if not query or not documents:
104+
raise RequestError("Query and documents must be provided")
105+
106+
return JinaRerankRequest(
107+
model=model,
108+
query=query,
109+
top_n=len(documents),
110+
documents=[JinaInput(text=doc) for doc in documents]
111+
if model == "jina-reranker-m0"
112+
else documents,
113+
)
114+
115+
@classmethod
116+
def from_query_multimodal(
117+
cls,
118+
query: str,
119+
documents: list[str],
120+
doc_type: Literal["text", "image"],
121+
model: Literal["jina-reranker-m0"] = "jina-reranker-m0",
122+
) -> Self:
123+
docs = [
124+
JinaInput(text=doc) if doc_type == "text" else JinaInput(image=doc)
125+
for doc in documents
126+
]
127+
return JinaRerankRequest(
128+
model=model,
129+
query=query,
130+
top_n=len(docs),
131+
documents=docs,
132+
)
133+
134+
135+
class RerankObject(msgspec.Struct, kw_only=True):
136+
index: int
137+
relevance_score: float
138+
139+
140+
class JinaRerankResponse(msgspec.Struct, kw_only=True):
141+
results: list[RerankObject]
142+
143+
def get_indices(self) -> list[int]:
144+
return [result.index for result in self.results]

vechord/pipeline.py

Lines changed: 23 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@
3737
RunRequest,
3838
RunResponse,
3939
)
40-
from vechord.rerank import CohereReranker
40+
from vechord.rerank import CohereReranker, JinaReranker
4141
from vechord.spec import (
4242
AnyOf,
4343
DefaultDocument,
@@ -128,7 +128,7 @@ class _Relation(Table, kw_only=True):
128128
"jina": JinaMultiModalEmbedding,
129129
},
130130
"ocr": {"gemini": GeminiExtractor, "llamaparse": LlamaParseExtractor},
131-
"rerank": {"cohere": CohereReranker},
131+
"rerank": {"cohere": CohereReranker, "jina": JinaReranker},
132132
"graph": {"gemini": GeminiEntityRecognizer},
133133
"index": {"vectorchord": IndexOption},
134134
"search": {"vectorchord": SearchOption},
@@ -184,8 +184,6 @@ def __post_init__(self):
184184
raise RequestError("Vector index is required if `index` is specified")
185185
if self.search and not (self.text_emb or self.multimodal_emb):
186186
raise RequestError("Search requires at least one embedding provider")
187-
if self.search and self.rerank and self.multimodal_emb:
188-
raise RequestError("Rerank only supports text")
189187

190188
@classmethod
191189
def from_steps(cls, steps: list[ResourceRequest]) -> Self:
@@ -340,21 +338,17 @@ async def run_index(self, request: RunRequest, vr: "VechordRegistry") -> RunAck:
340338
rels.extend(conv_rels)
341339
chunks.append(chunk)
342340

343-
async with (
344-
vr.client.get_connection() as conn,
345-
limit_to_transaction_buffer_conn(conn),
346-
):
347-
await vr.insert(doc)
348-
for chunk in chunks:
349-
await vr.insert(chunk)
350-
if self.index.graph:
351-
if request.input_type is not InputType.TEXT:
352-
# insert the fake chunk for image/pdf
353-
await vr.insert(fake_chunk)
354-
await self.graph_insert(
355-
ents=ents, rels=rels, ent_cls=Entity, rel_cls=Relation, vr=vr
356-
)
357-
return RunAck(name=request.name, msg="succeed", uid=doc.uid)
341+
await vr.insert(doc)
342+
for chunk in chunks:
343+
await vr.insert(chunk)
344+
if self.index.graph:
345+
if request.input_type is not InputType.TEXT:
346+
# insert the fake chunk for image/pdf
347+
await vr.insert(fake_chunk)
348+
await self.graph_insert(
349+
ents=ents, rels=rels, ent_cls=Entity, rel_cls=Relation, vr=vr
350+
)
351+
return RunAck(name=request.name, msg="succeed", uid=doc.uid)
358352

359353
async def graph_insert(
360354
self,
@@ -440,7 +434,16 @@ class Relation(_Relation):
440434
if self.search.graph:
441435
resp.extend(await self.graph_search(query, Chunk, Entity, Relation, vr))
442436
if self.rerank:
443-
indices = await self.rerank.rerank(query, [chunk.text for chunk in resp])
437+
if self.multimodal_emb:
438+
indices = await self.rerank.rerank_multimodal(
439+
query=query,
440+
chunks=[chunk.text for chunk in resp],
441+
doc_type=resp.chunk_type,
442+
)
443+
else:
444+
indices = await self.rerank.rerank(
445+
query=query, chunks=[chunk.text for chunk in resp]
446+
)
444447
resp.reorder(indices)
445448
if self.evaluate:
446449
resp.metrics = await self.evaluate.evaluate_with_estimation(

vechord/provider.py

Lines changed: 31 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
GeminiGenerateResponse,
1414
JinaEmbeddingRequest,
1515
JinaEmbeddingResponse,
16+
JinaRerankRequest,
17+
JinaRerankResponse,
1618
VoyageEmbeddingRequest,
1719
VoyageEmbeddingResponse,
1820
VoyageMultiModalEmbeddingRequest,
@@ -22,6 +24,7 @@
2224
GEMINI_EMBEDDING_RPS,
2325
GEMINI_GENERATE_RPS,
2426
JINA_EMBEDDING_RPS,
27+
JINA_RERANK_RPS,
2528
VOYAGE_EMBEDDING_RPS,
2629
RateLimitTransport,
2730
)
@@ -70,7 +73,6 @@ def __init__(self, model: str = "gemini-2.5-flash"):
7073
)
7174

7275
async def query(self, req: GeminiGenerateRequest) -> GeminiGenerateResponse:
73-
"""Query the Gemini model with a request."""
7476
response = await self.client.post(self.url, content=self.encoder.encode(req))
7577
if response.is_error:
7678
raise HTTPCallError(
@@ -107,7 +109,6 @@ def __init__(
107109
self.decoder = msgspec.json.Decoder(GeminiEmbeddingResponse)
108110

109111
async def query(self, req: GeminiEmbeddingRequest) -> GeminiEmbeddingResponse:
110-
"""Query the Gemini embedding model with a request."""
111112
response = await self.client.post(self.url, content=self.encoder.encode(req))
112113
if response.is_error:
113114
raise HTTPCallError(
@@ -137,7 +138,6 @@ def __init__(self, model: str = "jina-embeddings-v4", dim: int = 2048):
137138
self.decoder = msgspec.json.Decoder(JinaEmbeddingResponse)
138139

139140
async def query(self, req: JinaEmbeddingRequest) -> JinaEmbeddingResponse:
140-
"""Query the Jina embedding model with a request."""
141141
response = await self.client.post(self.url, content=self.encoder.encode(req))
142142
if response.is_error:
143143
raise HTTPCallError(
@@ -146,6 +146,34 @@ async def query(self, req: JinaEmbeddingRequest) -> JinaEmbeddingResponse:
146146
return self.decoder.decode(response.content)
147147

148148

149+
class JinaRerankProvider(BaseProvider):
150+
"""Jina Rerank Provider."""
151+
152+
PROVIDER_NAME = "JINA"
153+
154+
def __init__(self, model: str = "jina-reranker-m0"):
155+
super().__init__(model)
156+
self.client = httpx.AsyncClient(
157+
headers={
158+
"Content-Type": "application/json",
159+
"Authorization": f"Bearer {self.api_key}",
160+
},
161+
timeout=httpx.Timeout(120.0, connect=10.0),
162+
transport=RateLimitTransport(max_per_second=JINA_RERANK_RPS),
163+
)
164+
self.url = "https://api.jina.ai/v1/rerank"
165+
self.encoder = msgspec.json.Encoder()
166+
self.decoder = msgspec.json.Decoder(JinaRerankResponse)
167+
168+
async def query(self, req: JinaRerankRequest) -> JinaRerankResponse:
169+
response = await self.client.post(self.url, content=self.encoder.encode(req))
170+
if response.is_error:
171+
raise HTTPCallError(
172+
"Failed to query Jina rerank", response.status_code, response.text
173+
)
174+
return self.decoder.decode(response.content)
175+
176+
149177
class VoyageEmbeddingProvider(BaseProvider):
150178
"""Voyage Embedding Provider."""
151179

@@ -169,7 +197,6 @@ def __init__(self, model: str = "voyage-3.5", dim: int = 1024):
169197
async def query(
170198
self, req: VoyageEmbeddingRequest | VoyageMultiModalEmbeddingRequest
171199
) -> VoyageEmbeddingResponse:
172-
"""Query the Voyage embedding model with a request."""
173200
response = await self.client.post(self.url, content=self.encoder.encode(req))
174201
if response.is_error:
175202
raise HTTPCallError(

vechord/rerank.py

Lines changed: 45 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
from os import environ
55
from typing import TypeVar
66

7+
from vechord.model import JinaRerankRequest
8+
from vechord.provider import JinaRerankProvider
79
from vechord.spec import Table
810

911
T = TypeVar("T", bound=Table)
@@ -15,9 +17,19 @@ async def rerank(self, query: str, chunks: list[str]) -> list[int]:
1517
"""Return the indices of the reranked chunks."""
1618
raise NotImplementedError
1719

20+
@abstractmethod
21+
async def rerank_multimodal(
22+
self, query: str, chunks: list[str], doc_type: str
23+
) -> list[int]:
24+
"""Return the indices of the reranked multimodal chunks."""
25+
raise NotImplementedError
26+
1827

1928
class CohereReranker(BaseReranker):
20-
"""Rerank chunks using Cohere API (requires env `COHERE_API_KEY`)."""
29+
"""Rerank chunks using Cohere API (requires env `COHERE_API_KEY`).
30+
31+
Only supports rerank documents.
32+
"""
2133

2234
def __init__(self, model: str = "rerank-v3.5"):
2335
self.api_key = environ.get("COHERE_API_KEY")
@@ -45,6 +57,38 @@ async def rerank(self, query: str, chunks: list[str]) -> list[int]:
4557
)
4658
return [item.index for item in resp.results]
4759

60+
async def rerank_multimodal(
61+
self, query: str, chunks: list[str], doc_type: str
62+
) -> list[int]:
63+
raise NotImplementedError("Cohere does not support multimodal reranking.")
64+
65+
66+
class JinaReranker(BaseReranker, JinaRerankProvider):
67+
"""Rerank chunks using Jina Rerank API (requires env `JINA_API_KEY`)."""
68+
69+
def __init__(self, model: str = "jina-reranker-m0"):
70+
super().__init__(model)
71+
72+
async def rerank(self, query: str, chunks: list[str]) -> list[int]:
73+
resp = await self.query(
74+
JinaRerankRequest.from_query_docs(query=query, docs=chunks)
75+
)
76+
return resp.get_indices()
77+
78+
async def rerank_multimodal(
79+
self, query: str, chunks: list[str], doc_type: str
80+
) -> list[int]:
81+
"""
82+
Args:
83+
doc_type: "text" or "image"
84+
"""
85+
resp = await self.query(
86+
JinaRerankRequest.from_query_multimodal(
87+
query=query, documents=chunks, doc_type=doc_type
88+
)
89+
)
90+
return resp.get_indices()
91+
4892

4993
class ReciprocalRankFusion:
5094
"""Fuse chunks using reciprocal rank."""

vechord/utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
VOYAGE_EMBEDDING_RPS = 33.33
1313
# https://jina.ai/api-dashboard/rate-limit
1414
JINA_EMBEDDING_RPS = 8.33
15+
JINA_RERANK_RPS = 8.33
1516

1617

1718
class RateLimitTransport(httpx.AsyncHTTPTransport):

0 commit comments

Comments
 (0)