Skip to content

Commit 43779e8

Browse files
authored
feat: [cp][2.5] add external filter func for search iterator v2 (#2641)
issue: #2640 Milvus issue: milvus-io/milvus#39914 pr: #2639 Signed-off-by: Patrick Weizhi Xu <[email protected]> (cherry picked from commit 6112bea)
1 parent 6dc6281 commit 43779e8

File tree

3 files changed

+267
-10
lines changed

3 files changed

+267
-10
lines changed

examples/iterator/iterator.py

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from pymilvus.client.abstract import Hits
12
from pymilvus.milvus_client.milvus_client import MilvusClient
23
from pymilvus import (
34
FieldSchema, CollectionSchema, DataType,
@@ -42,14 +43,51 @@ def test_search_iterator(milvus_client: MilvusClient):
4243
while True:
4344
res = search_iterator.next()
4445
if len(res) == 0:
45-
print("query iteration finished, close")
46+
print("search iteration finished, close")
4647
search_iterator.close()
4748
break
4849
for i in range(len(res)):
4950
print(res[i])
5051
page_idx += 1
5152
print(f"page{page_idx}-------------------------")
5253

54+
def test_search_iterator_with_filter(milvus_client: MilvusClient):
55+
vector_to_search = rng.random((1, DIM), np.float32)
56+
expr = f"10 <= {AGE} <= 25"
57+
valid_ids = [1, 12, 123, 1234]
58+
59+
def external_filter_func(hits: Hits):
60+
# option 1
61+
return list(filter(lambda hit: hit.id in valid_ids, hits))
62+
63+
# option 2
64+
results = []
65+
for hit in hits:
66+
if hit.id in valid_ids:
67+
results.append(hit)
68+
return results
69+
70+
search_iterator = milvus_client.search_iterator(
71+
collection_name=collection_name,
72+
data=vector_to_search,
73+
batch_size=100,
74+
anns_field=PICTURE,
75+
filter=expr,
76+
external_filter_func=external_filter_func,
77+
output_fields=[USER_ID, AGE]
78+
)
79+
80+
page_idx = 0
81+
while True:
82+
res = search_iterator.next()
83+
if len(res) == 0:
84+
print("search iteration with external filter finished, close")
85+
search_iterator.close()
86+
break
87+
for i in range(len(res)):
88+
print(res[i])
89+
page_idx += 1
90+
print(f"page{page_idx}-------------------------")
5391

5492
def main():
5593
milvus_client = MilvusClient("http://localhost:19530")
@@ -93,6 +131,7 @@ def main():
93131
milvus_client.load_collection(collection_name)
94132
test_query_iterator(milvus_client=milvus_client)
95133
test_search_iterator(milvus_client=milvus_client)
134+
test_search_iterator_with_filter(milvus_client=milvus_client)
96135

97136

98137
if __name__ == '__main__':

pymilvus/client/search_iterator.py

Lines changed: 48 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
import logging
22
from copy import deepcopy
3-
from typing import Dict, List, Optional, Union
3+
from typing import Callable, Dict, List, Optional, Union
44

55
from pymilvus.client import entity_helper, utils
6-
from pymilvus.client.abstract import Hits
6+
from pymilvus.client.abstract import Hit, Hits
77
from pymilvus.client.constants import (
88
COLLECTION_ID,
99
GUARANTEE_TIMESTAMP,
@@ -39,6 +39,7 @@ def __init__(
3939
partition_names: Optional[List[str]] = None,
4040
anns_field: Optional[str] = None,
4141
round_decimal: Optional[int] = -1,
42+
external_filter_func: Optional[Callable[[Hits], Union[Hits, List[Hit]]]] = None,
4243
**kwargs,
4344
):
4445
self._check_params(batch_size, data, kwargs)
@@ -67,6 +68,9 @@ def __init__(
6768
GUARANTEE_TIMESTAMP: 0,
6869
**kwargs,
6970
}
71+
self._external_filter_func = external_filter_func
72+
self._cache = []
73+
self._batch_size = batch_size
7074
self._probe_for_compability(self._params)
7175

7276
def _set_up_collection_id(self, collection_name: str):
@@ -89,10 +93,8 @@ def _probe_for_compability(self, params: Dict):
8993
iter_info = self._conn.search(**dummy_params).get_search_iterator_v2_results_info()
9094
self._check_token_exists(iter_info.token)
9195

92-
def next(self):
93-
if self._left_res_cnt is not None and self._left_res_cnt <= 0:
94-
return SearchPage(None)
95-
96+
# internal next function, do not use this outside of this class
97+
def _next(self):
9698
res = self._conn.search(**self._params)
9799
iter_info = res.get_search_iterator_v2_results_info()
98100
self._check_token_exists(iter_info.token)
@@ -110,11 +112,45 @@ def next(self):
110112
"failed to set up mvccTs from milvus server, use client-side ts instead"
111113
)
112114
self._params[GUARANTEE_TIMESTAMP] = fall_back_to_latest_session_ts()
115+
return res
113116

117+
def next(self):
118+
if self._left_res_cnt is not None and self._left_res_cnt <= 0:
119+
return None
120+
121+
if self._external_filter_func is None:
122+
# return SearchPage for compability
123+
return self._wrap_return_res(self._next()[0])
124+
# the length of the results should be `batch_size` if no limit is set,
125+
# otherwise it should be the number of results left if less than `batch_size`
126+
target_len = (
127+
self._batch_size
128+
if self._left_res_cnt is None
129+
else min(self._batch_size, self._left_res_cnt)
130+
)
131+
while True:
132+
hits = self._next()[0]
133+
134+
# no more results from server
135+
if len(hits) == 0:
136+
break
137+
138+
# apply external filter
139+
if self._external_filter_func is not None:
140+
hits = self._external_filter_func(hits)
141+
142+
self._cache.extend(hits)
143+
if len(self._cache) >= target_len:
144+
break
145+
146+
# if the number of elements in cache is less than or equal to target_len,
147+
# return all results we could possibly return
148+
# if the number of elements in cache is more than target_len,
149+
# return target_len results and keep the rest for next call
150+
ret = self._cache[:target_len]
151+
del self._cache[:target_len]
114152
# return SearchPage for compability
115-
if len(res) > 0:
116-
return self._wrap_return_res(res[0])
117-
return SearchPage(None)
153+
return self._wrap_return_res(ret)
118154

119155
def close(self):
120156
pass
@@ -148,6 +184,9 @@ def _check_params(
148184
raise ParamError(message="The vector data for search cannot be empty")
149185

150186
def _wrap_return_res(self, res: Hits) -> SearchPage:
187+
if len(res) == 0:
188+
return SearchPage(None)
189+
151190
if self._left_res_cnt is None:
152191
return SearchPage(res)
153192

tests/test_search_iterator.py

Lines changed: 179 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,179 @@
1+
import pytest
2+
import numpy as np
3+
from unittest.mock import Mock, patch
4+
from pymilvus.client.search_iterator import SearchIteratorV2
5+
from pymilvus.client.abstract import Hits, SearchResult
6+
from pymilvus.exceptions import ParamError, ServerVersionIncompatibleException
7+
from pymilvus.grpc_gen import schema_pb2
8+
from pymilvus.orm.constants import UNLIMITED
9+
10+
class TestSearchIteratorV2:
11+
@pytest.fixture
12+
def mock_connection(self):
13+
connection = Mock()
14+
connection.describe_collection.return_value = {"collection_id": "test_id"}
15+
return connection
16+
17+
@pytest.fixture
18+
def search_data(self):
19+
np.random.seed(42)
20+
return np.random.random((1, 8)).tolist()
21+
22+
def create_mock_search_result(self, num_results=10):
23+
# Create mock search results
24+
mock_ids = schema_pb2.IDs(
25+
int_id=schema_pb2.LongArray(data=[i for i in range(num_results)])
26+
)
27+
result = schema_pb2.SearchResultData(
28+
num_queries=1,
29+
top_k=num_results,
30+
scores=[1.0 * i for i in range(num_results)],
31+
ids=mock_ids,
32+
topks=[num_results],
33+
)
34+
35+
# Create mock iterator info
36+
result.search_iterator_v2_results.token = "test_token"
37+
result.search_iterator_v2_results.last_bound = 0.5
38+
39+
return SearchResult(result)
40+
41+
def test_init_basic(self, mock_connection, search_data):
42+
iterator = SearchIteratorV2(
43+
connection=mock_connection,
44+
collection_name="test_collection",
45+
data=search_data,
46+
batch_size=100
47+
)
48+
49+
assert iterator._batch_size == 100
50+
assert iterator._left_res_cnt is None
51+
assert iterator._collection_id == "test_id"
52+
53+
def test_init_with_limit(self, mock_connection, search_data):
54+
iterator = SearchIteratorV2(
55+
connection=mock_connection,
56+
collection_name="test_collection",
57+
data=search_data,
58+
batch_size=100,
59+
limit=50
60+
)
61+
62+
assert iterator._left_res_cnt == 50
63+
64+
def test_invalid_batch_size(self, mock_connection, search_data):
65+
with pytest.raises(ParamError):
66+
SearchIteratorV2(
67+
connection=mock_connection,
68+
collection_name="test_collection",
69+
data=search_data,
70+
batch_size=-1
71+
)
72+
73+
def test_invalid_offset(self, mock_connection, search_data):
74+
with pytest.raises(ParamError):
75+
SearchIteratorV2(
76+
connection=mock_connection,
77+
collection_name="test_collection",
78+
data=search_data,
79+
batch_size=100,
80+
**{"offset": 10}
81+
)
82+
83+
def test_multiple_vectors_error(self, mock_connection):
84+
with pytest.raises(ParamError):
85+
SearchIteratorV2(
86+
connection=mock_connection,
87+
collection_name="test_collection",
88+
data=[[1, 2], [3, 4]], # Multiple vectors
89+
batch_size=100
90+
)
91+
92+
@patch('pymilvus.client.search_iterator.SearchIteratorV2._probe_for_compability')
93+
def test_next_without_external_filter(self, mock_probe, mock_connection, search_data):
94+
mock_connection.search.return_value = self.create_mock_search_result()
95+
iterator = SearchIteratorV2(
96+
connection=mock_connection,
97+
collection_name="test_collection",
98+
data=search_data,
99+
batch_size=100
100+
)
101+
102+
result = iterator.next()
103+
assert result is not None
104+
assert len(result) == 10 # Number of results from mock
105+
106+
@patch('pymilvus.client.search_iterator.SearchIteratorV2._probe_for_compability')
107+
def test_next_with_limit(self, mock_probe, mock_connection, search_data):
108+
mock_connection.search.return_value = self.create_mock_search_result()
109+
iterator = SearchIteratorV2(
110+
connection=mock_connection,
111+
collection_name="test_collection",
112+
data=search_data,
113+
batch_size=100,
114+
limit=5
115+
)
116+
117+
result = iterator.next()
118+
assert result is not None
119+
assert len(result) == 5 # Limited to 5 results
120+
121+
def test_server_incompatible(self, mock_connection, search_data):
122+
# Mock search result with empty token
123+
mock_result = self.create_mock_search_result()
124+
mock_result._search_iterator_v2_results.token = ""
125+
mock_connection.search.return_value = mock_result
126+
127+
with pytest.raises(ServerVersionIncompatibleException):
128+
SearchIteratorV2(
129+
connection=mock_connection,
130+
collection_name="test_collection",
131+
data=search_data,
132+
batch_size=100
133+
)
134+
135+
@patch('pymilvus.client.search_iterator.SearchIteratorV2._probe_for_compability')
136+
def test_external_filter(self, mock_probe, mock_connection, search_data):
137+
mock_connection.search.return_value = self.create_mock_search_result()
138+
139+
def filter_func(hits):
140+
return [hit for hit in hits if hit.distance < 5.0]
141+
142+
iterator = SearchIteratorV2(
143+
connection=mock_connection,
144+
collection_name="test_collection",
145+
data=search_data,
146+
batch_size=100,
147+
external_filter_func=filter_func
148+
)
149+
150+
result = iterator.next()
151+
assert result is not None
152+
assert all(hit.distance < 5.0 for hit in result)
153+
154+
@patch('pymilvus.client.search_iterator.SearchIteratorV2._probe_for_compability')
155+
def test_filter_and_external_filter(self, mock_probe, mock_connection, search_data):
156+
# Create mock search result with field values
157+
mock_result = self.create_mock_search_result()
158+
for hit in mock_result[0]:
159+
hit.entity.field_1 = hit.id % 2
160+
mock_result[0] = list(filter(lambda x: x.entity.field_1 < 5, mock_result[0]))
161+
mock_connection.search.return_value = mock_result
162+
163+
expr_filter = "field_1 < 5"
164+
165+
def filter_func(hits):
166+
return [hit for hit in hits if hit.distance < 5.0] # Only hits with distance < 5.0 should pass
167+
168+
iterator = SearchIteratorV2(
169+
connection=mock_connection,
170+
collection_name="test_collection",
171+
data=search_data,
172+
batch_size=100,
173+
filter=expr_filter,
174+
external_filter_func=filter_func
175+
)
176+
177+
result = iterator.next()
178+
assert result is not None
179+
assert all(hit.distance < 5.0 and hit.entity.field_1 < 5 for hit in result)

0 commit comments

Comments
 (0)