2626class SemiStructuredVespaIndex (StructuredVespaIndex , UnstructuredVespaIndex ):
2727 """
2828 An implementation of VespaIndex for SemiStructured indexes.
29+ TODO the multi-inheritance makes the implementation difficult to reason about. Consider refactor to composition
30+ instead. e.g. extract different logics to different query component builders, and combined the result.
2931 """
3032 index_supports_partial_updates : bool = False
3133
@@ -108,8 +110,24 @@ def _to_vespa_hybrid_query(self, marqo_query):
108110 # add sort by and relevance cutoff
109111 self ._add_relevance_cutoff_and_sort_by_params (marqo_query , query )
110112
113+ # add the collapse_field to query
114+ if marqo_query .collapse_field_name :
115+ query .update (self ._generate_collapse_query_params (marqo_query .collapse_field_name ))
116+
111117 return query
112118
119+ def _generate_collapse_query_params (self , collapse_field_name : str ):
120+ return {
121+ 'collapsefield' : collapse_field_name ,
122+ 'collapsesize' : 1 , # currently fixed to 1, will support multiple if needed in the future
123+
124+ # use a different rank profile to ensure diversity in the result returned to Vespa container
125+ 'marqo__ranking.lexical.lexical' : common .RANK_PROFILE_BM25 + '_diversity' ,
126+ 'marqo__ranking.tensor.tensor' : common .RANK_PROFILE_EMBEDDING_SIMILARITY + '_diversity' ,
127+ 'marqo__ranking.lexical.tensor' : common .RANK_PROFILE_HYBRID_BM25_THEN_EMBEDDING_SIMILARITY + '_diversity' ,
128+ 'marqo__ranking.tensor.lexical' : common .RANK_PROFILE_HYBRID_EMBEDDING_SIMILARITY_THEN_BM25 + '_diversity' ,
129+ }
130+
113131 def _add_relevance_cutoff_and_sort_by_params (self , marqo_query , query ):
114132 if marqo_query .relevance_cutoff :
115133 query ["marqo__hybrid.relevanceCutoff.method" ] = marqo_query .relevance_cutoff .method
@@ -136,6 +154,8 @@ def _add_relevance_cutoff_and_sort_by_params(self, marqo_query, query):
136154 for index , field in enumerate (marqo_query .sort_by .fields ):
137155 query ["query_features" ][f'marqo__sort_field_weights_{ index } ' ] = {field .field_name : 1 }
138156
157+ return query
158+
139159 def _generate_facet_queries (self , marqo_query ):
140160 facets_query_skeleton = '%s limit 0 | %s'
141161 QUERY_DELIMITER = "\n ---MARQO-YQL-QUERY-DELIMITER---\n "
@@ -171,7 +191,8 @@ def _generate_facet_queries(self, marqo_query):
171191 f'{ base_yql } { filter_term } ' , f"all(group({ self ._TOTAL_HITS_GROUP_CONST } ) each(output(count())))" ))
172192
173193 if marqo_query .facets is not None :
174- facets_term = self ._get_facets_term (marqo_query .facets )
194+ facets_term = self ._get_facets_term (marqo_query .facets ,
195+ collapse_field_name = marqo_query .collapse_field_name )
175196
176197 if facets_term is not None :
177198 facet_queries .append (facets_query_skeleton % (f'{ base_yql } { filter_term } ' , facets_term ))
@@ -190,15 +211,17 @@ def _generate_facet_queries(self, marqo_query):
190211 new_filter_term = f' AND { new_filter_term } '
191212 else :
192213 new_filter_term = ''
193- new_facets_term = self ._get_facets_term (marqo_query .facets , facet_parameters .exclude_terms )
214+ new_facets_term = self ._get_facets_term (marqo_query .facets , facet_parameters .exclude_terms ,
215+ collapse_field_name = marqo_query .collapse_field_name )
194216
195217 query_yql = f'{ base_yql } { new_filter_term } '
196218
197219 facet_queries .append (facets_query_skeleton % (query_yql , new_facets_term ))
198220
199221 return QUERY_DELIMITER .join (facet_queries )
200222
201- def _get_facets_term (self , facets_parameters : FacetsParameters , exclusion_terms : List [str ] = None ) -> str :
223+ def _get_facets_term (self , facets_parameters : FacetsParameters , exclusion_terms : List [str ] = None ,
224+ collapse_field_name : Optional [str ] = None ) -> str :
202225 """
203226 Build a facets grouping query string from the provided facets_parameters.
204227 """
@@ -260,7 +283,11 @@ def build_field_group(field_config, field_name, field_id, field_type_overwrite=N
260283 params = build_group_parameters (field_config )
261284
262285 # Build output expression
263- if field_config .type == "number" :
286+ if collapse_field_name :
287+ # we can only get the count collapsed to this field regardless of data type
288+ output = f"each(group({ collapse_field_name } ) output(count()))"
289+ elif field_config .type == "number" :
290+ # if we do not collapse, we can get the following stats with count for number type
264291 aggregations = ["sum" , "avg" , "min" , "max" ]
265292 field_type = FIELD_TYPES [field_type_overwrite ]
266293 funcs = [f'{ func } ({ field_type } {{"{ field_name } "}})' for func in aggregations ]
@@ -321,6 +348,10 @@ def generate_equality_filter_string(node: search_filter.EqualityTerm) -> str:
321348 if node .field == MARQO_DOC_ID :
322349 return f'({ VESPA_FIELD_ID } contains "{ node .value } ")'
323350
351+ if self .get_marqo_index ().is_collapse_field (node .field ):
352+ # collapse field is indexed as attribute, can be used directly in a filter term
353+ return f'({ node .field } contains "{ node .value } ")'
354+
324355 # Bool Filter
325356 if node .value .lower () in self ._FILTER_STRING_BOOL_VALUES :
326357 filter_value = int (True if node .value .lower () == "true" else False )
@@ -1034,11 +1065,15 @@ def _combine_number_stats(self, current_stats, stats):
10341065 if current_stats == {}:
10351066 return stats
10361067 aggregated_stats = {}
1037- aggregated_stats ["sum" ] = current_stats ["sum" ] + stats ["sum" ]
10381068 aggregated_stats ["count" ] = current_stats ["count" ] + stats ["count" ]
1039- aggregated_stats ["avg" ] = (current_stats ["avg" ] * current_stats ["count" ] + stats ["avg" ] * stats ["count" ]) / (current_stats ["count" ] + stats ["count" ])
1040- aggregated_stats ["min" ] = min (current_stats ["min" ], stats ["min" ])
1041- aggregated_stats ["max" ] = max (current_stats ["max" ], stats ["max" ])
1069+ if "sum" in current_stats and "sum" in stats :
1070+ aggregated_stats ["sum" ] = current_stats ["sum" ] + stats ["sum" ]
1071+ if "avg" in current_stats and "avg" in stats :
1072+ aggregated_stats ["avg" ] = (current_stats ["avg" ] * current_stats ["count" ] + stats ["avg" ] * stats ["count" ]) / (current_stats ["count" ] + stats ["count" ])
1073+ if "min" in current_stats and "min" in stats :
1074+ aggregated_stats ["min" ] = min (current_stats ["min" ], stats ["min" ])
1075+ if "max" in current_stats and "max" in stats :
1076+ aggregated_stats ["max" ] = max (current_stats ["max" ], stats ["max" ])
10421077 return aggregated_stats
10431078
10441079
0 commit comments