|
13 | 13 | OmitZeroDict,
|
14 | 14 | construct_cost_extra,
|
15 | 15 | )
|
| 16 | +from pymilvus.client.utils import is_vector_type |
16 | 17 | from pymilvus.exceptions import (
|
17 | 18 | DataTypeNotMatchException,
|
| 19 | + ErrorCode, |
18 | 20 | MilvusException,
|
19 | 21 | ParamError,
|
20 | 22 | PrimaryKeyException,
|
21 | 23 | )
|
22 | 24 | from pymilvus.orm import utility
|
23 | 25 | from pymilvus.orm.collection import CollectionSchema
|
24 | 26 | from pymilvus.orm.connections import connections
|
| 27 | +from pymilvus.orm.constants import FIELDS, METRIC_TYPE, TYPE, UNLIMITED |
| 28 | +from pymilvus.orm.iterator import QueryIterator, SearchIterator |
25 | 29 | from pymilvus.orm.types import DataType
|
26 | 30 |
|
27 | 31 | from .index import IndexParams
|
@@ -480,6 +484,120 @@ def query(
|
480 | 484 |
|
481 | 485 | return res
|
482 | 486 |
|
| 487 | + def query_iterator( |
| 488 | + self, |
| 489 | + collection_name: str, |
| 490 | + batch_size: Optional[int] = 1000, |
| 491 | + limit: Optional[int] = UNLIMITED, |
| 492 | + filter: Optional[str] = "", |
| 493 | + output_fields: Optional[List[str]] = None, |
| 494 | + partition_names: Optional[List[str]] = None, |
| 495 | + timeout: Optional[float] = None, |
| 496 | + **kwargs, |
| 497 | + ): |
| 498 | + if filter is not None and not isinstance(filter, str): |
| 499 | + raise DataTypeNotMatchException(message=ExceptionsMessage.ExprType % type(filter)) |
| 500 | + |
| 501 | + conn = self._get_connection() |
| 502 | + # set up schema for iterator |
| 503 | + try: |
| 504 | + schema_dict = conn.describe_collection(collection_name, timeout=timeout, **kwargs) |
| 505 | + except Exception as ex: |
| 506 | + logger.error("Failed to describe collection: %s", collection_name) |
| 507 | + raise ex from ex |
| 508 | + |
| 509 | + return QueryIterator( |
| 510 | + connection=conn, |
| 511 | + collection_name=collection_name, |
| 512 | + batch_size=batch_size, |
| 513 | + limit=limit, |
| 514 | + expr=filter, |
| 515 | + output_fields=output_fields, |
| 516 | + partition_names=partition_names, |
| 517 | + schema=schema_dict, |
| 518 | + timeout=timeout, |
| 519 | + **kwargs, |
| 520 | + ) |
| 521 | + |
| 522 | + def search_iterator( |
| 523 | + self, |
| 524 | + collection_name: str, |
| 525 | + data: Union[List[list], list], |
| 526 | + batch_size: Optional[int] = 1000, |
| 527 | + filter: Optional[str] = None, |
| 528 | + limit: Optional[int] = UNLIMITED, |
| 529 | + output_fields: Optional[List[str]] = None, |
| 530 | + search_params: Optional[dict] = None, |
| 531 | + timeout: Optional[float] = None, |
| 532 | + partition_names: Optional[List[str]] = None, |
| 533 | + anns_field: Optional[str] = None, |
| 534 | + round_decimal: int = -1, |
| 535 | + **kwargs, |
| 536 | + ): |
| 537 | + if filter is not None and not isinstance(filter, str): |
| 538 | + raise DataTypeNotMatchException(message=ExceptionsMessage.ExprType % type(filter)) |
| 539 | + |
| 540 | + conn = self._get_connection() |
| 541 | + # set up schema for iterator |
| 542 | + try: |
| 543 | + schema_dict = conn.describe_collection(collection_name, timeout=timeout, **kwargs) |
| 544 | + except Exception as ex: |
| 545 | + logger.error("Failed to describe collection: %s", collection_name) |
| 546 | + raise ex from ex |
| 547 | + # if anns_field is not provided |
| 548 | + # if only one vector field, use to search |
| 549 | + # if multiple vector fields, raise exception and abort |
| 550 | + |
| 551 | + if anns_field is None or anns_field == "": |
| 552 | + vec_field = None |
| 553 | + fields = schema_dict[FIELDS] |
| 554 | + vec_field_count = 0 |
| 555 | + for field in fields: |
| 556 | + if is_vector_type(field[TYPE]): |
| 557 | + vec_field_count += 1 |
| 558 | + vec_field = field |
| 559 | + if vec_field is None: |
| 560 | + raise MilvusException( |
| 561 | + code=ErrorCode.UNEXPECTED_ERROR, |
| 562 | + message="there should be at least one vector field in milvus collection", |
| 563 | + ) |
| 564 | + if vec_field_count > 1: |
| 565 | + raise MilvusException( |
| 566 | + code=ErrorCode.UNEXPECTED_ERROR, |
| 567 | + message="must specify anns_field when there are more than one vector field", |
| 568 | + ) |
| 569 | + anns_field = vec_field["name"] |
| 570 | + |
| 571 | + if search_params is None: |
| 572 | + search_params = {} |
| 573 | + if METRIC_TYPE not in search_params: |
| 574 | + indexes = conn.list_indexes(collection_name) |
| 575 | + for index in indexes: |
| 576 | + if anns_field == index.index_name: |
| 577 | + params = index.params |
| 578 | + for param in params: |
| 579 | + if param.key == METRIC_TYPE: |
| 580 | + search_params[METRIC_TYPE] = param.value |
| 581 | + if METRIC_TYPE not in search_params: |
| 582 | + raise MilvusException(ParamError, "Must provide metrics type for search iterator") |
| 583 | + |
| 584 | + return SearchIterator( |
| 585 | + connection=self._get_connection(), |
| 586 | + collection_name=collection_name, |
| 587 | + data=data, |
| 588 | + ann_field=anns_field, |
| 589 | + param=search_params, |
| 590 | + batch_size=batch_size, |
| 591 | + limit=limit, |
| 592 | + expr=filter, |
| 593 | + partition_names=partition_names, |
| 594 | + output_fields=output_fields, |
| 595 | + timeout=timeout, |
| 596 | + round_decimal=round_decimal, |
| 597 | + schema=schema_dict, |
| 598 | + **kwargs, |
| 599 | + ) |
| 600 | + |
483 | 601 | def get(
|
484 | 602 | self,
|
485 | 603 | collection_name: str,
|
|
0 commit comments