1
+ import unittest
2
+ from pydantic .v1 import ValidationError
3
+
4
+ from marqo .tensor_search .models .api_models import SearchQuery , CustomVectorQuery
5
+ from marqo .tensor_search .enums import SearchMethod
6
+ from marqo .core .models .hybrid_parameters import HybridParameters , RankingMethod , RetrievalMethod
7
+ from marqo .core .models .facets_parameters import FacetsParameters , FieldFacetsConfiguration
8
+ from marqo .tensor_search .models .search import SearchContext , SearchContextTensor
9
+
10
+
11
+ class TestSearchQuery (unittest .TestCase ):
12
+
13
+ def test_search_query_with_all_parameters (self ):
14
+ """Test SearchQuery creation with all parameters set to valid values."""
15
+ custom_vector_query = CustomVectorQuery (
16
+ customVector = CustomVectorQuery .CustomVector (
17
+ content = "test content" ,
18
+ vector = [0.1 , 0.2 , 0.3 , 0.4 ]
19
+ )
20
+ )
21
+
22
+ hybrid_parameters = HybridParameters (
23
+ retrievalMethod = RetrievalMethod .Disjunction ,
24
+ rankingMethod = RankingMethod .RRF ,
25
+ alpha = 0.7 ,
26
+ rrfK = 100
27
+ )
28
+
29
+ facets = FacetsParameters (
30
+ fields = {
31
+ "category" : FieldFacetsConfiguration (type = "string" , maxResults = 10 )
32
+ }
33
+ )
34
+
35
+ context = SearchContext (
36
+ tensor = [SearchContextTensor (vector = [0.1 , 0.2 ], weight = 1.0 )]
37
+ )
38
+
39
+ search_query = SearchQuery (
40
+ q = custom_vector_query ,
41
+ searchableAttributes = ["title" , "description" ],
42
+ searchMethod = SearchMethod .HYBRID ,
43
+ limit = 20 ,
44
+ offset = 5 ,
45
+ rerankDepth = 100 ,
46
+ efSearch = 200 ,
47
+ approximate = True ,
48
+ approximateThreshold = 0.85 ,
49
+ showHighlights = False ,
50
+ reRanker = "test_reranker" ,
51
+ filter = "category:electronics" ,
52
+ attributesToRetrieve = ["title" , "price" ],
53
+ boost = {"title" : 1.5 },
54
+ mediaDownloadHeaders = {"Authorization" : "Bearer token" },
55
+ context = context ,
56
+ textQueryPrefix = "search:" ,
57
+ hybridParameters = hybrid_parameters ,
58
+ facets = facets ,
59
+ trackTotalHits = True
60
+ )
61
+
62
+ # Verify key attributes
63
+ self .assertEqual (search_query .searchMethod , SearchMethod .HYBRID )
64
+ self .assertEqual (search_query .limit , 20 )
65
+ self .assertEqual (search_query .approximateThreshold , 0.85 )
66
+ self .assertIsNotNone (search_query .hybridParameters )
67
+ self .assertIsNotNone (search_query .facets )
68
+
69
+ def test_search_query_required_parameters_only (self ):
70
+ """Test SearchQuery with only required parameters."""
71
+ # For tensor search, either q or context is required
72
+ search_query = SearchQuery (
73
+ q = "test query" ,
74
+ searchMethod = SearchMethod .TENSOR
75
+ )
76
+
77
+ # Verify defaults
78
+ self .assertEqual (search_query .searchMethod , SearchMethod .TENSOR )
79
+ self .assertEqual (search_query .limit , 10 )
80
+ self .assertEqual (search_query .offset , 0 )
81
+ self .assertTrue (search_query .showHighlights )
82
+ self .assertIsNone (search_query .hybridParameters )
83
+
84
+ def test_hybrid_parameters_validation (self ):
85
+ """Test that hybrid parameters are only allowed for hybrid search."""
86
+ hybrid_parameters = HybridParameters (
87
+ retrievalMethod = RetrievalMethod .Disjunction ,
88
+ rankingMethod = RankingMethod .RRF
89
+ )
90
+
91
+ # Should fail for tensor search
92
+ with self .assertRaises (ValidationError ) as cm :
93
+ SearchQuery (
94
+ q = "test" ,
95
+ searchMethod = SearchMethod .TENSOR ,
96
+ hybridParameters = hybrid_parameters
97
+ )
98
+ self .assertIn ("Hybrid parameters can only be provided for 'HYBRID' search" , str (cm .exception ))
99
+
100
+ def test_facets_validation (self ):
101
+ """Test that facets are only allowed for hybrid search."""
102
+ facets = FacetsParameters (
103
+ fields = {"category" : FieldFacetsConfiguration (type = "string" )}
104
+ )
105
+
106
+ # Should fail for tensor search
107
+ with self .assertRaises (ValidationError ) as cm :
108
+ SearchQuery (
109
+ q = "test" ,
110
+ searchMethod = SearchMethod .TENSOR ,
111
+ facets = facets
112
+ )
113
+ self .assertIn ("Facets can only be provided for 'HYBRID' search" , str (cm .exception ))
114
+
115
+ def test_track_total_hits_validation (self ):
116
+ """Test that trackTotalHits is only allowed for hybrid search."""
117
+ # Should fail for tensor search
118
+ with self .assertRaises (ValidationError ) as cm :
119
+ SearchQuery (
120
+ q = "test" ,
121
+ searchMethod = SearchMethod .TENSOR ,
122
+ trackTotalHits = True
123
+ )
124
+ self .assertIn ("trackTotalHits can only be provided for 'HYBRID' search" , str (cm .exception ))
125
+
126
+ def test_approximate_threshold_validation (self ):
127
+ """Test approximate threshold validation."""
128
+ # Should fail for lexical search
129
+ with self .assertRaises (ValidationError ) as cm :
130
+ SearchQuery (
131
+ q = "test" ,
132
+ searchMethod = SearchMethod .LEXICAL ,
133
+ approximateThreshold = 0.5
134
+ )
135
+ self .assertIn ("'approximateThreshold' is only valid for 'HYBRID' and 'TENSOR' search methods" , str (cm .exception ))
136
+
137
+ # Should fail when approximate=False
138
+ with self .assertRaises (ValidationError ) as cm :
139
+ SearchQuery (
140
+ q = "test" ,
141
+ searchMethod = SearchMethod .TENSOR ,
142
+ approximate = False ,
143
+ approximateThreshold = 0.5
144
+ )
145
+ self .assertIn ("'approximateThreshold' cannot be set when 'approximate' is False" , str (cm .exception ))
146
+
147
+ # Should fail for invalid range
148
+ with self .assertRaises (ValidationError ) as cm :
149
+ SearchQuery (
150
+ q = "test" ,
151
+ searchMethod = SearchMethod .TENSOR ,
152
+ approximateThreshold = 1.5
153
+ )
154
+ self .assertIn ("'approximateThreshold' must be between 0 and 1" , str (cm .exception ))
155
+
156
+ def test_query_and_context_validation (self ):
157
+ """Test validation of query and context requirements."""
158
+ # Lexical search requires query
159
+ with self .assertRaises (ValidationError ) as cm :
160
+ SearchQuery (searchMethod = SearchMethod .LEXICAL )
161
+ self .assertIn ("Query(q) is required for lexical search" , str (cm .exception ))
162
+
163
+ # Tensor search requires either query or context
164
+ with self .assertRaises (ValidationError ) as cm :
165
+ SearchQuery (searchMethod = SearchMethod .TENSOR )
166
+ self .assertIn ("One of Query(q) or context is required for TENSOR search" , str (cm .exception ))
167
+
168
+ def test_rerank_depth_validation (self ):
169
+ """Test rerank depth validation."""
170
+ # Should fail for lexical search
171
+ with self .assertRaises (ValidationError ) as cm :
172
+ SearchQuery (
173
+ q = "test" ,
174
+ searchMethod = SearchMethod .LEXICAL ,
175
+ rerankDepth = 10
176
+ )
177
+ self .assertIn ("'rerankDepth' is currently not supported for 'LEXICAL' search method" , str (cm .exception ))
178
+
179
+ # Should fail for negative values
180
+ with self .assertRaises (ValidationError ) as cm :
181
+ SearchQuery (
182
+ q = "test" ,
183
+ searchMethod = SearchMethod .TENSOR ,
184
+ rerankDepth = - 1
185
+ )
186
+ self .assertIn ("rerankDepth cannot be negative" , str (cm .exception ))
187
+
188
+ def test_image_download_headers_validation (self ):
189
+ """Test validation of image download headers."""
190
+ # Should fail when both headers are set
191
+ with self .assertRaises (ValidationError ) as cm :
192
+ SearchQuery (
193
+ q = "test" ,
194
+ image_download_headers = {"header1" : "value1" },
195
+ mediaDownloadHeaders = {"header2" : "value2" }
196
+ )
197
+ self .assertIn ("Cannot set both imageDownloadHeaders" , str (cm .exception ))
198
+
199
+ # Should work when imageDownloadHeaders is set and mediaDownloadHeaders is copied
200
+ search_query = SearchQuery (
201
+ q = "test" ,
202
+ image_download_headers = {"Authorization" : "Bearer token" }
203
+ )
204
+ self .assertEqual (search_query .mediaDownloadHeaders , {"Authorization" : "Bearer token" })
205
+
206
+
207
+ if __name__ == '__main__' :
208
+ unittest .main ()
0 commit comments