Skip to content

Commit 6755207

Browse files
authored
Support collapse fields (PR 1/2) (#1276)
1 parent 9b3e5c2 commit 6755207

28 files changed

+1393
-66
lines changed

src/marqo/core/models/marqo_index.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,21 @@
1818
# TODO refactor to remove dep to s2_inference
1919
from marqo.s2_inference import s2_inference
2020
from marqo.s2_inference.errors import UnknownModelError, InvalidModelPropertiesError
21-
import marqo.core.constants as constants
2221

2322
logger = get_logger(__name__)
2423

2524

25+
class CollapseField(StrictBaseModel):
26+
name: str
27+
min_groups: int = pydantic.Field(default=500, gt=0, alias='minGroups')
28+
29+
@validator('name')
30+
def validate_field_name_collapse(cls, v):
31+
# Use common field name validation
32+
validate_field_name(v)
33+
return v
34+
35+
2636
class IndexType(Enum):
2737
Structured = 'structured'
2838
Unstructured = 'unstructured'
@@ -520,6 +530,7 @@ class SemiStructuredMarqoIndex(UnstructuredMarqoIndex):
520530
tensor_fields: List[TensorField]
521531
string_array_fields: Optional[List[
522532
StringArrayField]] # This is required so that when saving a document containing string array fields, we can make changes to the schema on the fly. Ref: https://github.com/marqo-ai/marqo/blob/cfea70adea7039d1586c94e36adae8e66cabe306/src/marqo/core/semi_structured_vespa_index/semi_structured_vespa_schema_template_2_16.sd.jinja2#L83
533+
collapse_fields: Optional[List[CollapseField]] = None
523534

524535
def __init__(self, **data):
525536
super().__init__(**data)
@@ -528,6 +539,18 @@ def __init__(self, **data):
528539
def _valid_type(cls) -> IndexType:
529540
return IndexType.SemiStructured
530541

542+
@root_validator
543+
def validate_collapse_fields(cls, values):
544+
collapse_fields = values.get('collapse_fields')
545+
if collapse_fields is not None and len(collapse_fields) != 1:
546+
raise ValueError("There must be exactly one collapse field")
547+
return values
548+
549+
def is_collapse_field(self, field_name: str) -> bool:
550+
if not self.collapse_fields:
551+
return False
552+
return field_name in [field.name for field in self.collapse_fields]
553+
531554
@property
532555
def field_map(self) -> Dict[str, Field]:
533556
"""

src/marqo/core/models/marqo_index_request.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,17 @@ class UnstructuredMarqoIndexRequest(MarqoIndexRequest):
4141
treat_urls_and_pointers_as_images: bool
4242
treat_urls_and_pointers_as_media: bool
4343
filter_string_max_length: int
44+
collapse_fields: Optional[List[marqo_index.CollapseField]] = None
45+
46+
@root_validator
47+
def validate_collapse_fields(cls, values):
48+
collapse_fields = values.get('collapse_fields')
49+
if collapse_fields is not None:
50+
if len(collapse_fields) == 0:
51+
raise ValueError("collapse_fields cannot be an empty list")
52+
if len(collapse_fields) > 1:
53+
raise ValueError("Only one collapse field is supported")
54+
return values
4455

4556

4657
class FieldRequest(StrictBaseModel):

src/marqo/core/semi_structured_vespa_index/semi_structured_add_document_handler.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,43 @@ def __init__(self, marqo_index: SemiStructuredMarqoIndex, add_docs_params: AddDo
4343
self.should_update_index = False
4444
self.field_count_config = field_count_config
4545

46+
def _validate_doc(self, doc) -> None:
47+
"""Override parent validation to add collapse field validation."""
48+
# Call parent validation first
49+
super()._validate_doc(doc)
50+
51+
# Add collapse field validation
52+
self._validate_collapse_field_presence(doc)
53+
54+
def _validate_collapse_field_presence(self, doc: Dict[str, Any]) -> None:
55+
"""Validate that documents contain required collapse fields with correct type."""
56+
if not self.marqo_index.collapse_fields:
57+
return # No collapse fields configured
58+
59+
collapse_field = self.marqo_index.collapse_fields[0] # Only one allowed per spec
60+
collapse_field_name = collapse_field.name
61+
62+
# TODO confirm if all these validations are required
63+
if collapse_field_name not in doc:
64+
raise AddDocumentsError(
65+
f"Document missing required field '{collapse_field_name}'. "
66+
f"All documents must contain this field for grouping."
67+
)
68+
69+
collapse_value = doc[collapse_field_name]
70+
71+
if not isinstance(collapse_value, str):
72+
raise AddDocumentsError(
73+
f"Field '{collapse_field_name}' must be of type string. "
74+
f"Got {type(collapse_value).__name__}: {collapse_value}"
75+
)
76+
77+
if not collapse_value.strip():
78+
raise AddDocumentsError(
79+
f"Field '{collapse_field_name}' cannot be empty. "
80+
f"Provide a non-empty string value for grouping."
81+
)
82+
4683
def _handle_field(self, marqo_doc, field_name, field_content):
4784
"""Handle a field in a Marqo document by processing it and updating the index schema if needed.
4885
@@ -54,6 +91,10 @@ def _handle_field(self, marqo_doc, field_name, field_content):
5491
# Process field using parent class handler
5592
super()._handle_field(marqo_doc, field_name, field_content)
5693

94+
# Skip automatic lexical field creation for collapse fields - they are predefined in schema
95+
if self.marqo_index.is_collapse_field(field_name):
96+
return
97+
5798
# Add lexical field if content is a string
5899
if isinstance(marqo_doc[field_name], str):
59100
language = self._get_field_language(field_name)

src/marqo/core/semi_structured_vespa_index/semi_structured_document.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,9 @@ def from_vespa_document(cls, document: Dict, marqo_index: SemiStructuredMarqoInd
8585
string_array_prefix_length = len(common.STRING_ARRAY + '_')
8686

8787
for field_name, field_value in fields.items():
88-
if field_name in tensor_subfield_map:
88+
if marqo_index.is_collapse_field(field_name):
89+
text_fields[field_name] = field_value
90+
elif field_name in tensor_subfield_map:
8991
tensor_fields[field_name] = field_value
9092
elif field_name in lexical_field_map:
9193
# Lexical fields are returned with prefixed name from get_by_ids
@@ -199,6 +201,10 @@ def _handle_field_content(cls, field_name: str, field_content: Union[str, bool,
199201

200202
@classmethod
201203
def _handle_string_field(cls, field_name: str, field_content: str, instance, marqo_index: SemiStructuredMarqoIndex):
204+
if marqo_index.is_collapse_field(field_name):
205+
instance.text_fields[field_name] = field_content
206+
return
207+
202208
if field_name not in marqo_index.field_map:
203209
raise MarqoDocumentParsingError(f'Field {field_name} is not in index {marqo_index.name}')
204210

@@ -209,6 +215,7 @@ def _handle_string_field(cls, field_name: str, field_content: str, instance, mar
209215
instance.fixed_fields.short_string_fields[field_name] = field_content
210216

211217
if instance.index_supports_partial_updates:
218+
# TODO do we need to store field type of collapse field? maybe since we need to support partial updates
212219
instance.fixed_fields.field_types[field_name] = MarqoFieldTypes.STRING.value
213220

214221
@classmethod

src/marqo/core/semi_structured_vespa_index/semi_structured_vespa_schema.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,15 @@ def generate_vespa_schema(cls, marqo_index: SemiStructuredMarqoIndex) -> str:
3131
vespa_schema_template = environment.get_template("semi_structured_vespa_schema_template_2_16.sd.jinja2")
3232
else:
3333
vespa_schema_template = environment.get_template("semi_structured_vespa_schema_template.sd.jinja2")
34-
return vespa_schema_template.render(index=marqo_index, dimension=str(marqo_index.model.get_dimension()))
34+
35+
# simplify the logic in the template to just pass in the first collapse field if exists
36+
collapse_field = marqo_index.collapse_fields[0] if marqo_index.collapse_fields else None
37+
38+
return vespa_schema_template.render(
39+
index=marqo_index,
40+
collapse_field=collapse_field,
41+
dimension=str(marqo_index.model.get_dimension())
42+
)
3543

3644
def _generate_marqo_index(self, schema_name: str) -> SemiStructuredMarqoIndex:
3745
marqo_index = SemiStructuredMarqoIndex(
@@ -55,6 +63,7 @@ def _generate_marqo_index(self, schema_name: str) -> SemiStructuredMarqoIndex:
5563
filter_string_max_length=self._index_request.filter_string_max_length,
5664
treat_urls_and_pointers_as_images=self._index_request.treat_urls_and_pointers_as_images,
5765
treat_urls_and_pointers_as_media=self._index_request.treat_urls_and_pointers_as_media,
66+
collapse_fields=self._index_request.collapse_fields,
5867
)
5968

6069
return marqo_index

src/marqo/core/semi_structured_vespa_index/semi_structured_vespa_schema_template_2_16.sd.jinja2

Lines changed: 70 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,14 @@ schema {{ index.schema_name }} {
7171
indexing: summary
7272
}
7373

74+
{% if collapse_field -%}
75+
field {{ collapse_field.name }} type string {
76+
indexing: attribute | summary
77+
attribute: fast-search
78+
rank: filter
79+
}
80+
{% endif -%}
81+
7482
{% for lexical_field in index.lexical_fields -%}
7583
field {{ lexical_field.lexical_field_name }} type string {
7684
{%- if lexical_field.language %}
@@ -227,7 +235,19 @@ schema {{ index.schema_name }} {
227235
expression: sort_field_value(query(marqo__sort_field_weights_2))
228236
}
229237

230-
match-features: global_mult_modifier global_add_modifier sort_field_value_0 sort_field_value_1 sort_field_value_2
238+
{% if collapse_field -%}
239+
function collapse_field_hash() {
240+
expression: attribute({{ collapse_field.name }})
241+
}
242+
{% endif -%}
243+
match-features {
244+
global_mult_modifier
245+
global_add_modifier
246+
sort_field_value_0
247+
sort_field_value_1
248+
sort_field_value_2
249+
{% if collapse_field %}collapse_field_hash{% endif %}
250+
}
231251
}
232252

233253
{% if index.lexical_fields -%}
@@ -236,6 +256,17 @@ schema {{ index.schema_name }} {
236256
expression: modify(lexical_score(), query(marqo__mult_weights_lexical), query(marqo__add_weights_lexical))
237257
}
238258
}
259+
{% if collapse_field -%}
260+
rank-profile bm25_diversity inherits bm25 {
261+
diversity {
262+
attribute: {{ collapse_field.name }}
263+
min-groups: {{ collapse_field.min_groups }}
264+
}
265+
second-phase {
266+
expression: firstPhase
267+
}
268+
}
269+
{% endif -%}
239270
{% endif -%}
240271

241272
{# We provide this rank profile even without the tensor field to support embed requests -#}
@@ -255,6 +286,18 @@ schema {{ index.schema_name }} {
255286
{%- endif %}
256287
}
257288

289+
{% if collapse_field -%}
290+
rank-profile embedding_similarity_diversity inherits embedding_similarity {
291+
diversity {
292+
attribute: {{ collapse_field.name }}
293+
min-groups: {{ collapse_field.min_groups }}
294+
}
295+
second-phase {
296+
expression: firstPhase
297+
}
298+
}
299+
{% endif -%}
300+
258301
{% if index.lexical_fields and index.tensor_fields -%}
259302
rank-profile hybrid_custom_searcher inherits default {
260303
inputs {
@@ -295,6 +338,26 @@ schema {{ index.schema_name }} {
295338
expression: modify(lexical_score(), query(marqo__mult_weights_lexical), query(marqo__add_weights_lexical))
296339
}
297340
}
341+
342+
{% if collapse_field -%}
343+
rank-profile hybrid_bm25_then_embedding_similarity_diversity inherits hybrid_bm25_then_embedding_similarity {
344+
diversity {
345+
attribute: {{ collapse_field.name }}
346+
min-groups: {{ collapse_field.min_groups }}
347+
}
348+
}
349+
350+
rank-profile hybrid_embedding_similarity_then_bm25_diversity inherits hybrid_embedding_similarity_then_bm25 {
351+
diversity {
352+
attribute: {{ collapse_field.name }}
353+
min-groups: {{ collapse_field.min_groups }}
354+
}
355+
second-phase {
356+
expression: firstPhase
357+
}
358+
}
359+
{% endif -%}
360+
298361
{%- endif %}
299362

300363
document-summary all-non-vector-summary {
@@ -303,6 +366,9 @@ schema {{ index.schema_name }} {
303366
summary marqo__bool_fields type map<string, byte> {}
304367
summary marqo__int_fields type map<string, long> {}
305368
summary marqo__float_fields type map<string, double> {}
369+
{% if collapse_field -%}
370+
summary {{ collapse_field.name }} type string {}
371+
{% endif -%}
306372
{% for string_array_field in index.string_array_fields -%}
307373
summary {{ string_array_field.string_array_field_name }} type array<string> {source: {{ string_array_field.string_array_field_name }}}
308374
{% endfor -%}
@@ -320,6 +386,9 @@ schema {{ index.schema_name }} {
320386
summary marqo__bool_fields type map<string, byte> {}
321387
summary marqo__int_fields type map<string, long> {}
322388
summary marqo__float_fields type map<string, double> {}
389+
{% if collapse_field -%}
390+
summary {{ collapse_field.name }} type string {}
391+
{% endif -%}
323392
{% for string_array_field in index.string_array_fields -%}
324393
summary {{ string_array_field.string_array_field_name }} type array<string> {source: {{ string_array_field.string_array_field_name }}}
325394
{% endfor -%}

src/marqo/core/vespa_index/add_documents_handler.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -175,14 +175,17 @@ def add_documents(self) -> MarqoAddDocumentsResponse:
175175
with RequestMetricsStore.for_request().time("add_documents.vespa.to_vespa_docs"):
176176
vespa_docs = self._convert_to_vespa_docs()
177177

178-
self._pre_persist_to_vespa()
178+
if vespa_docs: # only continue if there's still vespa docs to persist
179+
self._pre_persist_to_vespa()
179180

180-
# persist to vespa if there are still valid docs
181-
with RequestMetricsStore.for_request().time("add_documents.vespa._bulk"):
182-
response = self.vespa_client.feed_batch(vespa_docs, self.marqo_index.schema_name)
181+
# persist to vespa if there are still valid docs
182+
with RequestMetricsStore.for_request().time("add_documents.vespa._bulk"):
183+
response = self.vespa_client.feed_batch(vespa_docs, self.marqo_index.schema_name)
183184

184-
with RequestMetricsStore.for_request().time("add_documents.postprocess"):
185-
self._handle_vespa_response(response)
185+
with RequestMetricsStore.for_request().time("add_documents.postprocess"):
186+
self._handle_vespa_response(response)
187+
else:
188+
logger.debug('Skipping the Vespa roundtrip since there is no valid doc to feed')
186189

187190
return self.add_docs_response_collector.to_add_doc_responses(self.marqo_index.name)
188191

src/marqo/tensor_search/models/index_settings.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ class IndexSettings(StrictBaseModel):
2323
treatUrlsAndPointersAsImages: Optional[bool]
2424
treatUrlsAndPointersAsMedia: Optional[bool]
2525
filterStringMaxLength: Optional[int]
26+
collapseFields: Optional[List[core.CollapseField]] = None
2627
model: str = 'hf/e5-base-v2'
2728
modelProperties: Optional[Dict[str, Any]]
2829
textQueryPrefix: Optional[str] = None
@@ -74,6 +75,19 @@ def validate_url_pointer_treatment(cls, values):
7475

7576
return values
7677

78+
@root_validator
79+
def validate_collapse_fields(cls, values):
80+
collapse_fields = values.get('collapseFields')
81+
index_type = values.get('type')
82+
83+
# collapseFields is only supported for SemiStructuredIndex
84+
if collapse_fields is not None and index_type == core.IndexType.Structured:
85+
raise api_exceptions.InvalidArgError(
86+
"collapseFields is only supported for unstructured indexes"
87+
)
88+
89+
return values
90+
7791
@root_validator(pre=True)
7892
def validate_field_names(cls, values):
7993
# Verify no snake case field names (pydantic won't catch these due to allow_population_by_field_name = True)
@@ -191,6 +205,7 @@ def to_marqo_index_request(self, index_name: str) -> MarqoIndexRequest:
191205
treat_urls_and_pointers_as_images=self.treatUrlsAndPointersAsImages,
192206
treat_urls_and_pointers_as_media=self.treatUrlsAndPointersAsMedia,
193207
filter_string_max_length=self.filterStringMaxLength,
208+
collapse_fields=self.collapseFields,
194209
marqo_version=version.get_version(),
195210
created_at=time.time(),
196211
updated_at=time.time()
@@ -204,11 +219,18 @@ def from_marqo_index(cls, marqo_index: core.MarqoIndex) -> "IndexSettings":
204219
# This covers both UnstructuredMarqoIndex and SemiStructuredMarqoIndex
205220
# We intentionally hide the lexical and tensor fields info in SemiStructuredMarqoIndex from customers since
206221
# this information and the SemiStructured concept are internal implementation details only.
222+
223+
# Only include collapseFields for SemiStructuredMarqoIndex
224+
collapse_fields = None
225+
if isinstance(marqo_index, core.SemiStructuredMarqoIndex):
226+
collapse_fields = marqo_index.collapse_fields
227+
207228
return cls(
208229
type=core.IndexType.Unstructured,
209230
treatUrlsAndPointersAsImages=marqo_index.treat_urls_and_pointers_as_images,
210231
treatUrlsAndPointersAsMedia=marqo_index.treat_urls_and_pointers_as_media,
211232
filterStringMaxLength=marqo_index.filter_string_max_length,
233+
collapseFields=collapse_fields,
212234
model=marqo_index.model.name,
213235
modelProperties=IndexSettings.get_model_properties(marqo_index),
214236
normalizeEmbeddings=marqo_index.normalize_embeddings,

0 commit comments

Comments
 (0)