|
13 | 13 | from orjson import orjson |
14 | 14 |
|
15 | 15 | from marqo.core.inference.api import InferenceRequest, Modality, ModelConfig, TextPreprocessingConfig, Inference, \ |
16 | | - InferenceResult, InferenceErrorModel |
| 16 | + InferenceResult, InferenceErrorModel, ImagePreprocessingConfig |
17 | 17 | from marqo.inference.inference_cache.caching_inference import CachingInference |
18 | 18 |
|
19 | 19 |
|
@@ -199,6 +199,82 @@ def test_caching_inference_should_capture_key_metrics(self): |
199 | 199 |
|
200 | 200 | provider.shutdown() |
201 | 201 |
|
| 202 | + def test_base64_image_selective_caching(self): |
| 203 | + """Test selective caching: base64 images cached, URL images processed normally.""" |
| 204 | + caching_inference = CachingInference(self.inference_local, 10, "LRU") |
| 205 | + |
| 206 | + base64_png = "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mP8/5+hHgAHggJ/PchI7wAAAABJRU5ErkJggg==" |
| 207 | + base64_jpeg = "data:image/jpeg;base64,/9j/4AAQSkZJRgABAQEAYABgAAD/2wBDAAEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQH/" |
| 208 | + url_image = "https://example.com/image.jpg" |
| 209 | + |
| 210 | + # Mixed request: 2 base64 images + 1 URL |
| 211 | + mixed_request = InferenceRequest( |
| 212 | + modality=Modality.IMAGE, |
| 213 | + contents=[base64_png, url_image, base64_jpeg], |
| 214 | + model_config=ModelConfig( |
| 215 | + model_name="test/clip-model", |
| 216 | + model_properties={ |
| 217 | + "name": "test-clip-model", |
| 218 | + "dimensions": 512, |
| 219 | + "type": "clip" |
| 220 | + } |
| 221 | + ), |
| 222 | + preprocessing_config=ImagePreprocessingConfig(should_chunk=False), |
| 223 | + use_inference_cache=True |
| 224 | + ) |
| 225 | + |
| 226 | + # First call - base64 images should be cached, URL processed normally |
| 227 | + result1 = caching_inference.vectorise(mixed_request) |
| 228 | + |
| 229 | + # Verify cache contains blake3 keys for both base64 images |
| 230 | + import blake3 |
| 231 | + model_key = caching_inference.model_cache_key(mixed_request.model_config.model_properties) |
| 232 | + |
| 233 | + hash1 = blake3.blake3(base64_png.encode()).hexdigest() |
| 234 | + hash2 = blake3.blake3(base64_jpeg.encode()).hexdigest() |
| 235 | + cache_key1 = f"blake3:{hash1}" |
| 236 | + cache_key2 = f"blake3:{hash2}" |
| 237 | + |
| 238 | + cached_embedding1 = caching_inference.inference_cache.get(model_key, cache_key1) |
| 239 | + cached_embedding2 = caching_inference.inference_cache.get(model_key, cache_key2) |
| 240 | + |
| 241 | + self.assertIsNotNone(cached_embedding1, "First base64 image should be cached") |
| 242 | + self.assertIsNotNone(cached_embedding2, "Second base64 image should be cached") |
| 243 | + |
| 244 | + # Verify URL image is NOT cached |
| 245 | + url_cached_embedding = caching_inference.inference_cache.get(model_key, url_image) |
| 246 | + self.assertIsNone(url_cached_embedding, "URL image should not be cached") |
| 247 | + |
| 248 | + # Verify cache size (only 2 base64 images cached) |
| 249 | + self.assertEqual(caching_inference.inference_cache._cache.currsize, 2) |
| 250 | + |
| 251 | + # Second call with same mixed content |
| 252 | + result2 = caching_inference.vectorise(mixed_request) |
| 253 | + |
| 254 | + # Base64 results should return original base64 content (not blake3 keys) |
| 255 | + png_content1, png_embedding1 = result1.result[0][0] |
| 256 | + png_content2, png_embedding2 = result2.result[0][0] |
| 257 | + |
| 258 | + # Content should be original base64, embeddings should be identical (from cache) |
| 259 | + self.assertEqual(png_content1, base64_png) |
| 260 | + self.assertEqual(png_content2, base64_png) |
| 261 | + self.assertTrue(np.array_equal(png_embedding1, png_embedding2)) |
| 262 | + |
| 263 | + jpeg_content1, jpeg_embedding1 = result1.result[2][0] |
| 264 | + jpeg_content2, jpeg_embedding2 = result2.result[2][0] |
| 265 | + |
| 266 | + # Content should be original base64, embeddings should be identical (from cache) |
| 267 | + self.assertEqual(jpeg_content1, base64_jpeg) |
| 268 | + self.assertEqual(jpeg_content2, base64_jpeg) |
| 269 | + self.assertTrue(np.array_equal(jpeg_embedding1, jpeg_embedding2)) |
| 270 | + |
| 271 | + # URL results should be unchanged (original URL returned) |
| 272 | + url_content1, url_embedding1 = result1.result[1][0] |
| 273 | + url_content2, url_embedding2 = result2.result[1][0] |
| 274 | + self.assertEqual(url_content1, url_image) # Original URL unchanged |
| 275 | + self.assertEqual(url_content2, url_image) # Original URL unchanged |
| 276 | + self.assertTrue(np.array_equal(url_embedding1, url_embedding2)) |
| 277 | + |
202 | 278 | def _assert_metric_value(self, metric_data: MetricsData, name: str, expected_value: Any): |
203 | 279 | cache_metrics = metric_data.resource_metrics[0].scope_metrics[0].metrics |
204 | 280 | metric = next((metric for metric in cache_metrics if metric.name == name), None) |
|
0 commit comments