Skip to content

Commit 078f132

Browse files
committed
feat: enforce Voyage token limits in embeddings
1 parent bf0d733 commit 078f132

File tree

7 files changed

+492
-45
lines changed

7 files changed

+492
-45
lines changed

libs/voyageai/langchain_voyageai/embeddings.py

Lines changed: 128 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import logging
2-
from typing import Any, Iterable, List, Literal, Optional, cast
2+
from typing import Any, Iterable, Iterator, List, Literal, Optional, Tuple, cast
33

44
import voyageai # type: ignore
55
from langchain_core.embeddings import Embeddings
@@ -20,6 +20,12 @@
2020
DEFAULT_VOYAGE_3_LITE_BATCH_SIZE = 30
2121
DEFAULT_VOYAGE_3_BATCH_SIZE = 10
2222
DEFAULT_BATCH_SIZE = 7
23+
MAX_DOCUMENTS_PER_REQUEST = 1_000
24+
DEFAULT_MAX_TOKENS_PER_REQUEST = 120_000
25+
TOKEN_LIMIT_OVERRIDES: Tuple[Tuple[int, Tuple[str, ...]], ...] = (
26+
(1_000_000, ("voyage-3.5-lite", "voyage-3-lite")),
27+
(320_000, ("voyage-3.5", "voyage-3", "voyage-2", "voyage-02")),
28+
)
2329

2430

2531
class VoyageAIEmbeddings(BaseModel, Embeddings):
@@ -85,21 +91,69 @@ def validate_environment(self) -> Self:
8591
self._aclient = voyageai.client_async.AsyncClient(api_key=api_key_str)
8692
return self
8793

88-
def _get_batch_iterator(self, texts: List[str]) -> Iterable:
89-
if self.show_progress_bar:
90-
try:
91-
from tqdm.auto import tqdm # type: ignore
92-
except ImportError as e:
93-
raise ImportError(
94-
"Must have tqdm installed if `show_progress_bar` is set to True. "
95-
"Please install with `pip install tqdm`."
96-
) from e
94+
def _max_documents_per_batch(self) -> int:
95+
"""Return the maximum number of documents allowed in a single request."""
96+
return max(1, min(self.batch_size, MAX_DOCUMENTS_PER_REQUEST))
9797

98-
_iter = tqdm(range(0, len(texts), self.batch_size))
99-
else:
100-
_iter = range(0, len(texts), self.batch_size) # type: ignore
98+
def _max_tokens_per_batch(self) -> int:
99+
"""Return the maximum number of tokens allowed for the current model."""
100+
model_name = self.model
101+
for limit, models in TOKEN_LIMIT_OVERRIDES:
102+
if model_name in models:
103+
return limit
104+
return DEFAULT_MAX_TOKENS_PER_REQUEST
101105

102-
return _iter
106+
def _token_lengths(self, texts: List[str]) -> List[int]:
107+
"""Return token lengths for texts using the Voyage client tokenizer."""
108+
try:
109+
tokenized = self._client.tokenize(texts, self.model)
110+
except Exception:
111+
logger.debug("Failed to tokenize texts for model %s", self.model)
112+
raise
113+
return [len(tokens) for tokens in tokenized]
114+
115+
def _iter_token_safe_batch_slices(
116+
self, texts: List[str]
117+
) -> Iterator[Tuple[int, int]]:
118+
"""Yield (start, end) indices for batches within token and length limits."""
119+
if not texts:
120+
return
121+
122+
token_lengths = self._token_lengths(texts)
123+
max_docs = self._max_documents_per_batch()
124+
max_tokens = self._max_tokens_per_batch()
125+
126+
index = 0
127+
total_texts = len(texts)
128+
while index < total_texts:
129+
start = index
130+
batch_tokens = 0
131+
batch_docs = 0
132+
while index < total_texts and batch_docs < max_docs:
133+
current_tokens = token_lengths[index]
134+
if batch_docs > 0 and batch_tokens + current_tokens > max_tokens:
135+
break
136+
137+
if current_tokens > max_tokens and batch_docs == 0:
138+
logger.warning(
139+
"Text at index %s exceeds Voyage token limit (%s > %s). "
140+
"Sending as a single-item batch; API may truncate or error.",
141+
index,
142+
current_tokens,
143+
max_tokens,
144+
)
145+
index += 1
146+
batch_docs += 1
147+
batch_tokens = current_tokens
148+
break
149+
150+
batch_tokens += current_tokens
151+
batch_docs += 1
152+
index += 1
153+
154+
if start == index:
155+
index += 1
156+
yield (start, index)
103157

104158
def _is_context_model(self) -> bool:
105159
"""Check if the model is a contextualized embedding model."""
@@ -120,16 +174,36 @@ def _embed_context(
120174
def _embed_regular(self, texts: List[str], input_type: str) -> List[List[float]]:
121175
"""Embed using regular embedding API."""
122176
embeddings: List[List[float]] = []
123-
_iter = self._get_batch_iterator(texts)
124-
for i in _iter:
125-
r = self._client.embed(
126-
texts[i : i + self.batch_size],
127-
model=self.model,
128-
input_type=input_type,
129-
truncation=self.truncation,
130-
output_dimension=self.output_dimension,
131-
).embeddings
132-
embeddings.extend(cast(Iterable[List[float]], r))
177+
progress = None
178+
if self.show_progress_bar:
179+
try:
180+
from tqdm.auto import tqdm # type: ignore
181+
except ImportError as e:
182+
raise ImportError(
183+
"Must have tqdm installed if `show_progress_bar` is set to True. "
184+
"Please install with `pip install tqdm`."
185+
) from e
186+
187+
progress = tqdm(total=len(texts))
188+
189+
try:
190+
for start, end in self._iter_token_safe_batch_slices(texts):
191+
if start == end:
192+
continue
193+
batch = texts[start:end]
194+
r = self._client.embed(
195+
batch,
196+
model=self.model,
197+
input_type=input_type,
198+
truncation=self.truncation,
199+
output_dimension=self.output_dimension,
200+
).embeddings
201+
embeddings.extend(cast(Iterable[List[float]], r))
202+
if progress is not None:
203+
progress.update(len(batch))
204+
finally:
205+
if progress is not None:
206+
progress.close()
133207
return embeddings
134208

135209
def embed_documents(self, texts: List[str]) -> List[List[float]]:
@@ -163,16 +237,36 @@ async def _aembed_regular(
163237
) -> List[List[float]]:
164238
"""Async embed using regular embedding API."""
165239
embeddings: List[List[float]] = []
166-
_iter = self._get_batch_iterator(texts)
167-
for i in _iter:
168-
r = await self._aclient.embed(
169-
texts[i : i + self.batch_size],
170-
model=self.model,
171-
input_type=input_type,
172-
truncation=self.truncation,
173-
output_dimension=self.output_dimension,
174-
)
175-
embeddings.extend(cast(Iterable[List[float]], r.embeddings))
240+
progress = None
241+
if self.show_progress_bar:
242+
try:
243+
from tqdm.auto import tqdm # type: ignore
244+
except ImportError as e:
245+
raise ImportError(
246+
"Must have tqdm installed if `show_progress_bar` is set to True. "
247+
"Please install with `pip install tqdm`."
248+
) from e
249+
250+
progress = tqdm(total=len(texts))
251+
252+
try:
253+
for start, end in self._iter_token_safe_batch_slices(texts):
254+
if start == end:
255+
continue
256+
batch = texts[start:end]
257+
r = await self._aclient.embed(
258+
batch,
259+
model=self.model,
260+
input_type=input_type,
261+
truncation=self.truncation,
262+
output_dimension=self.output_dimension,
263+
)
264+
embeddings.extend(cast(Iterable[List[float]], r.embeddings))
265+
if progress is not None:
266+
progress.update(len(batch))
267+
finally:
268+
if progress is not None:
269+
progress.close()
176270
return embeddings
177271

178272
async def aembed_documents(self, texts: List[str]) -> List[List[float]]:

libs/voyageai/pyproject.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,8 @@ test = [
3434
"pytest-asyncio<1.0.0,>=0.21.1",
3535
"pytest-socket<1.0.0,>=0.7.0",
3636
"numpy<2.0.0,>=1.24.0; python_version < \"3.12\"",
37-
"numpy<2.0.0,>=1.26.0; python_version >= \"3.12\"",
37+
"numpy<2.0.0,>=1.26.0; python_version >= \"3.12\" and python_version < \"3.13\"",
38+
"numpy>=2.1.0; python_version >= \"3.13\"",
3839
]
3940
codespell = ["codespell<3.0.0,>=2.2.0"]
4041
test_integration = []

libs/voyageai/tests/integration_tests/test_embeddings.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,16 @@
11
"""Test VoyageAI embeddings."""
22

3+
import os
4+
import pytest
5+
36
from langchain_voyageai import VoyageAIEmbeddings
47

58
# Please set VOYAGE_API_KEY in the environment variables
9+
pytestmark = pytest.mark.skipif(
10+
"VOYAGE_API_KEY" not in os.environ,
11+
reason="VOYAGE_API_KEY environment variable required for Voyage integration tests",
12+
)
13+
614
MODEL = "voyage-2"
715

816

libs/voyageai/tests/integration_tests/test_rerank.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,19 @@
11
"""Test the voyageai reranker."""
22

33
import os
4+
import pytest
45

56
from langchain_core.documents import Document
67

78
from langchain_voyageai.rerank import VoyageAIRerank
89

910

11+
pytestmark = pytest.mark.skipif(
12+
"VOYAGE_API_KEY" not in os.environ,
13+
reason="VOYAGE_API_KEY environment variable required for Voyage integration tests",
14+
)
15+
16+
1017
def test_voyageai_reranker_init() -> None:
1118
"""Test the voyageai reranker initializes correctly."""
1219
VoyageAIRerank(voyage_api_key="foo", model="foo") # type: ignore[arg-type]

libs/voyageai/tests/unit_tests/test_embeddings.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
"""Test embedding model integration."""
22

3+
from typing import List
4+
35
from langchain_core.embeddings import Embeddings
46
from pydantic import SecretStr
57

@@ -137,3 +139,74 @@ def test_contextual_model_variants() -> None:
137139
assert (
138140
emb._is_context_model() is True
139141
), f"Model {model} should be detected as contextual"
142+
143+
144+
class _StubResponse:
145+
def __init__(self, count: int) -> None:
146+
self.embeddings = [[float(i)] for i in range(count)]
147+
148+
149+
class _StubClient:
150+
def __init__(self, token_lengths: List[int], recorded_batches: List[List[str]]) -> None:
151+
self._token_lengths = token_lengths
152+
self._recorded_batches = recorded_batches
153+
154+
def tokenize(self, texts: List[str], model: str) -> List[List[int]]: # type: ignore[override]
155+
assert len(texts) == len(self._token_lengths)
156+
return [list(range(length)) for length in self._token_lengths]
157+
158+
def embed(self, texts: List[str], **_: object) -> _StubResponse: # type: ignore[override]
159+
batch = list(texts)
160+
self._recorded_batches.append(batch)
161+
return _StubResponse(len(batch))
162+
163+
164+
def test_embed_regular_splits_on_token_limit(monkeypatch) -> None:
165+
texts = ["text-a", "text-b", "text-c", "text-d"]
166+
# voyage-3.5 limit is 320k tokens per request. Force batches of two items each.
167+
token_lengths = [150_000, 150_000, 150_000, 150_000]
168+
recorded_batches: List[List[str]] = []
169+
emb = VoyageAIEmbeddings(
170+
voyage_api_key=SecretStr("NOT_A_VALID_KEY"), # type: ignore
171+
model="voyage-3.5",
172+
batch_size=10,
173+
)
174+
stub_client = _StubClient(token_lengths, recorded_batches)
175+
monkeypatch.setattr(emb, "_client", stub_client, raising=False)
176+
177+
result = emb._embed_regular(texts, "document")
178+
179+
assert recorded_batches == [["text-a", "text-b"], ["text-c", "text-d"]]
180+
assert len(result) == len(texts)
181+
182+
183+
def test_iter_token_safe_batch_respects_custom_batch_size(monkeypatch) -> None:
184+
texts = [f"chunk-{i}" for i in range(5)]
185+
token_lengths = [5] * len(texts)
186+
recorded_batches: List[List[str]] = []
187+
emb = VoyageAIEmbeddings(
188+
voyage_api_key=SecretStr("NOT_A_VALID_KEY"), # type: ignore
189+
model="voyage-3.5-lite",
190+
batch_size=2,
191+
)
192+
stub_client = _StubClient(token_lengths, recorded_batches)
193+
monkeypatch.setattr(emb, "_client", stub_client, raising=False)
194+
195+
slices = list(emb._iter_token_safe_batch_slices(texts))
196+
assert slices == [(0, 2), (2, 4), (4, 5)]
197+
198+
199+
def test_iter_token_safe_batch_handles_single_oversized_text(monkeypatch) -> None:
200+
texts = ["oversized"]
201+
token_lengths = [500_000]
202+
recorded_batches: List[List[str]] = []
203+
emb = VoyageAIEmbeddings(
204+
voyage_api_key=SecretStr("NOT_A_VALID_KEY"), # type: ignore
205+
model="voyage-3-large",
206+
batch_size=5,
207+
)
208+
stub_client = _StubClient(token_lengths, recorded_batches)
209+
monkeypatch.setattr(emb, "_client", stub_client, raising=False)
210+
211+
slices = list(emb._iter_token_safe_batch_slices(texts))
212+
assert slices == [(0, 1)]

libs/voyageai/tests/unit_tests/test_rerank.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from langchain_core.documents import Document
66
from voyageai.api_resources import VoyageResponse # type: ignore
77
from voyageai.object import RerankingObject # type: ignore
8+
import voyageai
89

910
from langchain_voyageai.rerank import VoyageAIRerank
1011

@@ -47,8 +48,11 @@ def get_mock_rerank_result() -> RerankingObject:
4748

4849

4950
@pytest.mark.requires("voyageai")
50-
def test_rerank_unit_test(mocker: Any) -> None:
51-
mocker.patch("voyageai.Client.rerank").return_value = get_mock_rerank_result()
51+
def test_rerank_unit_test(monkeypatch: pytest.MonkeyPatch) -> None:
52+
def _mock_rerank(*_: Any, **__: Any) -> RerankingObject:
53+
return get_mock_rerank_result()
54+
55+
monkeypatch.setattr(voyageai.Client, "rerank", _mock_rerank)
5256
expected_result = [
5357
Document(
5458
page_content="Photosynthesis in plants converts light energy into "

0 commit comments

Comments
 (0)