Skip to content

Commit 650ac4d

Browse files
authored
Merge releases/2.18 to mainline (#1207)
2 parents 5545644 + 13a8228 commit 650ac4d

File tree

11 files changed

+254
-143
lines changed

11 files changed

+254
-143
lines changed

examples/ClothingCLI/simple_marqo_demo.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ def load_index(index_name: str, number_data: int) -> None:
2727

2828
settings = {
2929
"treatUrlsAndPointersAsImages": True, # allows us to find an image file and index it
30-
"model": "ViT-B/16"
30+
"model": "open_clip/ViT-B-16/openai"
3131
}
3232

3333
mq.create_index(index_name, settings_dict=settings)

examples/ClothingStreamlit/streamlit_marqo_demo.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ def load_index(number_data):
2828

2929
settings = {
3030
"treatUrlsAndPointersAsImages":True, # allows us to find an image file and index it
31-
"model":"ViT-B/16"
31+
"model": "open_clip/ViT-B-16/openai"
3232
}
3333

3434
mq.create_index("demo-search-index", settings_dict=settings)

src/marqo/core/inference/api/exceptions.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,3 +29,8 @@ class UnsupportedModalityError(InferenceError):
2929
class MediaExceedsMaxSizeError(InferenceError):
3030
"""Raised when the media exceeds the maximum size limit"""
3131
pass
32+
33+
34+
class MediaMismatchError(InferenceError):
35+
"""Raised when the media does not match the expected type"""
36+
pass

src/marqo/core/structured_vespa_index/structured_add_document_handler.py

Lines changed: 19 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -3,31 +3,20 @@
33
from marqo.api import exceptions as api_errors
44
from marqo.core import constants
55
from marqo.core.constants import MARQO_DOC_ID
6-
from marqo.core.inference.api import Modality, MediaDownloadError, Inference
7-
from marqo.core.inference.modality_utils import infer_modality
8-
from marqo.core.vespa_index.add_documents_handler import AddDocumentsHandler, AddDocumentsError
9-
from marqo.core.models.add_docs_params import AddDocsParams
6+
from marqo.core.inference.api import Modality, Inference
107
from marqo.core.inference.tensor_fields_container import TensorFieldsContainer, TensorField
8+
from marqo.core.models.add_docs_params import AddDocsParams
119
from marqo.core.models.marqo_index import FieldType, StructuredMarqoIndex
1210
from marqo.core.structured_vespa_index.structured_vespa_index import StructuredVespaIndex
11+
from marqo.core.vespa_index.add_documents_handler import AddDocumentsHandler, AddDocumentsError
1312
from marqo.exceptions import InvalidArgumentError
14-
15-
from marqo.vespa.models import VespaDocument
16-
from marqo.vespa.models.get_document_response import Document
17-
1813
# TODO deps to tensor_search needs to be removed
1914
from marqo.tensor_search import validation
15+
from marqo.vespa.models import VespaDocument
16+
from marqo.vespa.models.get_document_response import Document
2017
from marqo.vespa.vespa_client import VespaClient
2118

2219

23-
MODALITY_FIELD_TYPE_MAP = {
24-
Modality.TEXT: FieldType.Text,
25-
Modality.IMAGE: FieldType.ImagePointer,
26-
Modality.VIDEO: FieldType.VideoPointer,
27-
Modality.AUDIO: FieldType.AudioPointer,
28-
}
29-
30-
3120
class StructuredAddDocumentsHandler(AddDocumentsHandler):
3221
def __init__(self, marqo_index: StructuredMarqoIndex, add_docs_params: AddDocsParams, vespa_client: VespaClient,
3322
inference: Inference):
@@ -74,20 +63,22 @@ def _handle_field(self, marqo_doc, field_name, field_content):
7463
marqo_doc[field_name] = content
7564

7665
def _infer_modality(self, tensor_field: TensorField) -> Modality:
66+
"""
67+
Infer modality based on tensor field type specified in the definition of structured index. Please note we
68+
do not infer the modality from the content of the field here, any modality mismatch is detected later when
69+
we download and preprocess the media content.
70+
"""
7771
if tensor_field.field_type == FieldType.Text:
7872
return Modality.TEXT
79-
80-
url = tensor_field.field_content
81-
try:
82-
modality = infer_modality(url, self.add_docs_params.media_download_headers)
83-
except MediaDownloadError as err:
84-
raise AddDocumentsError(f"Error processing {tensor_field.field_name}: {err.message}") from err
85-
86-
if MODALITY_FIELD_TYPE_MAP[modality] != tensor_field.field_type:
87-
raise AddDocumentsError(f"Error processing {tensor_field.field_name}, detected as {modality.value}, "
88-
f"but expected field type is {tensor_field.field_type}")
89-
90-
return modality
73+
elif tensor_field.field_type == FieldType.ImagePointer:
74+
return Modality.IMAGE
75+
elif tensor_field.field_type == FieldType.VideoPointer:
76+
return Modality.VIDEO
77+
elif tensor_field.field_type == FieldType.AudioPointer:
78+
return Modality.AUDIO
79+
else:
80+
raise AddDocumentsError(f"Error processing {tensor_field.field_name}, tensor field type "
81+
f"{tensor_field.field_type} is not supported")
9182

9283
def _validate_field(self, field_name: str, field_content: Any) -> None:
9384
try:

src/marqo/core/vespa_index/add_documents_handler.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,8 @@ def add_documents(self) -> MarqoAddDocumentsResponse:
168168
self._populate_existing_tensors(existing_vespa_docs)
169169

170170
# vectorise tensor fields
171-
self._vectorise_tensor_fields()
171+
with RequestMetricsStore.for_request().time("add_documents.inference.all"):
172+
self._vectorise_tensor_fields()
172173

173174
with RequestMetricsStore.for_request().time("add_documents.vespa.to_vespa_docs"):
174175
vespa_docs = self._convert_to_vespa_docs()
@@ -287,7 +288,8 @@ def _vectorise_tensor_fields(self) -> None:
287288
3. The result will be then populated to the tensor field. Individual errors happened during preprocessing
288289
and vectorisation will also be returned and collected by the `add_docs_response_collector`
289290
"""
290-
modalities = self._infer_modalities()
291+
with RequestMetricsStore.for_request().time("add_documents.inference.infer_modality"):
292+
modalities = self._infer_modalities()
291293

292294
for modality in modalities:
293295
self._vectorise_fields(modality, for_top_level_field=True)
@@ -342,7 +344,9 @@ def subfield_predicate(f: TensorField) -> bool:
342344

343345
# This method could raise InferenceError, we'll allow it propagate to the API layer and convert to proper
344346
# error response to return to users
345-
inference_result = self.inference.vectorise(request)
347+
with RequestMetricsStore.for_request().time(f"add_documents.inference.{modality}."
348+
f"is_subfield_{not for_top_level_field}.size_{len(tensor_fields)}"):
349+
inference_result = self.inference.vectorise(request)
346350

347351
if len(tensor_fields) != len(inference_result.result):
348352
raise InternalError(f'Inference result contains chunks and embeddings for {len(inference_result.result)} '

src/marqo/inference/media_download_and_preprocess/streaming_media_processor.py

Lines changed: 39 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,14 @@ def __init__(
4444
self.modality = preprocessing_config.modality
4545

4646
self.media_download_header = self._convert_headers_to_cli_format(preprocessing_config.download_header)
47-
self.total_size, self.duration = self._fetch_file_metadata()
47+
self.total_size, self.duration, self.probed_modality = self._fetch_file_metadata()
48+
49+
if self.modality != self.probed_modality:
50+
raise MediaMismatchError(
51+
f"Error processing media file {self.url}. The provided modality {self.modality} does not match the "
52+
f"detected modality {self.probed_modality}. Please check your media file and try again. If you are using "
53+
f"a structured index, check if your media file matches the field type"
54+
)
4855

4956
if self.total_size > preprocessing_config.max_media_size_bytes:
5057
raise MediaExceedsMaxSizeError(
@@ -81,11 +88,36 @@ def _convert_headers_to_cli_format(self, raw_media_download_headers: Optional[Di
8188
raise InternalError("media_download_headers should be a dictionary")
8289
return "\r\n".join([f"{key}: {value}" for key, value in raw_media_download_headers.items()])
8390

84-
def _fetch_file_metadata(self) -> Tuple[float, float]:
91+
def _infer_modality_from_probe(self, modality_list: list[str], format_name: Optional[str]) -> Optional[Modality]:
92+
"""
93+
Infer the modality from the probed media file. This is used to determine whether the media is audio or video.
94+
"""
95+
if Modality.VIDEO in modality_list:
96+
# Images are also considered as video in ffmpeg, so we need to check the format name to
97+
# differentiate between video and image
98+
if "image" in format_name or "_pipe" in format_name:
99+
return Modality.IMAGE
100+
else:
101+
return Modality.VIDEO
102+
elif Modality.AUDIO in modality_list:
103+
return Modality.AUDIO
104+
else:
105+
return None
106+
107+
def _fetch_file_metadata(self) -> Tuple[float, float, Optional[Modality]]:
108+
"""
109+
Fetch the metadata of the media file using ffmpeg. This includes the size, duration, and modality of the
110+
media file.
111+
112+
Returns:
113+
Tuple[float, float, str]: A tuple containing the size (in bytes), duration (in seconds), and modality of the
114+
media file.
115+
116+
"""
85117
try:
86118
probe_options = {
87119
'v': 'error',
88-
'show_entries': 'format=size,duration',
120+
'show_entries': 'stream=codec_type,format=size,duration,format_name',
89121
'of': 'json',
90122
'probesize': '256K', # Probe only the first 256KB
91123
}
@@ -97,8 +129,11 @@ def _fetch_file_metadata(self) -> Tuple[float, float]:
97129

98130
size = int(probe['format'].get('size', 0))
99131
duration = float(probe['format'].get('duration', 0))
132+
format_name = probe['format'].get('format_name', "")
133+
modality_list = [codec_type.get('codec_type', "") for codec_type in probe['streams']]
134+
modality = self._infer_modality_from_probe(modality_list, format_name)
100135

101-
return size, duration
136+
return size, duration, modality
102137

103138
except ffmpeg.Error as e:
104139
raise MediaDownloadError(f"Error fetching metadata: {e.stderr.decode()}") from e

tests/integ_tests/inference/native_inference/media_download_and_preprocess/test_streaming_media_preprcessor.py

Lines changed: 66 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import torch
77
from pytest import mark
88

9-
from integ_tests.marqo_test import TestVideoUrls, TestAudioUrls
9+
from integ_tests.marqo_test import TestVideoUrls, TestAudioUrls, TestImageUrls
1010
from marqo.core.inference.api import *
1111
from marqo.inference.media_download_and_preprocess.streaming_media_processor import StreamingMediaProcessor
1212
from marqo.inference.native_inference.embedding_models.languagebind_model import LanguagebindPreprocessor
@@ -163,7 +163,7 @@ def test_metadata_fetching_success(self):
163163
url=valid_url, preprocessors=self.test_preprocessor,
164164
preprocessing_config=self.test_video_preprocessing_config
165165
)
166-
size, duration = streaming_media_processor_object._fetch_file_metadata()
166+
size, duration, _ = streaming_media_processor_object._fetch_file_metadata()
167167

168168
self.assertEqual(2971504, size) # Hardcoded value
169169
self.assertEqual(10.01, duration) # Hardcoded value
@@ -215,11 +215,72 @@ def test_header_conversion_with_valid_headers(self):
215215
with patch("marqo.inference.media_download_and_preprocess"
216216
".streaming_media_processor.StreamingMediaProcessor._fetch_file_metadata") \
217217
as mock_fetch_file_metadata:
218-
mock_fetch_file_metadata.return_value = (2971504, 10.01)
218+
mock_fetch_file_metadata.return_value = (2971504, 10.01, Modality.VIDEO)
219219
streaming_media_processor_object = StreamingMediaProcessor(
220-
url=TestAudioUrls.AUDIO1.value, preprocessors=self.test_preprocessor,
220+
url=TestVideoUrls.VIDEO1.value, preprocessors=self.test_preprocessor,
221221
preprocessing_config=test_video_preprocessing_config
222222
)
223223

224224
expected = "Authorization: Bearer token\r\nUser-Agent: Test"
225-
self.assertEqual(streaming_media_processor_object.media_download_header, expected)
225+
self.assertEqual(streaming_media_processor_object.media_download_header, expected)
226+
227+
def test_prob_modality_correct_video(self):
228+
for url in [
229+
TestVideoUrls.VIDEO1.value, TestVideoUrls.VIDEO2.value, TestVideoUrls.VIDEO3.value,
230+
TestVideoUrls.MKV_VIDEO1.value, TestVideoUrls.WEBM_VIDEO1.value, TestVideoUrls.AVI_VIDEO1.value
231+
]:
232+
with self.subTest(url=url):
233+
streaming_media_processor_object = StreamingMediaProcessor(
234+
url=url, preprocessors=self.test_preprocessor,
235+
preprocessing_config=self.test_video_preprocessing_config
236+
)
237+
self.assertEqual(Modality.VIDEO, streaming_media_processor_object.probed_modality)
238+
239+
def test_prob_modality_correct_audio(self):
240+
for url in [
241+
TestAudioUrls.AUDIO1.value, TestAudioUrls.AUDIO2.value, TestAudioUrls.AUDIO3.value,
242+
TestAudioUrls.MP3_AUDIO1.value, TestAudioUrls.MP3_AUDIO1.value, TestAudioUrls.ACC_AUDIO1.value,
243+
TestAudioUrls.OGG_AUDIO1.value, TestAudioUrls.FLAC_AUDIO1.value
244+
]:
245+
with self.subTest(url=url):
246+
streaming_media_processor_object = StreamingMediaProcessor(
247+
url=url, preprocessors=self.test_preprocessor,
248+
preprocessing_config=self.test_audio_preprocessing_config
249+
)
250+
self.assertEqual(Modality.AUDIO, streaming_media_processor_object.probed_modality)
251+
252+
def test_prob_modality_correct_image(self):
253+
"""Ensure that the probed modality is correct for various image formats. Note that
254+
an error is raised as StreamingMediaProcessor is not designed to handle images."""
255+
for url in [
256+
TestImageUrls.IMAGE1.value, TestImageUrls.IMAGE2.value, TestImageUrls.IMAGE3.value,
257+
TestImageUrls.COCO.value
258+
]:
259+
with self.subTest(url=url):
260+
with self.assertRaises(MediaMismatchError) as e:
261+
_ = StreamingMediaProcessor(
262+
url=url, preprocessors=self.test_preprocessor,
263+
preprocessing_config=self.test_video_preprocessing_config
264+
)
265+
self.assertIn('the detected modality image', str(e.exception))
266+
267+
268+
def test_incorrect_modality_between_audio_and_video_will_raise_an_error(self):
269+
test_cases = [
270+
(TestVideoUrls.VIDEO1.value, self.test_audio_preprocessing_config,
271+
"The url is video, but the preprocessing config is audio"),
272+
(TestAudioUrls.AUDIO1.value, self.test_video_preprocessing_config,
273+
"The url is audio, but the preprocessing config is video")
274+
]
275+
for url, processing_config, msg in test_cases:
276+
with self.subTest(msg):
277+
with self.assertRaises(MediaMismatchError) as e:
278+
_ = StreamingMediaProcessor(
279+
url=url, preprocessors=self.test_preprocessor,
280+
preprocessing_config=processing_config
281+
)
282+
self.assertIn("Please check your media file and try again", str(e.exception))
283+
284+
285+
286+

tests/integ_tests/tensor_search/integ_tests/test_add_documents_structured.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ def setUpClass(cls) -> None:
125125
)
126126
],
127127
tensor_fields=['image_field', 'image_field_2'],
128-
model=Model(name='ViT-B/16')
128+
model=Model(name='open_clip/ViT-B-16/openai')
129129
)
130130
index_request_img_chunking = cls.structured_marqo_index_request(
131131
fields=[
@@ -141,7 +141,7 @@ def setUpClass(cls) -> None:
141141
)
142142
],
143143
tensor_fields=['image_field'],
144-
model=Model(name='ViT-B/16'),
144+
model=Model(name='open_clip/ViT-B-16/openai'),
145145
normalize_embeddings=True,
146146
image_preprocessing=ImagePreProcessing(patch_method=PatchMethod.Frcnn)
147147
)
@@ -938,5 +938,4 @@ def test_add_documents_nonImageContentForAnImageField(self):
938938
for item in r.items:
939939
self.assertEqual(400, item.status)
940940
# modality mismatch
941-
self.assertIn("Error processing image_field, detected as language, "
942-
"but expected field type is image_pointer", item.message)
941+
self.assertIn("is not a local file or a valid url", item.message)

0 commit comments

Comments
 (0)