Skip to content

Commit 1238cf0

Browse files
committed
Add more unit tests
1 parent 86bd7c1 commit 1238cf0

File tree

3 files changed

+365
-1
lines changed

3 files changed

+365
-1
lines changed

.cursor/rules/run-tests.mdc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,4 +9,4 @@ alwaysApply: true
99
- If running integ or API tests, make sure Vespa is running vis docker ps. If not running, use
1010
python scripts/vespa_local/vespa_local.py full_start to run Vespa first.
1111
- To run API tests, first run Marqo API in one process by running src/marqo/tensor_search/api.py using PYTHONPATH=./src MARQO_ENABLE_BATCH_APIS=true MARQO_MODE=COMBINED MARQO_MODELS_TO_PRELOAD="[]". While the API is running, run API tests via pytest using PYTHONPATH=./tests/api_tests/v1/tests/api_tests . If Marqo API fails to run, stop. Terminate Marqo API when done.
12-
- Unit tests most follow the same package hierarchy as the code they test.
12+
- Unit tests must follow the same package hierarchy as the code they test.
Lines changed: 156 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,156 @@
1+
import time
2+
import unittest
3+
from typing import List
4+
5+
from marqo.core.models.marqo_query import MarqoTensorQuery, MarqoHybridQuery
6+
from marqo.core.models.marqo_index import (
7+
StructuredMarqoIndex, Model, TextPreProcessing, TextSplitMethod,
8+
ImagePreProcessing, HnswConfig, DistanceMetric, Field, FieldType,
9+
FieldFeature, TensorField
10+
)
11+
from marqo.core.models.hybrid_parameters import (
12+
HybridParameters, RankingMethod, RetrievalMethod
13+
)
14+
from marqo.core.structured_vespa_index.structured_vespa_index import StructuredVespaIndex
15+
16+
17+
class TestStructuredVespaIndexToVespaQuery(unittest.TestCase):
18+
19+
def setUp(self):
20+
"""Set up test fixtures with a structured index that supports both tensor and lexical search."""
21+
# Create a structured index with both tensor and lexical fields
22+
marqo_index = self._create_structured_marqo_index(
23+
name='test_index',
24+
text_field_names=['title', 'description'],
25+
tensor_field_names=['title', 'description']
26+
)
27+
self.vespa_index = StructuredVespaIndex(marqo_index)
28+
29+
def _create_structured_marqo_index(
30+
self,
31+
name: str,
32+
text_field_names: List[str] = [],
33+
tensor_field_names: List[str] = []
34+
) -> StructuredMarqoIndex:
35+
"""Helper method to create a structured Marqo index for testing."""
36+
fields = []
37+
38+
# Add text fields with lexical search and filter capabilities
39+
for field_name in text_field_names:
40+
fields.append(
41+
Field(
42+
name=field_name,
43+
type=FieldType.Text,
44+
features=[FieldFeature.LexicalSearch, FieldFeature.Filter],
45+
lexical_field_name=f'{field_name}_lexical',
46+
filter_field_name=f'{field_name}_filter'
47+
)
48+
)
49+
50+
# Add tensor fields
51+
tensor_fields = []
52+
for field_name in tensor_field_names:
53+
tensor_fields.append(
54+
TensorField(
55+
name=field_name,
56+
embeddings_field_name=f'{field_name}_embeddings',
57+
chunk_field_name=f'{field_name}_chunks'
58+
)
59+
)
60+
61+
return StructuredMarqoIndex(
62+
name=name,
63+
schema_name=name,
64+
model=Model(name='hf/all_datasets_v4_MiniLM-L6'),
65+
normalize_embeddings=True,
66+
distance_metric=DistanceMetric.Angular,
67+
vector_numeric_type='float',
68+
hnsw_config=HnswConfig(ef_construction=100, m=16),
69+
marqo_version='2.12.0', # Version that supports hybrid search
70+
created_at=time.time(),
71+
updated_at=time.time(),
72+
fields=fields,
73+
tensor_fields=tensor_fields,
74+
text_preprocessing=TextPreProcessing(
75+
split_length=2,
76+
split_overlap=0,
77+
split_method=TextSplitMethod.Sentence
78+
),
79+
image_preprocessing=ImagePreProcessing(
80+
patch_method=None
81+
)
82+
)
83+
84+
def test_to_vespa_query_tensor_mode_approximate_threshold(self):
85+
"""Test that to_vespa_query correctly sets approximate threshold for tensor queries."""
86+
threshold_values = [0.75, 0.85, 0.95, None]
87+
88+
for threshold in threshold_values:
89+
with self.subTest(approximate_threshold=threshold):
90+
marqo_query = MarqoTensorQuery(
91+
index_name='test_index',
92+
limit=10,
93+
offset=0,
94+
vector_query=[0.1, 0.2, 0.3, 0.4],
95+
approximate_threshold=threshold,
96+
approximate=True
97+
)
98+
99+
vespa_query = self.vespa_index.to_vespa_query(marqo_query)
100+
101+
if threshold is not None:
102+
# Verify approximate threshold is set correctly
103+
self.assertEqual(vespa_query['ranking.matching.approximateThreshold'], threshold)
104+
else:
105+
# When threshold is None, it should not be included in the query
106+
self.assertNotIn('ranking.matching.approximateThreshold', vespa_query)
107+
108+
# Verify other key fields are present
109+
self.assertIn('yql', vespa_query)
110+
self.assertIn('ranking', vespa_query)
111+
self.assertEqual(vespa_query['hits'], 10)
112+
113+
def test_to_vespa_query_hybrid_mode_approximate_threshold(self):
114+
"""Test that to_vespa_query correctly sets approximate threshold for hybrid queries."""
115+
threshold_values = [0.70, 0.80, 0.90, None]
116+
117+
for threshold in threshold_values:
118+
with self.subTest(approximate_threshold=threshold):
119+
hybrid_parameters = HybridParameters(
120+
retrievalMethod=RetrievalMethod.Disjunction,
121+
rankingMethod=RankingMethod.RRF,
122+
alpha=0.7,
123+
rrfK=100
124+
)
125+
126+
marqo_query = MarqoHybridQuery(
127+
index_name='test_index',
128+
limit=15,
129+
offset=0,
130+
vector_query=[0.2, 0.3, 0.4, 0.5],
131+
or_phrases=['search', 'query'],
132+
and_phrases=['required'],
133+
hybrid_parameters=hybrid_parameters,
134+
approximate_threshold=threshold,
135+
approximate=True
136+
)
137+
138+
vespa_query = self.vespa_index.to_vespa_query(marqo_query)
139+
140+
if threshold is not None:
141+
# Verify approximate threshold is set correctly
142+
self.assertEqual(vespa_query['ranking.matching.approximateThreshold'], threshold)
143+
else:
144+
# When threshold is None, it should not be included in the query
145+
self.assertNotIn('ranking.matching.approximateThreshold', vespa_query)
146+
147+
# Verify hybrid-specific fields are present
148+
self.assertEqual(vespa_query['hits'], 15)
149+
self.assertIn('searchChain', vespa_query)
150+
self.assertEqual(vespa_query['searchChain'], 'marqo')
151+
self.assertIn('marqo__hybrid.retrievalMethod', vespa_query)
152+
self.assertIn('marqo__hybrid.rankingMethod', vespa_query)
153+
154+
155+
if __name__ == '__main__':
156+
unittest.main()
Lines changed: 208 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,208 @@
1+
import unittest
2+
from pydantic.v1 import ValidationError
3+
4+
from marqo.tensor_search.models.api_models import SearchQuery, CustomVectorQuery
5+
from marqo.tensor_search.enums import SearchMethod
6+
from marqo.core.models.hybrid_parameters import HybridParameters, RankingMethod, RetrievalMethod
7+
from marqo.core.models.facets_parameters import FacetsParameters, FieldFacetsConfiguration
8+
from marqo.tensor_search.models.search import SearchContext, SearchContextTensor
9+
10+
11+
class TestSearchQuery(unittest.TestCase):
12+
13+
def test_search_query_with_all_parameters(self):
14+
"""Test SearchQuery creation with all parameters set to valid values."""
15+
custom_vector_query = CustomVectorQuery(
16+
customVector=CustomVectorQuery.CustomVector(
17+
content="test content",
18+
vector=[0.1, 0.2, 0.3, 0.4]
19+
)
20+
)
21+
22+
hybrid_parameters = HybridParameters(
23+
retrievalMethod=RetrievalMethod.Disjunction,
24+
rankingMethod=RankingMethod.RRF,
25+
alpha=0.7,
26+
rrfK=100
27+
)
28+
29+
facets = FacetsParameters(
30+
fields={
31+
"category": FieldFacetsConfiguration(type="string", maxResults=10)
32+
}
33+
)
34+
35+
context = SearchContext(
36+
tensor=[SearchContextTensor(vector=[0.1, 0.2], weight=1.0)]
37+
)
38+
39+
search_query = SearchQuery(
40+
q=custom_vector_query,
41+
searchableAttributes=["title", "description"],
42+
searchMethod=SearchMethod.HYBRID,
43+
limit=20,
44+
offset=5,
45+
rerankDepth=100,
46+
efSearch=200,
47+
approximate=True,
48+
approximateThreshold=0.85,
49+
showHighlights=False,
50+
reRanker="test_reranker",
51+
filter="category:electronics",
52+
attributesToRetrieve=["title", "price"],
53+
boost={"title": 1.5},
54+
mediaDownloadHeaders={"Authorization": "Bearer token"},
55+
context=context,
56+
textQueryPrefix="search:",
57+
hybridParameters=hybrid_parameters,
58+
facets=facets,
59+
trackTotalHits=True
60+
)
61+
62+
# Verify key attributes
63+
self.assertEqual(search_query.searchMethod, SearchMethod.HYBRID)
64+
self.assertEqual(search_query.limit, 20)
65+
self.assertEqual(search_query.approximateThreshold, 0.85)
66+
self.assertIsNotNone(search_query.hybridParameters)
67+
self.assertIsNotNone(search_query.facets)
68+
69+
def test_search_query_required_parameters_only(self):
70+
"""Test SearchQuery with only required parameters."""
71+
# For tensor search, either q or context is required
72+
search_query = SearchQuery(
73+
q="test query",
74+
searchMethod=SearchMethod.TENSOR
75+
)
76+
77+
# Verify defaults
78+
self.assertEqual(search_query.searchMethod, SearchMethod.TENSOR)
79+
self.assertEqual(search_query.limit, 10)
80+
self.assertEqual(search_query.offset, 0)
81+
self.assertTrue(search_query.showHighlights)
82+
self.assertIsNone(search_query.hybridParameters)
83+
84+
def test_hybrid_parameters_validation(self):
85+
"""Test that hybrid parameters are only allowed for hybrid search."""
86+
hybrid_parameters = HybridParameters(
87+
retrievalMethod=RetrievalMethod.Disjunction,
88+
rankingMethod=RankingMethod.RRF
89+
)
90+
91+
# Should fail for tensor search
92+
with self.assertRaises(ValidationError) as cm:
93+
SearchQuery(
94+
q="test",
95+
searchMethod=SearchMethod.TENSOR,
96+
hybridParameters=hybrid_parameters
97+
)
98+
self.assertIn("Hybrid parameters can only be provided for 'HYBRID' search", str(cm.exception))
99+
100+
def test_facets_validation(self):
101+
"""Test that facets are only allowed for hybrid search."""
102+
facets = FacetsParameters(
103+
fields={"category": FieldFacetsConfiguration(type="string")}
104+
)
105+
106+
# Should fail for tensor search
107+
with self.assertRaises(ValidationError) as cm:
108+
SearchQuery(
109+
q="test",
110+
searchMethod=SearchMethod.TENSOR,
111+
facets=facets
112+
)
113+
self.assertIn("Facets can only be provided for 'HYBRID' search", str(cm.exception))
114+
115+
def test_track_total_hits_validation(self):
116+
"""Test that trackTotalHits is only allowed for hybrid search."""
117+
# Should fail for tensor search
118+
with self.assertRaises(ValidationError) as cm:
119+
SearchQuery(
120+
q="test",
121+
searchMethod=SearchMethod.TENSOR,
122+
trackTotalHits=True
123+
)
124+
self.assertIn("trackTotalHits can only be provided for 'HYBRID' search", str(cm.exception))
125+
126+
def test_approximate_threshold_validation(self):
127+
"""Test approximate threshold validation."""
128+
# Should fail for lexical search
129+
with self.assertRaises(ValidationError) as cm:
130+
SearchQuery(
131+
q="test",
132+
searchMethod=SearchMethod.LEXICAL,
133+
approximateThreshold=0.5
134+
)
135+
self.assertIn("'approximateThreshold' is only valid for 'HYBRID' and 'TENSOR' search methods", str(cm.exception))
136+
137+
# Should fail when approximate=False
138+
with self.assertRaises(ValidationError) as cm:
139+
SearchQuery(
140+
q="test",
141+
searchMethod=SearchMethod.TENSOR,
142+
approximate=False,
143+
approximateThreshold=0.5
144+
)
145+
self.assertIn("'approximateThreshold' cannot be set when 'approximate' is False", str(cm.exception))
146+
147+
# Should fail for invalid range
148+
with self.assertRaises(ValidationError) as cm:
149+
SearchQuery(
150+
q="test",
151+
searchMethod=SearchMethod.TENSOR,
152+
approximateThreshold=1.5
153+
)
154+
self.assertIn("'approximateThreshold' must be between 0 and 1", str(cm.exception))
155+
156+
def test_query_and_context_validation(self):
157+
"""Test validation of query and context requirements."""
158+
# Lexical search requires query
159+
with self.assertRaises(ValidationError) as cm:
160+
SearchQuery(searchMethod=SearchMethod.LEXICAL)
161+
self.assertIn("Query(q) is required for lexical search", str(cm.exception))
162+
163+
# Tensor search requires either query or context
164+
with self.assertRaises(ValidationError) as cm:
165+
SearchQuery(searchMethod=SearchMethod.TENSOR)
166+
self.assertIn("One of Query(q) or context is required for TENSOR search", str(cm.exception))
167+
168+
def test_rerank_depth_validation(self):
169+
"""Test rerank depth validation."""
170+
# Should fail for lexical search
171+
with self.assertRaises(ValidationError) as cm:
172+
SearchQuery(
173+
q="test",
174+
searchMethod=SearchMethod.LEXICAL,
175+
rerankDepth=10
176+
)
177+
self.assertIn("'rerankDepth' is currently not supported for 'LEXICAL' search method", str(cm.exception))
178+
179+
# Should fail for negative values
180+
with self.assertRaises(ValidationError) as cm:
181+
SearchQuery(
182+
q="test",
183+
searchMethod=SearchMethod.TENSOR,
184+
rerankDepth=-1
185+
)
186+
self.assertIn("rerankDepth cannot be negative", str(cm.exception))
187+
188+
def test_image_download_headers_validation(self):
189+
"""Test validation of image download headers."""
190+
# Should fail when both headers are set
191+
with self.assertRaises(ValidationError) as cm:
192+
SearchQuery(
193+
q="test",
194+
image_download_headers={"header1": "value1"},
195+
mediaDownloadHeaders={"header2": "value2"}
196+
)
197+
self.assertIn("Cannot set both imageDownloadHeaders", str(cm.exception))
198+
199+
# Should work when imageDownloadHeaders is set and mediaDownloadHeaders is copied
200+
search_query = SearchQuery(
201+
q="test",
202+
image_download_headers={"Authorization": "Bearer token"}
203+
)
204+
self.assertEqual(search_query.mediaDownloadHeaders, {"Authorization": "Bearer token"})
205+
206+
207+
if __name__ == '__main__':
208+
unittest.main()

0 commit comments

Comments
 (0)