Skip to content

Commit 0f20fed

Browse files
author
MrPresent-Han
committed
feature: support milvus-client iterator
Signed-off-by: MrPresent-Han <[email protected]>
1 parent 712e9b6 commit 0f20fed

File tree

4 files changed

+241
-0
lines changed

4 files changed

+241
-0
lines changed

examples/iterator/iterator.py

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
from pymilvus.milvus_client.milvus_client import MilvusClient
2+
from pymilvus import (
3+
FieldSchema, CollectionSchema, DataType,
4+
)
5+
import numpy as np
6+
7+
collection_name = "test_milvus_client_iterator"
8+
prepare_new_data = True
9+
clean_exist = True
10+
11+
USER_ID = "id"
12+
AGE = "age"
13+
DEPOSIT = "deposit"
14+
PICTURE = "picture"
15+
DIM = 8
16+
NUM_ENTITIES = 10000
17+
rng = np.random.default_rng(seed=19530)
18+
19+
20+
def test_query_iterator(milvus_client: MilvusClient):
21+
# test query iterator
22+
expr = f"10 <= {AGE} <= 25"
23+
output_fields = [USER_ID, AGE]
24+
queryIt = milvus_client.query_iterator(collection_name, filter=expr, batch_size=50, output_fields=output_fields)
25+
page_idx = 0
26+
while True:
27+
res = queryIt.next()
28+
if len(res) == 0:
29+
print("query iteration finished, close")
30+
queryIt.close()
31+
break
32+
for i in range(len(res)):
33+
print(res[i])
34+
page_idx += 1
35+
print(f"page{page_idx}-------------------------")
36+
37+
def test_search_iterator(milvus_client: MilvusClient):
38+
vector_to_search = rng.random((1, DIM), np.float32)
39+
search_iterator = milvus_client.search_iterator(collection_name, data=vector_to_search, batch_size=100, anns_field=PICTURE)
40+
41+
page_idx = 0
42+
while True:
43+
res = search_iterator.next()
44+
if len(res) == 0:
45+
print("query iteration finished, close")
46+
search_iterator.close()
47+
break
48+
for i in range(len(res)):
49+
print(res[i])
50+
page_idx += 1
51+
print(f"page{page_idx}-------------------------")
52+
53+
54+
def main():
55+
milvus_client = MilvusClient("http://localhost:19530")
56+
if milvus_client.has_collection(collection_name) and clean_exist:
57+
milvus_client.drop_collection(collection_name)
58+
print(f"dropped existed collection{collection_name}")
59+
60+
if not milvus_client.has_collection(collection_name):
61+
fields = [
62+
FieldSchema(name=USER_ID, dtype=DataType.INT64, is_primary=True, auto_id=False),
63+
FieldSchema(name=AGE, dtype=DataType.INT64),
64+
FieldSchema(name=DEPOSIT, dtype=DataType.DOUBLE),
65+
FieldSchema(name=PICTURE, dtype=DataType.FLOAT_VECTOR, dim=DIM)
66+
]
67+
schema = CollectionSchema(fields)
68+
milvus_client.create_collection(collection_name, dimension=DIM, schema=schema)
69+
70+
if prepare_new_data:
71+
entities = []
72+
for i in range(NUM_ENTITIES):
73+
entity = {
74+
USER_ID: i,
75+
AGE: (i % 100),
76+
DEPOSIT: float(i),
77+
PICTURE: rng.random((1, DIM))[0]
78+
}
79+
entities.append(entity)
80+
milvus_client.insert(collection_name, entities)
81+
milvus_client.flush(collection_name)
82+
print(f"Finish flush collections:{collection_name}")
83+
84+
index_params = milvus_client.prepare_index_params()
85+
86+
index_params.add_index(
87+
field_name=PICTURE,
88+
index_type='IVF_FLAT',
89+
metric_type='L2',
90+
params={"nlist": 1024}
91+
)
92+
milvus_client.create_index(collection_name, index_params)
93+
milvus_client.load_collection(collection_name)
94+
#test_query_iterator(milvus_client=milvus_client)
95+
test_search_iterator(milvus_client=milvus_client)
96+
97+
98+
if __name__ == '__main__':
99+
main()

pymilvus/client/utils.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -375,3 +375,26 @@ def is_scipy_sparse(cls, data: Any):
375375
"csr_array",
376376
"spmatrix",
377377
]
378+
379+
380+
def is_sparse_vector_type(data_type: DataType) -> bool:
381+
return data_type == data_type.SPARSE_FLOAT_VECTOR
382+
383+
384+
dense_vector_type_set = {DataType.FLOAT_VECTOR, DataType.FLOAT16_VECTOR, DataType.BFLOAT16_VECTOR}
385+
386+
387+
def is_dense_vector_type(data_type: DataType) -> bool:
388+
return data_type in dense_vector_type_set
389+
390+
391+
def is_float_vector_type(data_type: DataType):
392+
return is_sparse_vector_type(data_type) or is_dense_vector_type(data_type)
393+
394+
395+
def is_binary_vector_type(data_type: DataType):
396+
return data_type == DataType.BINARY_VECTOR
397+
398+
399+
def is_vector_type(data_type: DataType):
400+
return is_float_vector_type(data_type) or is_binary_vector_type(data_type)

pymilvus/milvus_client/milvus_client.py

Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,15 +13,19 @@
1313
OmitZeroDict,
1414
construct_cost_extra,
1515
)
16+
from pymilvus.client.utils import is_vector_type
1617
from pymilvus.exceptions import (
1718
DataTypeNotMatchException,
19+
ErrorCode,
1820
MilvusException,
1921
ParamError,
2022
PrimaryKeyException,
2123
)
2224
from pymilvus.orm import utility
2325
from pymilvus.orm.collection import CollectionSchema
2426
from pymilvus.orm.connections import connections
27+
from pymilvus.orm.constants import FIELDS, METRIC_TYPE, TYPE, UNLIMITED
28+
from pymilvus.orm.iterator import QueryIterator, SearchIterator
2529
from pymilvus.orm.types import DataType
2630

2731
from .index import IndexParams
@@ -480,6 +484,120 @@ def query(
480484

481485
return res
482486

487+
def query_iterator(
488+
self,
489+
collection_name: str,
490+
batch_size: Optional[int] = 1000,
491+
limit: Optional[int] = UNLIMITED,
492+
filter: Optional[str] = "",
493+
output_fields: Optional[List[str]] = None,
494+
partition_names: Optional[List[str]] = None,
495+
timeout: Optional[float] = None,
496+
**kwargs,
497+
):
498+
if filter is not None and not isinstance(filter, str):
499+
raise DataTypeNotMatchException(message=ExceptionsMessage.ExprType % type(filter))
500+
501+
conn = self._get_connection()
502+
# set up schema for iterator
503+
try:
504+
schema_dict = conn.describe_collection(collection_name, timeout=timeout, **kwargs)
505+
except Exception as ex:
506+
logger.error("Failed to describe collection: %s", collection_name)
507+
raise ex from ex
508+
509+
return QueryIterator(
510+
connection=conn,
511+
collection_name=collection_name,
512+
batch_size=batch_size,
513+
limit=limit,
514+
expr=filter,
515+
output_fields=output_fields,
516+
partition_names=partition_names,
517+
schema=schema_dict,
518+
timeout=timeout,
519+
**kwargs,
520+
)
521+
522+
def search_iterator(
523+
self,
524+
collection_name: str,
525+
data: Union[List[list], list],
526+
batch_size: Optional[int] = 1000,
527+
filter: Optional[str] = None,
528+
limit: Optional[int] = UNLIMITED,
529+
output_fields: Optional[List[str]] = None,
530+
search_params: Optional[dict] = None,
531+
timeout: Optional[float] = None,
532+
partition_names: Optional[List[str]] = None,
533+
anns_field: Optional[str] = None,
534+
round_decimal: int = -1,
535+
**kwargs,
536+
):
537+
if filter is not None and not isinstance(filter, str):
538+
raise DataTypeNotMatchException(message=ExceptionsMessage.ExprType % type(filter))
539+
540+
conn = self._get_connection()
541+
# set up schema for iterator
542+
try:
543+
schema_dict = conn.describe_collection(collection_name, timeout=timeout, **kwargs)
544+
except Exception as ex:
545+
logger.error("Failed to describe collection: %s", collection_name)
546+
raise ex from ex
547+
# if anns_field is not provided
548+
# if only one vector field, use to search
549+
# if multiple vector fields, raise exception and abort
550+
551+
if anns_field is None or anns_field == "":
552+
vec_field = None
553+
fields = schema_dict[FIELDS]
554+
vec_field_count = 0
555+
for field in fields:
556+
if is_vector_type(field[TYPE]):
557+
vec_field_count += 1
558+
vec_field = field
559+
if vec_field is None:
560+
raise MilvusException(
561+
code=ErrorCode.UNEXPECTED_ERROR,
562+
message="there should be at least one vector field in milvus collection",
563+
)
564+
if vec_field_count > 1:
565+
raise MilvusException(
566+
code=ErrorCode.UNEXPECTED_ERROR,
567+
message="must specify anns_field when there are more than one vector field",
568+
)
569+
anns_field = vec_field["name"]
570+
571+
if search_params is None:
572+
search_params = {}
573+
if METRIC_TYPE not in search_params:
574+
indexes = conn.list_indexes(collection_name)
575+
for index in indexes:
576+
if anns_field == index.index_name:
577+
params = index.params
578+
for param in params:
579+
if param.key == METRIC_TYPE:
580+
search_params[METRIC_TYPE] = param.value
581+
if METRIC_TYPE not in search_params:
582+
raise MilvusException(ParamError, "Must provide metrics type for search iterator")
583+
584+
return SearchIterator(
585+
connection=self._get_connection(),
586+
collection_name=collection_name,
587+
data=data,
588+
ann_field=anns_field,
589+
param=search_params,
590+
batch_size=batch_size,
591+
limit=limit,
592+
expr=filter,
593+
partition_names=partition_names,
594+
output_fields=output_fields,
595+
timeout=timeout,
596+
round_decimal=round_decimal,
597+
schema=schema_dict,
598+
**kwargs,
599+
)
600+
483601
def get(
484602
self,
485603
collection_name: str,

pymilvus/orm/constants.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
MILVUS_LIMIT = "limit"
3838
BATCH_SIZE = "batch_size"
3939
ID = "id"
40+
TYPE = "type"
4041
METRIC_TYPE = "metric_type"
4142
PARAMS = "params"
4243
DISTANCE = "distance"

0 commit comments

Comments
 (0)