Skip to content

Commit 00f31bc

Browse files
committed
Add tests for Marqo Query
1 parent 113157f commit 00f31bc

File tree

2 files changed

+262
-0
lines changed

2 files changed

+262
-0
lines changed

.cursor/rules/run-tests.mdc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,3 +9,4 @@ alwaysApply: true
99
- If running integ or API tests, make sure Vespa is running vis docker ps. If not running, use
1010
python scripts/vespa_local/vespa_local.py full_start to run Vespa first.
1111
- To run API tests, first run Marqo API in one process by running src/marqo/tensor_search/api.py using PYTHONPATH=./src MARQO_ENABLE_BATCH_APIS=true MARQO_MODE=COMBINED MARQO_MODELS_TO_PRELOAD="[]". While the API is running, run API tests via pytest using PYTHONPATH=./tests/api_tests/v1/tests/api_tests . If Marqo API fails to run, stop. Terminate Marqo API when done.
12+
- Unit tests most follow the same package hierarchy as the code they test.
Lines changed: 261 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,261 @@
1+
from unittest import TestCase
2+
from pydantic.v1 import ValidationError
3+
4+
from marqo.core.models.marqo_query import (
5+
MarqoTensorQuery, MarqoQuery, MarqoHybridQuery
6+
)
7+
from marqo.core.models.score_modifier import ScoreModifier, ScoreModifierType
8+
from marqo.core.search.search_filter import SearchFilter, EqualityTerm
9+
from marqo.core.models.hybrid_parameters import (
10+
HybridParameters, RankingMethod, RetrievalMethod
11+
)
12+
from marqo.core.models.facets_parameters import (
13+
FacetsParameters, FieldFacetsConfiguration
14+
)
15+
16+
17+
class TestMarqoTensorQuery(TestCase):
18+
19+
def test_creation_with_all_values(self):
20+
"""Test creating MarqoTensorQuery with all possible values."""
21+
score_modifier = ScoreModifier(
22+
field="test_field",
23+
weight=1.5,
24+
type=ScoreModifierType.Multiply
25+
)
26+
27+
filter_obj = SearchFilter(
28+
EqualityTerm("field1", "value1", "field1:value1")
29+
)
30+
31+
query = MarqoTensorQuery(
32+
index_name="test_index",
33+
limit=20,
34+
offset=5,
35+
searchable_attributes=["field1", "field2"],
36+
attributes_to_retrieve=["field1", "field3"],
37+
filter=filter_obj,
38+
score_modifiers=[score_modifier],
39+
expose_facets=True,
40+
vector_query=[0.1, 0.2, 0.3, 0.4],
41+
ef_search=100,
42+
approximate=False,
43+
approximate_threshold=0.95,
44+
rerank_depth_tensor=50
45+
)
46+
47+
# Verify all fields are set correctly
48+
self.assertEqual("test_index", query.index_name)
49+
self.assertEqual(20, query.limit)
50+
self.assertEqual(5, query.offset)
51+
self.assertEqual(["field1", "field2"], query.searchable_attributes)
52+
self.assertEqual(["field1", "field3"], query.attributes_to_retrieve)
53+
self.assertEqual(filter_obj, query.filter)
54+
self.assertEqual([score_modifier], query.score_modifiers)
55+
self.assertTrue(query.expose_facets)
56+
self.assertEqual([0.1, 0.2, 0.3, 0.4], query.vector_query)
57+
self.assertEqual(100, query.ef_search)
58+
self.assertFalse(query.approximate)
59+
self.assertEqual(0.95, query.approximate_threshold)
60+
self.assertEqual(50, query.rerank_depth_tensor)
61+
62+
# Test inheritance
63+
self.assertIsInstance(query, MarqoQuery)
64+
self.assertIsInstance(query, MarqoTensorQuery)
65+
66+
def test_required_fields(self):
67+
"""Test that all required fields must be provided."""
68+
base_params = {
69+
"index_name": "test_index",
70+
"limit": 10,
71+
"vector_query": [0.1, 0.2, 0.3]
72+
}
73+
74+
required_fields = ["index_name", "limit", "vector_query"]
75+
76+
for required_field in required_fields:
77+
with self.subTest(missing_field=required_field):
78+
params = base_params.copy()
79+
del params[required_field]
80+
81+
with self.assertRaises(ValidationError) as context:
82+
MarqoTensorQuery(**params)
83+
84+
self.assertIn(required_field, str(context.exception))
85+
86+
87+
class TestMarqoHybridQuery(TestCase):
88+
89+
def test_creation_with_all_values(self):
90+
"""Test creating MarqoHybridQuery with all possible values."""
91+
score_modifier = ScoreModifier(
92+
field="test_field",
93+
weight=1.5,
94+
type=ScoreModifierType.Multiply
95+
)
96+
97+
filter_obj = SearchFilter(
98+
EqualityTerm("field1", "value1", "field1:value1")
99+
)
100+
101+
hybrid_parameters = HybridParameters(
102+
retrievalMethod=RetrievalMethod.Disjunction,
103+
rankingMethod=RankingMethod.RRF,
104+
alpha=0.7,
105+
rrfK=100
106+
)
107+
108+
facets = FacetsParameters(
109+
fields={
110+
"test_field": FieldFacetsConfiguration(type="string")
111+
},
112+
maxDepth=5,
113+
maxResults=100
114+
)
115+
116+
query = MarqoHybridQuery(
117+
index_name="test_index",
118+
limit=20,
119+
offset=5,
120+
attributes_to_retrieve=["field1", "field3"],
121+
filter=filter_obj,
122+
expose_facets=True,
123+
vector_query=[0.1, 0.2, 0.3, 0.4],
124+
ef_search=100,
125+
approximate=False,
126+
approximate_threshold=0.95,
127+
rerank_depth_tensor=50,
128+
or_phrases=["phrase1", "phrase2"],
129+
and_phrases=["phrase3"],
130+
hybrid_parameters=hybrid_parameters,
131+
score_modifiers_lexical=[score_modifier],
132+
score_modifiers_tensor=[score_modifier],
133+
global_rerank_depth=100,
134+
facets=facets,
135+
track_total_hits=True
136+
)
137+
138+
# Verify all fields are set correctly
139+
self.assertEqual("test_index", query.index_name)
140+
self.assertEqual(20, query.limit)
141+
self.assertEqual(5, query.offset)
142+
self.assertEqual(["field1", "field3"], query.attributes_to_retrieve)
143+
self.assertEqual(filter_obj, query.filter)
144+
self.assertTrue(query.expose_facets)
145+
self.assertEqual([0.1, 0.2, 0.3, 0.4], query.vector_query)
146+
self.assertEqual(100, query.ef_search)
147+
self.assertFalse(query.approximate)
148+
self.assertEqual(0.95, query.approximate_threshold)
149+
self.assertEqual(50, query.rerank_depth_tensor)
150+
self.assertEqual(["phrase1", "phrase2"], query.or_phrases)
151+
self.assertEqual(["phrase3"], query.and_phrases)
152+
self.assertEqual(hybrid_parameters, query.hybrid_parameters)
153+
self.assertEqual([score_modifier], query.score_modifiers_lexical)
154+
self.assertEqual([score_modifier], query.score_modifiers_tensor)
155+
self.assertEqual(100, query.global_rerank_depth)
156+
self.assertEqual(facets, query.facets)
157+
self.assertTrue(query.track_total_hits)
158+
159+
# Test inheritance
160+
self.assertIsInstance(query, MarqoQuery)
161+
self.assertIsInstance(query, MarqoHybridQuery)
162+
163+
def test_required_fields(self):
164+
"""Test that all required fields must be provided."""
165+
hybrid_parameters = HybridParameters()
166+
167+
base_params = {
168+
"index_name": "test_index",
169+
"limit": 10,
170+
"or_phrases": ["phrase1"],
171+
"and_phrases": ["phrase2"],
172+
"hybrid_parameters": hybrid_parameters
173+
}
174+
175+
required_fields = [
176+
"index_name", "limit", "or_phrases", "and_phrases",
177+
"hybrid_parameters"
178+
]
179+
180+
for required_field in required_fields:
181+
with self.subTest(missing_field=required_field):
182+
params = base_params.copy()
183+
del params[required_field]
184+
185+
with self.assertRaises(ValidationError) as context:
186+
MarqoHybridQuery(**params)
187+
188+
self.assertIn(required_field, str(context.exception))
189+
190+
def test_score_modifiers_validation_with_rrf(self):
191+
"""Test that score_modifiers is allowed with RRF ranking method."""
192+
score_modifier = ScoreModifier(
193+
field="test_field",
194+
weight=1.5,
195+
type=ScoreModifierType.Multiply
196+
)
197+
198+
hybrid_parameters = HybridParameters(
199+
rankingMethod=RankingMethod.RRF
200+
)
201+
202+
# Should work with RRF
203+
query = MarqoHybridQuery(
204+
index_name="test_index",
205+
limit=10,
206+
or_phrases=["phrase1"],
207+
and_phrases=["phrase2"],
208+
hybrid_parameters=hybrid_parameters,
209+
score_modifiers=[score_modifier]
210+
)
211+
self.assertEqual([score_modifier], query.score_modifiers)
212+
213+
def test_score_modifiers_validation_with_non_rrf(self):
214+
"""Test that score_modifiers raises error with non-RRF ranking methods."""
215+
score_modifier = ScoreModifier(
216+
field="test_field",
217+
weight=1.5,
218+
type=ScoreModifierType.Multiply
219+
)
220+
221+
non_rrf_methods = [RankingMethod.Tensor, RankingMethod.Lexical]
222+
223+
for ranking_method in non_rrf_methods:
224+
with self.subTest(ranking_method=ranking_method):
225+
hybrid_parameters = HybridParameters(
226+
retrievalMethod=RetrievalMethod.Tensor,
227+
rankingMethod=ranking_method
228+
)
229+
230+
with self.assertRaises(ValidationError) as context:
231+
MarqoHybridQuery(
232+
index_name="test_index",
233+
limit=10,
234+
or_phrases=["phrase1"],
235+
and_phrases=["phrase2"],
236+
hybrid_parameters=hybrid_parameters,
237+
score_modifiers=[score_modifier]
238+
)
239+
240+
error_msg = ("'scoreModifiers' is only supported for hybrid "
241+
"search if 'rankingMethod' is 'RRF'")
242+
self.assertIn(error_msg, str(context.exception))
243+
244+
def test_searchable_attributes_validation_fails(self):
245+
"""Test that searchable_attributes cannot be used in hybrid search."""
246+
hybrid_parameters = HybridParameters()
247+
248+
with self.assertRaises(ValidationError) as context:
249+
MarqoHybridQuery(
250+
index_name="test_index",
251+
limit=10,
252+
or_phrases=["phrase1"],
253+
and_phrases=["phrase2"],
254+
hybrid_parameters=hybrid_parameters,
255+
searchable_attributes=["field1", "field2"]
256+
)
257+
258+
self.assertIn(
259+
"'searchableAttributes' cannot be used for hybrid search",
260+
str(context.exception)
261+
)

0 commit comments

Comments
 (0)