|
| 1 | +from unittest import TestCase |
| 2 | +from pydantic.v1 import ValidationError |
| 3 | + |
| 4 | +from marqo.core.models.marqo_query import ( |
| 5 | + MarqoTensorQuery, MarqoQuery, MarqoHybridQuery |
| 6 | +) |
| 7 | +from marqo.core.models.score_modifier import ScoreModifier, ScoreModifierType |
| 8 | +from marqo.core.search.search_filter import SearchFilter, EqualityTerm |
| 9 | +from marqo.core.models.hybrid_parameters import ( |
| 10 | + HybridParameters, RankingMethod, RetrievalMethod |
| 11 | +) |
| 12 | +from marqo.core.models.facets_parameters import ( |
| 13 | + FacetsParameters, FieldFacetsConfiguration |
| 14 | +) |
| 15 | + |
| 16 | + |
| 17 | +class TestMarqoTensorQuery(TestCase): |
| 18 | + |
| 19 | + def test_creation_with_all_values(self): |
| 20 | + """Test creating MarqoTensorQuery with all possible values.""" |
| 21 | + score_modifier = ScoreModifier( |
| 22 | + field="test_field", |
| 23 | + weight=1.5, |
| 24 | + type=ScoreModifierType.Multiply |
| 25 | + ) |
| 26 | + |
| 27 | + filter_obj = SearchFilter( |
| 28 | + EqualityTerm("field1", "value1", "field1:value1") |
| 29 | + ) |
| 30 | + |
| 31 | + query = MarqoTensorQuery( |
| 32 | + index_name="test_index", |
| 33 | + limit=20, |
| 34 | + offset=5, |
| 35 | + searchable_attributes=["field1", "field2"], |
| 36 | + attributes_to_retrieve=["field1", "field3"], |
| 37 | + filter=filter_obj, |
| 38 | + score_modifiers=[score_modifier], |
| 39 | + expose_facets=True, |
| 40 | + vector_query=[0.1, 0.2, 0.3, 0.4], |
| 41 | + ef_search=100, |
| 42 | + approximate=False, |
| 43 | + approximate_threshold=0.95, |
| 44 | + rerank_depth_tensor=50 |
| 45 | + ) |
| 46 | + |
| 47 | + # Verify all fields are set correctly |
| 48 | + self.assertEqual("test_index", query.index_name) |
| 49 | + self.assertEqual(20, query.limit) |
| 50 | + self.assertEqual(5, query.offset) |
| 51 | + self.assertEqual(["field1", "field2"], query.searchable_attributes) |
| 52 | + self.assertEqual(["field1", "field3"], query.attributes_to_retrieve) |
| 53 | + self.assertEqual(filter_obj, query.filter) |
| 54 | + self.assertEqual([score_modifier], query.score_modifiers) |
| 55 | + self.assertTrue(query.expose_facets) |
| 56 | + self.assertEqual([0.1, 0.2, 0.3, 0.4], query.vector_query) |
| 57 | + self.assertEqual(100, query.ef_search) |
| 58 | + self.assertFalse(query.approximate) |
| 59 | + self.assertEqual(0.95, query.approximate_threshold) |
| 60 | + self.assertEqual(50, query.rerank_depth_tensor) |
| 61 | + |
| 62 | + # Test inheritance |
| 63 | + self.assertIsInstance(query, MarqoQuery) |
| 64 | + self.assertIsInstance(query, MarqoTensorQuery) |
| 65 | + |
| 66 | + def test_required_fields(self): |
| 67 | + """Test that all required fields must be provided.""" |
| 68 | + base_params = { |
| 69 | + "index_name": "test_index", |
| 70 | + "limit": 10, |
| 71 | + "vector_query": [0.1, 0.2, 0.3] |
| 72 | + } |
| 73 | + |
| 74 | + required_fields = ["index_name", "limit", "vector_query"] |
| 75 | + |
| 76 | + for required_field in required_fields: |
| 77 | + with self.subTest(missing_field=required_field): |
| 78 | + params = base_params.copy() |
| 79 | + del params[required_field] |
| 80 | + |
| 81 | + with self.assertRaises(ValidationError) as context: |
| 82 | + MarqoTensorQuery(**params) |
| 83 | + |
| 84 | + self.assertIn(required_field, str(context.exception)) |
| 85 | + |
| 86 | + |
| 87 | +class TestMarqoHybridQuery(TestCase): |
| 88 | + |
| 89 | + def test_creation_with_all_values(self): |
| 90 | + """Test creating MarqoHybridQuery with all possible values.""" |
| 91 | + score_modifier = ScoreModifier( |
| 92 | + field="test_field", |
| 93 | + weight=1.5, |
| 94 | + type=ScoreModifierType.Multiply |
| 95 | + ) |
| 96 | + |
| 97 | + filter_obj = SearchFilter( |
| 98 | + EqualityTerm("field1", "value1", "field1:value1") |
| 99 | + ) |
| 100 | + |
| 101 | + hybrid_parameters = HybridParameters( |
| 102 | + retrievalMethod=RetrievalMethod.Disjunction, |
| 103 | + rankingMethod=RankingMethod.RRF, |
| 104 | + alpha=0.7, |
| 105 | + rrfK=100 |
| 106 | + ) |
| 107 | + |
| 108 | + facets = FacetsParameters( |
| 109 | + fields={ |
| 110 | + "test_field": FieldFacetsConfiguration(type="string") |
| 111 | + }, |
| 112 | + maxDepth=5, |
| 113 | + maxResults=100 |
| 114 | + ) |
| 115 | + |
| 116 | + query = MarqoHybridQuery( |
| 117 | + index_name="test_index", |
| 118 | + limit=20, |
| 119 | + offset=5, |
| 120 | + attributes_to_retrieve=["field1", "field3"], |
| 121 | + filter=filter_obj, |
| 122 | + expose_facets=True, |
| 123 | + vector_query=[0.1, 0.2, 0.3, 0.4], |
| 124 | + ef_search=100, |
| 125 | + approximate=False, |
| 126 | + approximate_threshold=0.95, |
| 127 | + rerank_depth_tensor=50, |
| 128 | + or_phrases=["phrase1", "phrase2"], |
| 129 | + and_phrases=["phrase3"], |
| 130 | + hybrid_parameters=hybrid_parameters, |
| 131 | + score_modifiers_lexical=[score_modifier], |
| 132 | + score_modifiers_tensor=[score_modifier], |
| 133 | + global_rerank_depth=100, |
| 134 | + facets=facets, |
| 135 | + track_total_hits=True |
| 136 | + ) |
| 137 | + |
| 138 | + # Verify all fields are set correctly |
| 139 | + self.assertEqual("test_index", query.index_name) |
| 140 | + self.assertEqual(20, query.limit) |
| 141 | + self.assertEqual(5, query.offset) |
| 142 | + self.assertEqual(["field1", "field3"], query.attributes_to_retrieve) |
| 143 | + self.assertEqual(filter_obj, query.filter) |
| 144 | + self.assertTrue(query.expose_facets) |
| 145 | + self.assertEqual([0.1, 0.2, 0.3, 0.4], query.vector_query) |
| 146 | + self.assertEqual(100, query.ef_search) |
| 147 | + self.assertFalse(query.approximate) |
| 148 | + self.assertEqual(0.95, query.approximate_threshold) |
| 149 | + self.assertEqual(50, query.rerank_depth_tensor) |
| 150 | + self.assertEqual(["phrase1", "phrase2"], query.or_phrases) |
| 151 | + self.assertEqual(["phrase3"], query.and_phrases) |
| 152 | + self.assertEqual(hybrid_parameters, query.hybrid_parameters) |
| 153 | + self.assertEqual([score_modifier], query.score_modifiers_lexical) |
| 154 | + self.assertEqual([score_modifier], query.score_modifiers_tensor) |
| 155 | + self.assertEqual(100, query.global_rerank_depth) |
| 156 | + self.assertEqual(facets, query.facets) |
| 157 | + self.assertTrue(query.track_total_hits) |
| 158 | + |
| 159 | + # Test inheritance |
| 160 | + self.assertIsInstance(query, MarqoQuery) |
| 161 | + self.assertIsInstance(query, MarqoHybridQuery) |
| 162 | + |
| 163 | + def test_required_fields(self): |
| 164 | + """Test that all required fields must be provided.""" |
| 165 | + hybrid_parameters = HybridParameters() |
| 166 | + |
| 167 | + base_params = { |
| 168 | + "index_name": "test_index", |
| 169 | + "limit": 10, |
| 170 | + "or_phrases": ["phrase1"], |
| 171 | + "and_phrases": ["phrase2"], |
| 172 | + "hybrid_parameters": hybrid_parameters |
| 173 | + } |
| 174 | + |
| 175 | + required_fields = [ |
| 176 | + "index_name", "limit", "or_phrases", "and_phrases", |
| 177 | + "hybrid_parameters" |
| 178 | + ] |
| 179 | + |
| 180 | + for required_field in required_fields: |
| 181 | + with self.subTest(missing_field=required_field): |
| 182 | + params = base_params.copy() |
| 183 | + del params[required_field] |
| 184 | + |
| 185 | + with self.assertRaises(ValidationError) as context: |
| 186 | + MarqoHybridQuery(**params) |
| 187 | + |
| 188 | + self.assertIn(required_field, str(context.exception)) |
| 189 | + |
| 190 | + def test_score_modifiers_validation_with_rrf(self): |
| 191 | + """Test that score_modifiers is allowed with RRF ranking method.""" |
| 192 | + score_modifier = ScoreModifier( |
| 193 | + field="test_field", |
| 194 | + weight=1.5, |
| 195 | + type=ScoreModifierType.Multiply |
| 196 | + ) |
| 197 | + |
| 198 | + hybrid_parameters = HybridParameters( |
| 199 | + rankingMethod=RankingMethod.RRF |
| 200 | + ) |
| 201 | + |
| 202 | + # Should work with RRF |
| 203 | + query = MarqoHybridQuery( |
| 204 | + index_name="test_index", |
| 205 | + limit=10, |
| 206 | + or_phrases=["phrase1"], |
| 207 | + and_phrases=["phrase2"], |
| 208 | + hybrid_parameters=hybrid_parameters, |
| 209 | + score_modifiers=[score_modifier] |
| 210 | + ) |
| 211 | + self.assertEqual([score_modifier], query.score_modifiers) |
| 212 | + |
| 213 | + def test_score_modifiers_validation_with_non_rrf(self): |
| 214 | + """Test that score_modifiers raises error with non-RRF ranking methods.""" |
| 215 | + score_modifier = ScoreModifier( |
| 216 | + field="test_field", |
| 217 | + weight=1.5, |
| 218 | + type=ScoreModifierType.Multiply |
| 219 | + ) |
| 220 | + |
| 221 | + non_rrf_methods = [RankingMethod.Tensor, RankingMethod.Lexical] |
| 222 | + |
| 223 | + for ranking_method in non_rrf_methods: |
| 224 | + with self.subTest(ranking_method=ranking_method): |
| 225 | + hybrid_parameters = HybridParameters( |
| 226 | + retrievalMethod=RetrievalMethod.Tensor, |
| 227 | + rankingMethod=ranking_method |
| 228 | + ) |
| 229 | + |
| 230 | + with self.assertRaises(ValidationError) as context: |
| 231 | + MarqoHybridQuery( |
| 232 | + index_name="test_index", |
| 233 | + limit=10, |
| 234 | + or_phrases=["phrase1"], |
| 235 | + and_phrases=["phrase2"], |
| 236 | + hybrid_parameters=hybrid_parameters, |
| 237 | + score_modifiers=[score_modifier] |
| 238 | + ) |
| 239 | + |
| 240 | + error_msg = ("'scoreModifiers' is only supported for hybrid " |
| 241 | + "search if 'rankingMethod' is 'RRF'") |
| 242 | + self.assertIn(error_msg, str(context.exception)) |
| 243 | + |
| 244 | + def test_searchable_attributes_validation_fails(self): |
| 245 | + """Test that searchable_attributes cannot be used in hybrid search.""" |
| 246 | + hybrid_parameters = HybridParameters() |
| 247 | + |
| 248 | + with self.assertRaises(ValidationError) as context: |
| 249 | + MarqoHybridQuery( |
| 250 | + index_name="test_index", |
| 251 | + limit=10, |
| 252 | + or_phrases=["phrase1"], |
| 253 | + and_phrases=["phrase2"], |
| 254 | + hybrid_parameters=hybrid_parameters, |
| 255 | + searchable_attributes=["field1", "field2"] |
| 256 | + ) |
| 257 | + |
| 258 | + self.assertIn( |
| 259 | + "'searchableAttributes' cannot be used for hybrid search", |
| 260 | + str(context.exception) |
| 261 | + ) |
0 commit comments