Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions libs/voyageai/langchain_voyageai/rerank.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,15 +114,18 @@ def compress_documents(

Returns:
A sequence of compressed documents in relevance_score order.
Each document's metadata includes 'relevance_score' and 'total_tokens'.
"""
if len(documents) == 0:
return []

rerank_result = self._rerank(documents, query)
compressed = []
for res in self._rerank(documents, query).results:
for res in rerank_result.results:
doc = documents[res.index]
doc_copy = Document(doc.page_content, metadata=deepcopy(doc.metadata))
doc_copy.metadata["relevance_score"] = res.relevance_score
doc_copy.metadata["total_tokens"] = rerank_result.total_tokens
compressed.append(doc_copy)
return compressed

Expand All @@ -142,14 +145,17 @@ async def acompress_documents(

Returns:
A sequence of compressed documents in relevance_score order.
Each document's metadata includes 'relevance_score' and 'total_tokens'.
"""
if len(documents) == 0:
return []

rerank_result = await self._arerank(documents, query)
compressed = []
for res in (await self._arerank(documents, query)).results:
for res in rerank_result.results:
doc = documents[res.index]
doc_copy = Document(doc.page_content, metadata=deepcopy(doc.metadata))
doc_copy.metadata["relevance_score"] = res.relevance_score
doc_copy.metadata["total_tokens"] = rerank_result.total_tokens
compressed.append(doc_copy)
return compressed
8 changes: 8 additions & 0 deletions libs/voyageai/tests/integration_tests/test_rerank.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,10 @@ def test_sync() -> None:
query="When is the Apple's conference call scheduled?", documents=documents
)
assert len(doc_list) == len(result)
for doc in result:
assert "total_tokens" in doc.metadata
assert isinstance(doc.metadata["total_tokens"], int)
assert doc.metadata["total_tokens"] > 0


async def test_async() -> None:
Expand Down Expand Up @@ -66,3 +70,7 @@ async def test_async() -> None:
query="When is the Apple's conference call scheduled?", documents=documents
)
assert len(doc_list) == len(result)
for doc in result:
assert "total_tokens" in doc.metadata
assert isinstance(doc.metadata["total_tokens"], int)
assert doc.metadata["total_tokens"] > 0
4 changes: 2 additions & 2 deletions libs/voyageai/tests/unit_tests/test_rerank.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,12 +53,12 @@ def test_rerank_unit_test(mocker: Any) -> None:
Document(
page_content="Photosynthesis in plants converts light energy into "
"glucose and produces essential oxygen.",
metadata={"relevance_score": 0.9},
metadata={"relevance_score": 0.9, "total_tokens": 255},
),
Document(
page_content="The Mediterranean diet emphasizes fish, olive oil, and "
"vegetables, believed to reduce chronic diseases.",
metadata={"relevance_score": 0.8},
metadata={"relevance_score": 0.8, "total_tokens": 255},
),
]

Expand Down