|
4 | 4 | import os |
5 | 5 | from io import BytesIO |
6 | 6 | from pathlib import Path |
7 | | -from typing import Any, List, Optional, Union |
| 7 | +from typing import Any, Generator, List, Optional, Tuple, Union |
8 | 8 |
|
9 | 9 | import voyageai |
10 | 10 | from PIL import Image |
|
17 | 17 |
|
18 | 18 | logger = logging.getLogger(__name__) |
19 | 19 |
|
20 | | -DEFAULT_VOYAGE_2_BATCH_SIZE = 72 |
21 | | -DEFAULT_VOYAGE_3_LITE_BATCH_SIZE = 30 |
22 | | -DEFAULT_VOYAGE_3_BATCH_SIZE = 10 |
23 | | -DEFAULT_BATCH_SIZE = 7 |
| 20 | +MAX_BATCH_SIZE = 1000 |
| 21 | + |
24 | 22 | MULTIMODAL_MODELS = ["voyage-multimodal-3"] |
| 23 | +CONTEXT_MODELS = ["voyage-context-3"] |
25 | 24 |
|
26 | 25 | SUPPORTED_IMAGE_FORMATS = {"png", "jpeg", "jpg", "webp", "gif"} |
27 | 26 |
|
| 27 | +VOYAGE_TOTAL_TOKEN_LIMITS = { |
| 28 | + "voyage-context-3": 32_000, |
| 29 | + "voyage-3.5-lite": 1_000_000, |
| 30 | + "voyage-3.5": 32_000, # voyage-3.5 has 32k context window |
| 31 | + "voyage-2": 320_000, |
| 32 | + "voyage-3-large": 120_000, |
| 33 | + "voyage-code-3": 120_000, |
| 34 | + "voyage-large-2-instruct": 120_000, |
| 35 | + "voyage-finance-2": 120_000, |
| 36 | + "voyage-multilingual-2": 120_000, |
| 37 | + "voyage-law-2": 120_000, |
| 38 | + "voyage-large-2": 120_000, |
| 39 | + "voyage-3": 120_000, |
| 40 | + "voyage-3-lite": 120_000, |
| 41 | + "voyage-code-2": 120_000, |
| 42 | + "voyage-3-m-exp": 120_000, |
| 43 | +} |
| 44 | + |
28 | 45 |
|
29 | 46 | class VoyageEmbedding(MultiModalEmbedding): |
30 | 47 | """ |
@@ -76,19 +93,7 @@ def __init__( |
76 | 93 | ) |
77 | 94 |
|
78 | 95 | if embed_batch_size is None: |
79 | | - embed_batch_size = ( |
80 | | - DEFAULT_VOYAGE_2_BATCH_SIZE |
81 | | - if model_name in ["voyage-2", "voyage-02"] |
82 | | - else ( |
83 | | - DEFAULT_VOYAGE_3_LITE_BATCH_SIZE |
84 | | - if model_name in ["voyage-3-lite", "voyage-3.5-lite"] |
85 | | - else ( |
86 | | - DEFAULT_VOYAGE_3_BATCH_SIZE |
87 | | - if model_name in ["voyage-3", "voyage-3.5", "voyage-context-3"] |
88 | | - else DEFAULT_BATCH_SIZE |
89 | | - ) |
90 | | - ) |
91 | | - ) |
| 96 | + embed_batch_size = MAX_BATCH_SIZE |
92 | 97 |
|
93 | 98 | super().__init__( |
94 | 99 | model_name=model_name, |
@@ -116,6 +121,32 @@ def _validate_image_format(file_type: str) -> bool: |
116 | 121 | def _texts_to_content(cls, input_strs: List[str]) -> List[dict]: |
117 | 122 | return [{"content": [{"type": "text", "text": x}]} for x in input_strs] |
118 | 123 |
|
| 124 | + def _build_batches( |
| 125 | + self, texts: List[str] |
| 126 | + ) -> Generator[Tuple[List[str], int], None, None]: |
| 127 | + """Generate batches of texts based on token limits.""" |
| 128 | + max_tokens_per_batch = VOYAGE_TOTAL_TOKEN_LIMITS.get(self.model_name, 120_000) |
| 129 | + index = 0 |
| 130 | + |
| 131 | + while index < len(texts): |
| 132 | + batch: List[str] = [] |
| 133 | + batch_tokens = 0 |
| 134 | + while ( |
| 135 | + index < len(texts) |
| 136 | + and len(batch) < min(self.embed_batch_size, MAX_BATCH_SIZE) |
| 137 | + and batch_tokens < max_tokens_per_batch |
| 138 | + ): |
| 139 | + n_tokens = len( |
| 140 | + self._client.tokenize([texts[index]], model=self.model_name)[0] |
| 141 | + ) |
| 142 | + if batch_tokens + n_tokens > max_tokens_per_batch and len(batch) > 0: |
| 143 | + break |
| 144 | + batch_tokens += n_tokens |
| 145 | + batch.append(texts[index]) |
| 146 | + index += 1 |
| 147 | + |
| 148 | + yield batch, len(batch) |
| 149 | + |
119 | 150 | def _image_to_content(self, image_input: Union[str, Path, BytesIO]) -> Image: |
120 | 151 | """Convert an image to a base64 Data URL.""" |
121 | 152 | if isinstance(image_input, (str, Path)): |
@@ -177,41 +208,75 @@ async def _aget_image_embedding(self, img_file_path: ImageType) -> Embedding: |
177 | 208 | return await self._aembed_image(img_file_path) |
178 | 209 |
|
179 | 210 | def _embed(self, texts: List[str], input_type: str) -> List[List[float]]: |
180 | | - if self.model_name in MULTIMODAL_MODELS: |
181 | | - return self._client.multimodal_embed( |
182 | | - inputs=self._texts_to_content(texts), |
183 | | - model=self.model_name, |
184 | | - input_type=input_type, |
185 | | - truncation=self.truncation, |
186 | | - ).embeddings |
187 | | - else: |
188 | | - return self._client.embed( |
189 | | - texts, |
190 | | - model=self.model_name, |
191 | | - input_type=input_type, |
192 | | - truncation=self.truncation, |
193 | | - output_dtype=self.output_dtype, |
194 | | - output_dimension=self.output_dimension, |
195 | | - ).embeddings |
| 211 | + """Embed texts with dynamic batching based on token limits.""" |
| 212 | + embeddings: List[List[float]] = [] |
| 213 | + |
| 214 | + for batch, _ in self._build_batches(texts): |
| 215 | + if self.model_name in CONTEXT_MODELS: |
| 216 | + r = self._client.contextualized_embed( |
| 217 | + inputs=[batch], |
| 218 | + model=self.model_name, |
| 219 | + input_type=input_type, |
| 220 | + output_dtype=self.output_dtype, |
| 221 | + output_dimension=self.output_dimension, |
| 222 | + ).results |
| 223 | + embeddings.extend(r[0].embeddings) |
| 224 | + elif self.model_name in MULTIMODAL_MODELS: |
| 225 | + batch_embeddings = self._client.multimodal_embed( |
| 226 | + inputs=self._texts_to_content(batch), |
| 227 | + model=self.model_name, |
| 228 | + input_type=input_type, |
| 229 | + truncation=self.truncation, |
| 230 | + ).embeddings |
| 231 | + embeddings.extend(batch_embeddings) |
| 232 | + else: |
| 233 | + batch_embeddings = self._client.embed( |
| 234 | + batch, |
| 235 | + model=self.model_name, |
| 236 | + input_type=input_type, |
| 237 | + truncation=self.truncation, |
| 238 | + output_dtype=self.output_dtype, |
| 239 | + output_dimension=self.output_dimension, |
| 240 | + ).embeddings |
| 241 | + embeddings.extend(batch_embeddings) |
| 242 | + |
| 243 | + return embeddings |
196 | 244 |
|
197 | 245 | async def _aembed(self, texts: List[str], input_type: str) -> List[List[float]]: |
198 | | - if self.model_name in MULTIMODAL_MODELS: |
199 | | - r = await self._aclient.multimodal_embed( |
200 | | - inputs=self._texts_to_content(texts), |
201 | | - model=self.model_name, |
202 | | - input_type=input_type, |
203 | | - truncation=self.truncation, |
204 | | - ) |
205 | | - else: |
206 | | - r = await self._aclient.embed( |
207 | | - texts, |
208 | | - model=self.model_name, |
209 | | - input_type=input_type, |
210 | | - truncation=self.truncation, |
211 | | - output_dtype=self.output_dtype, |
212 | | - output_dimension=self.output_dimension, |
213 | | - ) |
214 | | - return r.embeddings |
| 246 | + """Asynchronously embed texts with dynamic batching based on token limits.""" |
| 247 | + embeddings: List[List[float]] = [] |
| 248 | + |
| 249 | + for batch, _ in self._build_batches(texts): |
| 250 | + if self.model_name in CONTEXT_MODELS: |
| 251 | + ar = await self._aclient.contextualized_embed( |
| 252 | + inputs=[batch], |
| 253 | + model=self.model_name, |
| 254 | + input_type=input_type, |
| 255 | + output_dtype=self.output_dtype, |
| 256 | + output_dimension=self.output_dimension, |
| 257 | + ) |
| 258 | + r = ar.results |
| 259 | + embeddings.extend(r[0].embeddings) |
| 260 | + elif self.model_name in MULTIMODAL_MODELS: |
| 261 | + r = await self._aclient.multimodal_embed( |
| 262 | + inputs=self._texts_to_content(batch), |
| 263 | + model=self.model_name, |
| 264 | + input_type=input_type, |
| 265 | + truncation=self.truncation, |
| 266 | + ) |
| 267 | + embeddings.extend(r.embeddings) |
| 268 | + else: |
| 269 | + r = await self._aclient.embed( |
| 270 | + batch, |
| 271 | + model=self.model_name, |
| 272 | + input_type=input_type, |
| 273 | + truncation=self.truncation, |
| 274 | + output_dtype=self.output_dtype, |
| 275 | + output_dimension=self.output_dimension, |
| 276 | + ) |
| 277 | + embeddings.extend(r.embeddings) |
| 278 | + |
| 279 | + return embeddings |
215 | 280 |
|
216 | 281 | def _get_query_embedding(self, query: str) -> List[float]: |
217 | 282 | """Get query embedding.""" |
|
0 commit comments