Skip to content

Commit 0209469

Browse files
authored
Support collapse fields (PR 2/2) (#1277)
1 parent 2f7417b commit 0209469

File tree

21 files changed

+2237
-146
lines changed

21 files changed

+2237
-146
lines changed

src/marqo/core/constants.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
MARQO_LANGUAGE_MINIMUM_VERSION = semver.VersionInfo.parse('2.16.0')
2929
MARQO_STEMMING_MINIMUM_VERSION = semver.VersionInfo.parse('2.16.0')
3030
MARQO_PARTIAL_UPDATE_MINIMUM_VERSION = semver.VersionInfo.parse('2.16.0')
31+
MARQO_COLLAPSE_FIELDS_MINIMUM_VERSION = semver.VersionInfo.parse('2.23.0')
3132

3233
# For score modifiers
3334
QUERY_INPUT_SCORE_MODIFIERS_MULT_WEIGHTS_2_9 = 'marqo__mult_weights'

src/marqo/core/models/marqo_query.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ class MarqoHybridQuery(MarqoTensorQuery, MarqoLexicalQuery):
7373
track_total_hits: Optional[bool] = None
7474
relevance_cutoff: Optional[RelevanceCutoffModel] = None
7575
sort_by: Optional[SortByModel] = None
76+
collapse_field_name: Optional[str] = None
7677

7778
@root_validator(pre=True)
7879
def validate_searchable_attributes_and_score_modifiers(cls, values):

src/marqo/core/search/hybrid_search.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,8 @@ def search(
5050
language: Optional[str] = None,
5151
relevance_cutoff: Optional[RelevanceCutoffModel] = None,
5252
sort_by: Optional[SortByModel] = None,
53-
interpolation_method: Optional[InterpolationMethod] = None
53+
interpolation_method: Optional[InterpolationMethod] = None,
54+
collapse_field_name: Optional[str] = None
5455
) -> Dict:
5556
"""
5657
@@ -81,6 +82,7 @@ def search(
8182
relevance_cutoff: RelevanceCutoffModel object to specify relevance cutoff for the search.
8283
sort_by: SortByModel object to specify sorting for the search. If not provided, no sorting will be applied.
8384
interpolation_method: InterpolationMethod object to specify the interpolation method for hybrid search.
85+
collapse_field_name: field name to collapse the search result on.
8486
Returns:
8587
8688
Output format:
@@ -292,7 +294,8 @@ def search(
292294
track_total_hits=track_total_hits,
293295
language=language,
294296
relevance_cutoff=relevance_cutoff,
295-
sort_by=sort_by
297+
sort_by=sort_by,
298+
collapse_field_name=collapse_field_name
296299
)
297300

298301
vespa_index = vespa_index_factory(marqo_index)

src/marqo/core/semi_structured_vespa_index/semi_structured_vespa_index.py

Lines changed: 43 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@
2626
class SemiStructuredVespaIndex(StructuredVespaIndex, UnstructuredVespaIndex):
2727
"""
2828
An implementation of VespaIndex for SemiStructured indexes.
29+
TODO the multi-inheritance makes the implementation difficult to reason about. Consider refactor to composition
30+
instead. e.g. extract different logics to different query component builders, and combined the result.
2931
"""
3032
index_supports_partial_updates: bool = False
3133

@@ -108,8 +110,24 @@ def _to_vespa_hybrid_query(self, marqo_query):
108110
# add sort by and relevance cutoff
109111
self._add_relevance_cutoff_and_sort_by_params(marqo_query, query)
110112

113+
# add the collapse_field to query
114+
if marqo_query.collapse_field_name:
115+
query.update(self._generate_collapse_query_params(marqo_query.collapse_field_name))
116+
111117
return query
112118

119+
def _generate_collapse_query_params(self, collapse_field_name: str):
120+
return {
121+
'collapsefield': collapse_field_name,
122+
'collapsesize': 1, # currently fixed to 1, will support multiple if needed in the future
123+
124+
# use a different rank profile to ensure diversity in the result returned to Vespa container
125+
'marqo__ranking.lexical.lexical': common.RANK_PROFILE_BM25 + '_diversity',
126+
'marqo__ranking.tensor.tensor': common.RANK_PROFILE_EMBEDDING_SIMILARITY + '_diversity',
127+
'marqo__ranking.lexical.tensor': common.RANK_PROFILE_HYBRID_BM25_THEN_EMBEDDING_SIMILARITY + '_diversity',
128+
'marqo__ranking.tensor.lexical': common.RANK_PROFILE_HYBRID_EMBEDDING_SIMILARITY_THEN_BM25 + '_diversity',
129+
}
130+
113131
def _add_relevance_cutoff_and_sort_by_params(self, marqo_query, query):
114132
if marqo_query.relevance_cutoff:
115133
query["marqo__hybrid.relevanceCutoff.method"] = marqo_query.relevance_cutoff.method
@@ -136,6 +154,8 @@ def _add_relevance_cutoff_and_sort_by_params(self, marqo_query, query):
136154
for index, field in enumerate(marqo_query.sort_by.fields):
137155
query["query_features"][f'marqo__sort_field_weights_{index}'] = {field.field_name: 1}
138156

157+
return query
158+
139159
def _generate_facet_queries(self, marqo_query):
140160
facets_query_skeleton = '%s limit 0 | %s'
141161
QUERY_DELIMITER = "\n---MARQO-YQL-QUERY-DELIMITER---\n"
@@ -171,7 +191,8 @@ def _generate_facet_queries(self, marqo_query):
171191
f'{base_yql}{filter_term}', f"all(group({self._TOTAL_HITS_GROUP_CONST}) each(output(count())))"))
172192

173193
if marqo_query.facets is not None:
174-
facets_term = self._get_facets_term(marqo_query.facets)
194+
facets_term = self._get_facets_term(marqo_query.facets,
195+
collapse_field_name=marqo_query.collapse_field_name)
175196

176197
if facets_term is not None:
177198
facet_queries.append(facets_query_skeleton % (f'{base_yql}{filter_term}', facets_term))
@@ -190,15 +211,17 @@ def _generate_facet_queries(self, marqo_query):
190211
new_filter_term = f' AND {new_filter_term}'
191212
else:
192213
new_filter_term = ''
193-
new_facets_term = self._get_facets_term(marqo_query.facets, facet_parameters.exclude_terms)
214+
new_facets_term = self._get_facets_term(marqo_query.facets, facet_parameters.exclude_terms,
215+
collapse_field_name=marqo_query.collapse_field_name)
194216

195217
query_yql = f'{base_yql}{new_filter_term}'
196218

197219
facet_queries.append(facets_query_skeleton % (query_yql, new_facets_term))
198220

199221
return QUERY_DELIMITER.join(facet_queries)
200222

201-
def _get_facets_term(self, facets_parameters: FacetsParameters, exclusion_terms: List[str] = None) -> str:
223+
def _get_facets_term(self, facets_parameters: FacetsParameters, exclusion_terms: List[str] = None,
224+
collapse_field_name: Optional[str] = None) -> str:
202225
"""
203226
Build a facets grouping query string from the provided facets_parameters.
204227
"""
@@ -260,7 +283,11 @@ def build_field_group(field_config, field_name, field_id, field_type_overwrite=N
260283
params = build_group_parameters(field_config)
261284

262285
# Build output expression
263-
if field_config.type == "number":
286+
if collapse_field_name:
287+
# we can only get the count collapsed to this field regardless of data type
288+
output = f"each(group({collapse_field_name}) output(count()))"
289+
elif field_config.type == "number":
290+
# if we do not collapse, we can get the following stats with count for number type
264291
aggregations = ["sum", "avg", "min", "max"]
265292
field_type = FIELD_TYPES[field_type_overwrite]
266293
funcs = [f'{func}({field_type}{{"{field_name}"}})' for func in aggregations]
@@ -321,6 +348,10 @@ def generate_equality_filter_string(node: search_filter.EqualityTerm) -> str:
321348
if node.field == MARQO_DOC_ID:
322349
return f'({VESPA_FIELD_ID} contains "{node.value}")'
323350

351+
if self.get_marqo_index().is_collapse_field(node.field):
352+
# collapse field is indexed as attribute, can be used directly in a filter term
353+
return f'({node.field} contains "{node.value}")'
354+
324355
# Bool Filter
325356
if node.value.lower() in self._FILTER_STRING_BOOL_VALUES:
326357
filter_value = int(True if node.value.lower() == "true" else False)
@@ -1034,11 +1065,15 @@ def _combine_number_stats(self, current_stats, stats):
10341065
if current_stats == {}:
10351066
return stats
10361067
aggregated_stats = {}
1037-
aggregated_stats["sum"] = current_stats["sum"] + stats["sum"]
10381068
aggregated_stats["count"] = current_stats["count"] + stats["count"]
1039-
aggregated_stats["avg"] = (current_stats["avg"] * current_stats["count"] + stats["avg"] * stats["count"]) / (current_stats["count"] + stats["count"])
1040-
aggregated_stats["min"] = min(current_stats["min"], stats["min"])
1041-
aggregated_stats["max"] = max(current_stats["max"], stats["max"])
1069+
if "sum" in current_stats and "sum" in stats:
1070+
aggregated_stats["sum"] = current_stats["sum"] + stats["sum"]
1071+
if "avg" in current_stats and "avg" in stats:
1072+
aggregated_stats["avg"] = (current_stats["avg"] * current_stats["count"] + stats["avg"] * stats["count"]) / (current_stats["count"] + stats["count"])
1073+
if "min" in current_stats and "min" in stats:
1074+
aggregated_stats["min"] = min(current_stats["min"], stats["min"])
1075+
if "max" in current_stats and "max" in stats:
1076+
aggregated_stats["max"] = max(current_stats["max"], stats["max"])
10421077
return aggregated_stats
10431078

10441079

src/marqo/tensor_search/api.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -439,7 +439,8 @@ def search(index_name: str, search_query_dict: dict, device: str = Depends(api_v
439439
language=search_query.language,
440440
relevance_cutoff= search_query.relevance_cutoff,
441441
sort_by = search_query.sort_by,
442-
interpolation_method=search_query.interpolationMethod
442+
interpolation_method=search_query.interpolationMethod,
443+
collapse_field_name=search_query.collapse_fields[0].name if search_query.collapse_fields else None
443444
)
444445
return ORJSONResponse(result)
445446
except Exception as e:

src/marqo/tensor_search/models/api_models.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,10 @@ class Config:
2929
pass
3030

3131

32+
class SearchCollapseField(ImmutableStrictBaseModel):
33+
name: str
34+
35+
3236
class CustomVectorQuery(ImmutableStrictBaseModel):
3337
class CustomVector(ImmutableStrictBaseModel):
3438
content: Optional[str] = None
@@ -68,6 +72,7 @@ class Config(BaseMarqoModel.Config):
6872
sort_by: Optional[SortByModel] = Field(default=None, alias="sortBy")
6973
relevance_cutoff: Optional[RelevanceCutoffModel] = Field(default=None, alias="relevanceCutoff")
7074
interpolationMethod: Optional[InterpolationMethod] = None
75+
collapse_fields: Optional[List[SearchCollapseField]] = Field(default=None, alias="collapseFields")
7176

7277
# By default, we retrieve 3 times more candidates than the limit to ensure we have enough results to sort.
7378
_DEFAULT_SORT_CANDIDATES_MULTIPLIER = 3
@@ -424,6 +429,25 @@ def _validate_and_set_sort_by_min_sort_candidates_parameters(cls, values):
424429
)
425430
return values
426431

432+
@root_validator(pre=False)
433+
def validate_collapse_fields_only_for_hybrid_search(cls, values):
434+
"""Validate collapse fields only provided for hybrid search"""
435+
collapse_fields = values.get('collapse_fields')
436+
search_method = values.get('searchMethod')
437+
if collapse_fields is not None and search_method.upper() != SearchMethod.HYBRID:
438+
raise ValueError(f"collapseFields can only be provided for 'HYBRID' search. "
439+
f"Search method is {search_method}.")
440+
return values
441+
442+
@root_validator(pre=False)
443+
def validate_single_collapse_field(cls, values):
444+
"""Validate exactly one collapse field is provided"""
445+
collapse_fields = values.get('collapse_fields')
446+
if collapse_fields is not None:
447+
if len(collapse_fields) != 1:
448+
raise ValueError("Exactly one collapse field must be provided")
449+
return values
450+
427451

428452
class BulkSearchQueryEntity(SearchQuery):
429453
index: MarqoIndex

src/marqo/tensor_search/tensor_search.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -359,7 +359,8 @@ def search(config: Config, index_name: str, text: Optional[Union[str, dict, Cust
359359
language: Optional[str] = None,
360360
relevance_cutoff: Optional[RelevanceCutoffModel] = None,
361361
sort_by: Optional[SortByModel] = None,
362-
interpolation_method: Optional[InterpolationMethod] = None
362+
interpolation_method: Optional[InterpolationMethod] = None,
363+
collapse_field_name: Optional[str] = None
363364
) -> Dict:
364365
"""The root search method. Calls the specific search method
365366
@@ -454,13 +455,32 @@ def search(config: Config, index_name: str, text: Optional[Union[str, dict, Cust
454455
# Fetch marqo index to pass to search method
455456
marqo_index = index_meta_cache.get_index(index_management=config.index_management, index_name=index_name)
456457
marqo_index_version = marqo_index.parsed_marqo_version()
458+
459+
# Validate collapse field configuration
460+
if collapse_field_name is not None:
461+
# Validate if the index version support this feature
462+
if (marqo_index_version < constants.MARQO_COLLAPSE_FIELDS_MINIMUM_VERSION or
463+
not isinstance(marqo_index, SemiStructuredMarqoIndex)):
464+
index_type = 'structured' if marqo_index.type == IndexType.Structured else 'unstructured'
465+
raise core_exceptions.UnsupportedFeatureError(
466+
f"The 'collapseFields' search parameter is only supported for unstructured indexes created with "
467+
f"Marqo version {str(constants.MARQO_COLLAPSE_FIELDS_MINIMUM_VERSION)} or later. "
468+
f"This index is {index_type} and was created with Marqo {marqo_index_version}."
469+
)
470+
471+
# Validate collapse field exists in index configuration
472+
if not marqo_index.is_collapse_field(collapse_field_name):
473+
raise api_exceptions.InvalidArgError(f"Field '{collapse_field_name}' is not configured as a collapse field "
474+
f"for this index")
475+
457476
if rerank_depth is not None \
458477
and marqo_index_version < constants.MARQO_RERANK_DEPTH_MINIMUM_VERSION:
459478
raise core_exceptions.UnsupportedFeatureError(
460479
f"The 'rerankDepth' search parameter is only supported for indexes created with Marqo version "
461480
f"{str(constants.MARQO_RERANK_DEPTH_MINIMUM_VERSION)} or later. "
462481
f"This index was created with Marqo {marqo_index_version}."
463482
)
483+
464484
if search_method.upper() in {SearchMethod.TENSOR, SearchMethod.HYBRID}:
465485
# Default approximate and efSearch -- we can't set these at API-level since they're not a valid args
466486
# for lexical search
@@ -511,7 +531,8 @@ def search(config: Config, index_name: str, text: Optional[Union[str, dict, Cust
511531
hybrid_parameters=hybrid_parameters, facets=facets, track_total_hits=track_total_hits,
512532
language=language,
513533
relevance_cutoff=relevance_cutoff, sort_by=sort_by,
514-
interpolation_method=interpolation_method
534+
interpolation_method=interpolation_method,
535+
collapse_field_name=collapse_field_name
515536
)
516537

517538
elif search_method.upper() == SearchMethod.LEXICAL:

src/marqo/version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
__version__ = "2.22.1"
1+
__version__ = "2.23.0"
22

33
def get_version() -> str:
44
return f"{__version__}"

src/marqo/vespa/vespa_client.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -257,7 +257,8 @@ def query(self, yql: str, hits: int = 10, ranking: str = None, model_restrict: s
257257

258258
self._query_raise_for_status(resp)
259259

260-
return QueryResult(**orjson.loads(resp.text))
260+
resp_dict = orjson.loads(resp.text)
261+
return QueryResult(**resp_dict)
261262

262263
def feed_document(self, document: VespaDocument, schema: str, timeout: int = 60) -> FeedDocumentResponse:
263264
"""

0 commit comments

Comments
 (0)