1
1
import logging
2
- from typing import Any , Iterable , List , Literal , Optional , cast
2
+ from typing import Any , Iterable , Iterator , List , Literal , Optional , Tuple , cast
3
3
4
4
import voyageai # type: ignore
5
5
from langchain_core .embeddings import Embeddings
20
20
DEFAULT_VOYAGE_3_LITE_BATCH_SIZE = 30
21
21
DEFAULT_VOYAGE_3_BATCH_SIZE = 10
22
22
DEFAULT_BATCH_SIZE = 7
23
+ MAX_DOCUMENTS_PER_REQUEST = 1_000
24
+ DEFAULT_MAX_TOKENS_PER_REQUEST = 120_000
25
+ TOKEN_LIMIT_OVERRIDES : Tuple [Tuple [int , Tuple [str , ...]], ...] = (
26
+ (1_000_000 , ("voyage-3.5-lite" , "voyage-3-lite" )),
27
+ (320_000 , ("voyage-3.5" , "voyage-3" , "voyage-2" , "voyage-02" )),
28
+ )
23
29
24
30
25
31
class VoyageAIEmbeddings (BaseModel , Embeddings ):
@@ -85,21 +91,69 @@ def validate_environment(self) -> Self:
85
91
self ._aclient = voyageai .client_async .AsyncClient (api_key = api_key_str )
86
92
return self
87
93
88
- def _get_batch_iterator (self , texts : List [str ]) -> Iterable :
89
- if self .show_progress_bar :
90
- try :
91
- from tqdm .auto import tqdm # type: ignore
92
- except ImportError as e :
93
- raise ImportError (
94
- "Must have tqdm installed if `show_progress_bar` is set to True. "
95
- "Please install with `pip install tqdm`."
96
- ) from e
94
+ def _max_documents_per_batch (self ) -> int :
95
+ """Return the maximum number of documents allowed in a single request."""
96
+ return max (1 , min (self .batch_size , MAX_DOCUMENTS_PER_REQUEST ))
97
97
98
- _iter = tqdm (range (0 , len (texts ), self .batch_size ))
99
- else :
100
- _iter = range (0 , len (texts ), self .batch_size ) # type: ignore
98
+ def _max_tokens_per_batch (self ) -> int :
99
+ """Return the maximum number of tokens allowed for the current model."""
100
+ model_name = self .model
101
+ for limit , models in TOKEN_LIMIT_OVERRIDES :
102
+ if model_name in models :
103
+ return limit
104
+ return DEFAULT_MAX_TOKENS_PER_REQUEST
101
105
102
- return _iter
106
+ def _token_lengths (self , texts : List [str ]) -> List [int ]:
107
+ """Return token lengths for texts using the Voyage client tokenizer."""
108
+ try :
109
+ tokenized = self ._client .tokenize (texts , self .model )
110
+ except Exception :
111
+ logger .debug ("Failed to tokenize texts for model %s" , self .model )
112
+ raise
113
+ return [len (tokens ) for tokens in tokenized ]
114
+
115
+ def _iter_token_safe_batch_slices (
116
+ self , texts : List [str ]
117
+ ) -> Iterator [Tuple [int , int ]]:
118
+ """Yield (start, end) indices for batches within token and length limits."""
119
+ if not texts :
120
+ return
121
+
122
+ token_lengths = self ._token_lengths (texts )
123
+ max_docs = self ._max_documents_per_batch ()
124
+ max_tokens = self ._max_tokens_per_batch ()
125
+
126
+ index = 0
127
+ total_texts = len (texts )
128
+ while index < total_texts :
129
+ start = index
130
+ batch_tokens = 0
131
+ batch_docs = 0
132
+ while index < total_texts and batch_docs < max_docs :
133
+ current_tokens = token_lengths [index ]
134
+ if batch_docs > 0 and batch_tokens + current_tokens > max_tokens :
135
+ break
136
+
137
+ if current_tokens > max_tokens and batch_docs == 0 :
138
+ logger .warning (
139
+ "Text at index %s exceeds Voyage token limit (%s > %s). "
140
+ "Sending as a single-item batch; API may truncate or error." ,
141
+ index ,
142
+ current_tokens ,
143
+ max_tokens ,
144
+ )
145
+ index += 1
146
+ batch_docs += 1
147
+ batch_tokens = current_tokens
148
+ break
149
+
150
+ batch_tokens += current_tokens
151
+ batch_docs += 1
152
+ index += 1
153
+
154
+ if start == index :
155
+ index += 1
156
+ yield (start , index )
103
157
104
158
def _is_context_model (self ) -> bool :
105
159
"""Check if the model is a contextualized embedding model."""
@@ -120,16 +174,36 @@ def _embed_context(
120
174
def _embed_regular (self , texts : List [str ], input_type : str ) -> List [List [float ]]:
121
175
"""Embed using regular embedding API."""
122
176
embeddings : List [List [float ]] = []
123
- _iter = self ._get_batch_iterator (texts )
124
- for i in _iter :
125
- r = self ._client .embed (
126
- texts [i : i + self .batch_size ],
127
- model = self .model ,
128
- input_type = input_type ,
129
- truncation = self .truncation ,
130
- output_dimension = self .output_dimension ,
131
- ).embeddings
132
- embeddings .extend (cast (Iterable [List [float ]], r ))
177
+ progress = None
178
+ if self .show_progress_bar :
179
+ try :
180
+ from tqdm .auto import tqdm # type: ignore
181
+ except ImportError as e :
182
+ raise ImportError (
183
+ "Must have tqdm installed if `show_progress_bar` is set to True. "
184
+ "Please install with `pip install tqdm`."
185
+ ) from e
186
+
187
+ progress = tqdm (total = len (texts ))
188
+
189
+ try :
190
+ for start , end in self ._iter_token_safe_batch_slices (texts ):
191
+ if start == end :
192
+ continue
193
+ batch = texts [start :end ]
194
+ r = self ._client .embed (
195
+ batch ,
196
+ model = self .model ,
197
+ input_type = input_type ,
198
+ truncation = self .truncation ,
199
+ output_dimension = self .output_dimension ,
200
+ ).embeddings
201
+ embeddings .extend (cast (Iterable [List [float ]], r ))
202
+ if progress is not None :
203
+ progress .update (len (batch ))
204
+ finally :
205
+ if progress is not None :
206
+ progress .close ()
133
207
return embeddings
134
208
135
209
def embed_documents (self , texts : List [str ]) -> List [List [float ]]:
@@ -163,16 +237,36 @@ async def _aembed_regular(
163
237
) -> List [List [float ]]:
164
238
"""Async embed using regular embedding API."""
165
239
embeddings : List [List [float ]] = []
166
- _iter = self ._get_batch_iterator (texts )
167
- for i in _iter :
168
- r = await self ._aclient .embed (
169
- texts [i : i + self .batch_size ],
170
- model = self .model ,
171
- input_type = input_type ,
172
- truncation = self .truncation ,
173
- output_dimension = self .output_dimension ,
174
- )
175
- embeddings .extend (cast (Iterable [List [float ]], r .embeddings ))
240
+ progress = None
241
+ if self .show_progress_bar :
242
+ try :
243
+ from tqdm .auto import tqdm # type: ignore
244
+ except ImportError as e :
245
+ raise ImportError (
246
+ "Must have tqdm installed if `show_progress_bar` is set to True. "
247
+ "Please install with `pip install tqdm`."
248
+ ) from e
249
+
250
+ progress = tqdm (total = len (texts ))
251
+
252
+ try :
253
+ for start , end in self ._iter_token_safe_batch_slices (texts ):
254
+ if start == end :
255
+ continue
256
+ batch = texts [start :end ]
257
+ r = await self ._aclient .embed (
258
+ batch ,
259
+ model = self .model ,
260
+ input_type = input_type ,
261
+ truncation = self .truncation ,
262
+ output_dimension = self .output_dimension ,
263
+ )
264
+ embeddings .extend (cast (Iterable [List [float ]], r .embeddings ))
265
+ if progress is not None :
266
+ progress .update (len (batch ))
267
+ finally :
268
+ if progress is not None :
269
+ progress .close ()
176
270
return embeddings
177
271
178
272
async def aembed_documents (self , texts : List [str ]) -> List [List [float ]]:
0 commit comments