1
+ import pytest
2
+ import numpy as np
3
+ from unittest .mock import Mock , patch
4
+ from pymilvus .client .search_iterator import SearchIteratorV2
5
+ from pymilvus .client .abstract import Hits , SearchResult
6
+ from pymilvus .exceptions import ParamError , ServerVersionIncompatibleException
7
+ from pymilvus .grpc_gen import schema_pb2
8
+ from pymilvus .orm .constants import UNLIMITED
9
+
10
+ class TestSearchIteratorV2 :
11
+ @pytest .fixture
12
+ def mock_connection (self ):
13
+ connection = Mock ()
14
+ connection .describe_collection .return_value = {"collection_id" : "test_id" }
15
+ return connection
16
+
17
+ @pytest .fixture
18
+ def search_data (self ):
19
+ np .random .seed (42 )
20
+ return np .random .random ((1 , 8 )).tolist ()
21
+
22
+ def create_mock_search_result (self , num_results = 10 ):
23
+ # Create mock search results
24
+ mock_ids = schema_pb2 .IDs (
25
+ int_id = schema_pb2 .LongArray (data = [i for i in range (num_results )])
26
+ )
27
+ result = schema_pb2 .SearchResultData (
28
+ num_queries = 1 ,
29
+ top_k = num_results ,
30
+ scores = [1.0 * i for i in range (num_results )],
31
+ ids = mock_ids ,
32
+ topks = [num_results ],
33
+ )
34
+
35
+ # Create mock iterator info
36
+ result .search_iterator_v2_results .token = "test_token"
37
+ result .search_iterator_v2_results .last_bound = 0.5
38
+
39
+ return SearchResult (result )
40
+
41
+ def test_init_basic (self , mock_connection , search_data ):
42
+ iterator = SearchIteratorV2 (
43
+ connection = mock_connection ,
44
+ collection_name = "test_collection" ,
45
+ data = search_data ,
46
+ batch_size = 100
47
+ )
48
+
49
+ assert iterator ._batch_size == 100
50
+ assert iterator ._left_res_cnt is None
51
+ assert iterator ._collection_id == "test_id"
52
+
53
+ def test_init_with_limit (self , mock_connection , search_data ):
54
+ iterator = SearchIteratorV2 (
55
+ connection = mock_connection ,
56
+ collection_name = "test_collection" ,
57
+ data = search_data ,
58
+ batch_size = 100 ,
59
+ limit = 50
60
+ )
61
+
62
+ assert iterator ._left_res_cnt == 50
63
+
64
+ def test_invalid_batch_size (self , mock_connection , search_data ):
65
+ with pytest .raises (ParamError ):
66
+ SearchIteratorV2 (
67
+ connection = mock_connection ,
68
+ collection_name = "test_collection" ,
69
+ data = search_data ,
70
+ batch_size = - 1
71
+ )
72
+
73
+ def test_invalid_offset (self , mock_connection , search_data ):
74
+ with pytest .raises (ParamError ):
75
+ SearchIteratorV2 (
76
+ connection = mock_connection ,
77
+ collection_name = "test_collection" ,
78
+ data = search_data ,
79
+ batch_size = 100 ,
80
+ ** {"offset" : 10 }
81
+ )
82
+
83
+ def test_multiple_vectors_error (self , mock_connection ):
84
+ with pytest .raises (ParamError ):
85
+ SearchIteratorV2 (
86
+ connection = mock_connection ,
87
+ collection_name = "test_collection" ,
88
+ data = [[1 , 2 ], [3 , 4 ]], # Multiple vectors
89
+ batch_size = 100
90
+ )
91
+
92
+ @patch ('pymilvus.client.search_iterator.SearchIteratorV2._probe_for_compability' )
93
+ def test_next_without_external_filter (self , mock_probe , mock_connection , search_data ):
94
+ mock_connection .search .return_value = self .create_mock_search_result ()
95
+ iterator = SearchIteratorV2 (
96
+ connection = mock_connection ,
97
+ collection_name = "test_collection" ,
98
+ data = search_data ,
99
+ batch_size = 100
100
+ )
101
+
102
+ result = iterator .next ()
103
+ assert result is not None
104
+ assert len (result ) == 10 # Number of results from mock
105
+
106
+ @patch ('pymilvus.client.search_iterator.SearchIteratorV2._probe_for_compability' )
107
+ def test_next_with_limit (self , mock_probe , mock_connection , search_data ):
108
+ mock_connection .search .return_value = self .create_mock_search_result ()
109
+ iterator = SearchIteratorV2 (
110
+ connection = mock_connection ,
111
+ collection_name = "test_collection" ,
112
+ data = search_data ,
113
+ batch_size = 100 ,
114
+ limit = 5
115
+ )
116
+
117
+ result = iterator .next ()
118
+ assert result is not None
119
+ assert len (result ) == 5 # Limited to 5 results
120
+
121
+ def test_server_incompatible (self , mock_connection , search_data ):
122
+ # Mock search result with empty token
123
+ mock_result = self .create_mock_search_result ()
124
+ mock_result ._search_iterator_v2_results .token = ""
125
+ mock_connection .search .return_value = mock_result
126
+
127
+ with pytest .raises (ServerVersionIncompatibleException ):
128
+ SearchIteratorV2 (
129
+ connection = mock_connection ,
130
+ collection_name = "test_collection" ,
131
+ data = search_data ,
132
+ batch_size = 100
133
+ )
134
+
135
+ @patch ('pymilvus.client.search_iterator.SearchIteratorV2._probe_for_compability' )
136
+ def test_external_filter (self , mock_probe , mock_connection , search_data ):
137
+ mock_connection .search .return_value = self .create_mock_search_result ()
138
+
139
+ def filter_func (hits ):
140
+ return [hit for hit in hits if hit .distance < 5.0 ]
141
+
142
+ iterator = SearchIteratorV2 (
143
+ connection = mock_connection ,
144
+ collection_name = "test_collection" ,
145
+ data = search_data ,
146
+ batch_size = 100 ,
147
+ external_filter_func = filter_func
148
+ )
149
+
150
+ result = iterator .next ()
151
+ assert result is not None
152
+ assert all (hit .distance < 5.0 for hit in result )
153
+
154
+ @patch ('pymilvus.client.search_iterator.SearchIteratorV2._probe_for_compability' )
155
+ def test_filter_and_external_filter (self , mock_probe , mock_connection , search_data ):
156
+ # Create mock search result with field values
157
+ mock_result = self .create_mock_search_result ()
158
+ for hit in mock_result [0 ]:
159
+ hit .entity .field_1 = hit .id % 2
160
+ mock_result [0 ] = list (filter (lambda x : x .entity .field_1 < 5 , mock_result [0 ]))
161
+ mock_connection .search .return_value = mock_result
162
+
163
+ expr_filter = "field_1 < 5"
164
+
165
+ def filter_func (hits ):
166
+ return [hit for hit in hits if hit .distance < 5.0 ] # Only hits with distance < 5.0 should pass
167
+
168
+ iterator = SearchIteratorV2 (
169
+ connection = mock_connection ,
170
+ collection_name = "test_collection" ,
171
+ data = search_data ,
172
+ batch_size = 100 ,
173
+ filter = expr_filter ,
174
+ external_filter_func = filter_func
175
+ )
176
+
177
+ result = iterator .next ()
178
+ assert result is not None
179
+ assert all (hit .distance < 5.0 and hit .entity .field_1 < 5 for hit in result )
0 commit comments