Skip to content

Commit daa7aec

Browse files
authored
Support base64-encoded image in inference cache (#1285)
1 parent 8d3ae68 commit daa7aec

File tree

6 files changed

+255
-31
lines changed

6 files changed

+255
-31
lines changed

requirements.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,4 +9,5 @@
99
opentelemetry-api==1.33.1
1010
opentelemetry-sdk==1.33.1
1111

12-
cachetools==6.1.0
12+
cachetools==6.1.0
13+
blake3==1.0.5

src/marqo/inference/inference_cache/caching_inference.py

Lines changed: 40 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
import hashlib
2-
from typing import Tuple, List
2+
from typing import Tuple, List, Optional
33

4+
import blake3
45
import numpy as np
56
import orjson
67

78
from marqo.core.inference.api import Inference, InferenceRequest, InferenceResult, Modality, \
89
InferenceErrorModel
10+
from marqo.core.inference.modality_utils import is_base64_image
911
from marqo.inference.inference_cache.marqo_inference_cache import MarqoInferenceCache
1012

1113

@@ -24,7 +26,12 @@ def vectorise(self, request: InferenceRequest) -> InferenceResult:
2426
contents_to_vectorise: List[str] = []
2527

2628
for index, content in enumerate(request.contents):
27-
embedding = self.inference_cache.get(model_cache_key, content)
29+
content_cache_key = self.content_cache_key(content, request.modality)
30+
if not content_cache_key:
31+
contents_to_vectorise.append(content)
32+
continue
33+
34+
embedding = self.inference_cache.get(model_cache_key, content_cache_key)
2835
if embedding is not None:
2936
cached_result.append((index, content, embedding))
3037
else:
@@ -43,7 +50,9 @@ def vectorise(self, request: InferenceRequest) -> InferenceResult:
4350
f"Preprocessing config: "
4451
f"{orjson.dumps(dict(new_request.preprocessing_config)).decode('utf-8')}")
4552
content, embedding = r[0]
46-
self.inference_cache.set(model_cache_key, content, embedding)
53+
content_cache_key = self.content_cache_key(content, request.modality)
54+
if content_cache_key:
55+
self.inference_cache.set(model_cache_key, content_cache_key, embedding)
4756

4857
# Merge result
4958
if cached_result:
@@ -61,10 +70,37 @@ def model_cache_key(self, model_properties) -> str:
6170
data = orjson.dumps(model_properties, option=orjson.OPT_SORT_KEYS)
6271
return hashlib.md5(data).hexdigest()
6372

73+
def content_cache_key(self, content: str, modality: Modality) -> Optional[str]:
74+
"""
75+
Generate appropriate cache key for content based on modality.
76+
77+
For TEXT modality: use content directly
78+
For IMAGE modality:
79+
- if base64 image: use blake3 hash with prefix
80+
- otherwise: use content directly (will be skipped in caching logic)
81+
82+
Args:
83+
content: The content string
84+
modality: The modality type
85+
86+
Returns:
87+
Cache key string, None if it should not be cached
88+
"""
89+
if modality == Modality.TEXT:
90+
# Use original content for text and non-base64 images
91+
return content
92+
elif modality == Modality.IMAGE and is_base64_image(content):
93+
# Use blake3 hash for base64 images to save memory
94+
hash_digest = blake3.blake3(content.encode()).hexdigest()
95+
return f"blake3:{hash_digest}"
96+
else:
97+
# should not cache non-base64-encoded images
98+
return None
99+
64100
def should_skip_cache(self, request):
65101
return (
66102
not request.use_inference_cache
67103
or request.device # device is only specified to debug embedding, skip caching
68-
or request.modality != Modality.TEXT # we only support text modality for now
104+
or request.modality not in [Modality.TEXT, Modality.IMAGE] # we support text and image modalities
69105
or request.preprocessing_config.should_chunk # we do not support caching chunks
70106
)

tests/api_tests/v1/tests/application_tests/test_env_var_changes.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -18,17 +18,13 @@
1818
this test suite's runtime from growing too large.
1919
"""
2020
import json
21-
import unittest
2221
from concurrent.futures import ThreadPoolExecutor, as_completed
2322
from typing import Callable, Optional, List
2423

25-
import math
26-
24+
from marqo import Client
2725
from tests import marqo_test
2826
from tests import utilities
2927

30-
from marqo import Client
31-
3228

3329
class TestEnvVarChanges(marqo_test.MarqoTestCase):
3430

@@ -83,7 +79,7 @@ def test_inference_cache(self):
8379
"""
8480

8581
# Restart marqo with new max values
86-
new_models = ["hf/e5-large-v2"]
82+
new_models = ["open_clip/ViT-B-32/laion2b_s34b_b79k"]
8783
index_name = "test_multiple_env_vars"
8884
utilities.rerun_marqo_with_env_vars(
8985
env_vars=[
@@ -111,10 +107,12 @@ def test_inference_cache(self):
111107
telemetry_client = Client(**self.client_settings, return_telemetry=True)
112108

113109
min_inference_time_ms = 8 # inference usually takes at least 8ms
114-
cache_reading_time_ms = 2 # if it hits cache, it's usually less than 2ms
110+
cache_reading_time_ms = 3 # if it hits cache, the pipeline should take less than 3ms
115111

116112
# Test search query's embedding is cached when inference cache is enabled
117-
for query in ["test", {"random": 1, "query": 2}]:
113+
base64_image = "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mP8/5+hHgAHggJ/PchI7wAAAABJRU5ErkJggg=="
114+
image_url = marqo_test.TestImageUrls.HIPPO_STATUE.value
115+
for query in ["test", {"random": 1, "query": 2}, base64_image, image_url]:
118116
with self.subTest(f"Search query: {query}"):
119117
# Single query
120118
# First search that misses cache should take longer
@@ -129,7 +127,13 @@ def test_inference_cache(self):
129127
inference_latency = self._run_in_threads(
130128
lambda client: client.index(index_name).search(q=query),
131129
max_workers=1, count=10, telemetry_name="search.vector_inference_full_pipeline")
132-
self.assertTrue(sum(inference_latency) / 10 < cache_reading_time_ms, inference_latency)
130+
131+
if query == image_url:
132+
# image url is not cached, so avg latency will usually be > min_inference_time_ms
133+
self.assertTrue(sum(inference_latency) / 10 > min_inference_time_ms, inference_latency)
134+
else:
135+
# other queries are all cached, so avg latency should be < cache_reading_time_ms
136+
self.assertTrue(sum(inference_latency) / 10 < cache_reading_time_ms, inference_latency)
133137

134138
# Test to ensure inference cache is not working for add_documents:
135139
with self.subTest("Add document"):

tests/integ_tests/inference/inference_cache/test_inference_cache.py

Lines changed: 77 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from orjson import orjson
1414

1515
from marqo.core.inference.api import InferenceRequest, Modality, ModelConfig, TextPreprocessingConfig, Inference, \
16-
InferenceResult, InferenceErrorModel
16+
InferenceResult, InferenceErrorModel, ImagePreprocessingConfig
1717
from marqo.inference.inference_cache.caching_inference import CachingInference
1818

1919

@@ -199,6 +199,82 @@ def test_caching_inference_should_capture_key_metrics(self):
199199

200200
provider.shutdown()
201201

202+
def test_base64_image_selective_caching(self):
203+
"""Test selective caching: base64 images cached, URL images processed normally."""
204+
caching_inference = CachingInference(self.inference_local, 10, "LRU")
205+
206+
base64_png = "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mP8/5+hHgAHggJ/PchI7wAAAABJRU5ErkJggg=="
207+
base64_jpeg = "data:image/jpeg;base64,/9j/4AAQSkZJRgABAQEAYABgAAD/2wBDAAEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQH/"
208+
url_image = "https://example.com/image.jpg"
209+
210+
# Mixed request: 2 base64 images + 1 URL
211+
mixed_request = InferenceRequest(
212+
modality=Modality.IMAGE,
213+
contents=[base64_png, url_image, base64_jpeg],
214+
model_config=ModelConfig(
215+
model_name="test/clip-model",
216+
model_properties={
217+
"name": "test-clip-model",
218+
"dimensions": 512,
219+
"type": "clip"
220+
}
221+
),
222+
preprocessing_config=ImagePreprocessingConfig(should_chunk=False),
223+
use_inference_cache=True
224+
)
225+
226+
# First call - base64 images should be cached, URL processed normally
227+
result1 = caching_inference.vectorise(mixed_request)
228+
229+
# Verify cache contains blake3 keys for both base64 images
230+
import blake3
231+
model_key = caching_inference.model_cache_key(mixed_request.model_config.model_properties)
232+
233+
hash1 = blake3.blake3(base64_png.encode()).hexdigest()
234+
hash2 = blake3.blake3(base64_jpeg.encode()).hexdigest()
235+
cache_key1 = f"blake3:{hash1}"
236+
cache_key2 = f"blake3:{hash2}"
237+
238+
cached_embedding1 = caching_inference.inference_cache.get(model_key, cache_key1)
239+
cached_embedding2 = caching_inference.inference_cache.get(model_key, cache_key2)
240+
241+
self.assertIsNotNone(cached_embedding1, "First base64 image should be cached")
242+
self.assertIsNotNone(cached_embedding2, "Second base64 image should be cached")
243+
244+
# Verify URL image is NOT cached
245+
url_cached_embedding = caching_inference.inference_cache.get(model_key, url_image)
246+
self.assertIsNone(url_cached_embedding, "URL image should not be cached")
247+
248+
# Verify cache size (only 2 base64 images cached)
249+
self.assertEqual(caching_inference.inference_cache._cache.currsize, 2)
250+
251+
# Second call with same mixed content
252+
result2 = caching_inference.vectorise(mixed_request)
253+
254+
# Base64 results should return original base64 content (not blake3 keys)
255+
png_content1, png_embedding1 = result1.result[0][0]
256+
png_content2, png_embedding2 = result2.result[0][0]
257+
258+
# Content should be original base64, embeddings should be identical (from cache)
259+
self.assertEqual(png_content1, base64_png)
260+
self.assertEqual(png_content2, base64_png)
261+
self.assertTrue(np.array_equal(png_embedding1, png_embedding2))
262+
263+
jpeg_content1, jpeg_embedding1 = result1.result[2][0]
264+
jpeg_content2, jpeg_embedding2 = result2.result[2][0]
265+
266+
# Content should be original base64, embeddings should be identical (from cache)
267+
self.assertEqual(jpeg_content1, base64_jpeg)
268+
self.assertEqual(jpeg_content2, base64_jpeg)
269+
self.assertTrue(np.array_equal(jpeg_embedding1, jpeg_embedding2))
270+
271+
# URL results should be unchanged (original URL returned)
272+
url_content1, url_embedding1 = result1.result[1][0]
273+
url_content2, url_embedding2 = result2.result[1][0]
274+
self.assertEqual(url_content1, url_image) # Original URL unchanged
275+
self.assertEqual(url_content2, url_image) # Original URL unchanged
276+
self.assertTrue(np.array_equal(url_embedding1, url_embedding2))
277+
202278
def _assert_metric_value(self, metric_data: MetricsData, name: str, expected_value: Any):
203279
cache_metrics = metric_data.resource_metrics[0].scope_metrics[0].metrics
204280
metric = next((metric for metric in cache_metrics if metric.name == name), None)

tests/integ_tests/core/inference/test_cache.py renamed to tests/unit_tests/marqo/inference/inference_cache/test_cache.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,10 @@
1-
import random
2-
import time
31
import unittest
4-
from concurrent.futures import ThreadPoolExecutor
5-
from queue import Queue
62

7-
import numpy as np
8-
9-
from marqo.api.exceptions import EnvVarError
103
from marqo.inference.inference_cache.marqo_lfu_cache import MarqoLFUCache
114
from marqo.inference.inference_cache.marqo_lru_cache import MarqoLRUCache
125

136

14-
class TestLFUCache(unittest.TestCase):
7+
class TestCache(unittest.TestCase):
158
"""This class tests the LRU and LFU cache implementations."""
169

1710
def setUp(self):

0 commit comments

Comments
 (0)