Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 22 additions & 43 deletions pymilvus/client/async_grpc_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,37 +266,24 @@ async def drop_collection(
async def load_collection(
self,
collection_name: str,
replica_number: int = 1,
replica_number: Optional[int] = None,
timeout: Optional[float] = None,
**kwargs,
):
await self.ensure_channel_ready()
check_pass_param(
collection_name=collection_name, replica_number=replica_number, timeout=timeout
)
refresh = kwargs.get("refresh", kwargs.get("_refresh", False))
resource_groups = kwargs.get("resource_groups", kwargs.get("_resource_groups"))
load_fields = kwargs.get("load_fields", kwargs.get("_load_fields"))
skip_load_dynamic_field = kwargs.get(
"skip_load_dynamic_field", kwargs.get("_skip_load_dynamic_field", False)
)

request = Prepare.load_collection(
"",
collection_name,
replica_number,
refresh,
resource_groups,
load_fields,
skip_load_dynamic_field,
)
check_pass_param(timeout=timeout)
request = Prepare.load_collection(collection_name, replica_number, **kwargs)
response = await self._async_stub.LoadCollection(
request, timeout=timeout, metadata=_api_level_md(**kwargs)
)
check_status(response)

await self.wait_for_loading_collection(
collection_name, timeout, is_refresh=refresh, **kwargs
collection_name=collection_name,
is_refresh=request.refresh,
timeout=timeout,
**kwargs,
)

@retry_on_rpc_failure()
Expand All @@ -314,7 +301,10 @@ def can_loop(t: int) -> bool:

while can_loop(time.time()):
progress = await self.get_loading_progress(
collection_name, timeout=timeout, is_refresh=is_refresh, **kwargs
collection_name=collection_name,
is_refresh=is_refresh,
timeout=timeout,
**kwargs,
)
if progress >= 100:
return
Expand Down Expand Up @@ -825,41 +815,30 @@ async def load_partitions(
self,
collection_name: str,
partition_names: List[str],
replica_number: int = 1,
replica_number: Optional[int] = None,
timeout: Optional[float] = None,
**kwargs,
):
await self.ensure_channel_ready()
check_pass_param(
collection_name=collection_name,
partition_name_array=partition_names,
replica_number=replica_number,
timeout=timeout,
)
refresh = kwargs.get("refresh", kwargs.get("_refresh", False))
resource_groups = kwargs.get("resource_groups", kwargs.get("_resource_groups"))
load_fields = kwargs.get("load_fields", kwargs.get("_load_fields"))
skip_load_dynamic_field = kwargs.get(
"skip_load_dynamic_field", kwargs.get("_skip_load_dynamic_field", False)
)
check_pass_param(timeout=timeout)

request = Prepare.load_partitions(
"",
collection_name,
partition_names,
replica_number,
refresh,
resource_groups,
load_fields,
skip_load_dynamic_field,
collection_name=collection_name,
partition_names=partition_names,
replica_number=replica_number,
**kwargs,
)
response = await self._async_stub.LoadPartitions(
request, timeout=timeout, metadata=_api_level_md(**kwargs)
)
check_status(response)

await self.wait_for_loading_partitions(
collection_name, partition_names, is_refresh=refresh, **kwargs
collection_name=collection_name,
partition_names=partition_names,
is_refresh=request.is_refresh,
timeout=timeout,
**kwargs,
)

@retry_on_rpc_failure()
Expand Down
121 changes: 42 additions & 79 deletions pymilvus/client/grpc_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@
from .asynch import (
CreateIndexFuture,
FlushFuture,
LoadPartitionsFuture,
MutationFuture,
SearchFuture,
)
Expand Down Expand Up @@ -1263,43 +1262,36 @@ def wait_for_creating_index(
def load_collection(
self,
collection_name: str,
replica_number: int = 1,
replica_number: Optional[int] = None,
timeout: Optional[float] = None,
**kwargs,
):
check_pass_param(
collection_name=collection_name, replica_number=replica_number, timeout=timeout
)
# leading _ is misused for keywork escape for `async`
# other params now support prefix _ or not
# params without leading "_" have higher priority
refresh = kwargs.get("refresh", kwargs.get("_refresh", False))
resource_groups = kwargs.get("resource_groups", kwargs.get("_resource_groups"))
load_fields = kwargs.get("load_fields", kwargs.get("_load_fields"))
skip_load_dynamic_field = kwargs.get(
"skip_load_dynamic_field", kwargs.get("_skip_load_dynamic_field", False)
)
check_pass_param(timeout=timeout)

request = Prepare.load_collection(
"",
collection_name,
replica_number,
refresh,
resource_groups,
load_fields,
skip_load_dynamic_field,
)
request = Prepare.load_collection(collection_name, replica_number, **kwargs)
response = self._stub.LoadCollection(
request, timeout=timeout, metadata=_api_level_md(**kwargs)
request,
timeout=timeout,
metadata=_api_level_md(**kwargs),
)
check_status(response)
_async = kwargs.get("_async", False)
if not _async:
self.wait_for_loading_collection(collection_name, timeout, is_refresh=refresh, **kwargs)

if kwargs.get("_async", False):
return

self.wait_for_loading_collection(
collection_name=collection_name,
is_refresh=request.refresh,
timeout=timeout,
**kwargs,
)

@retry_on_rpc_failure()
def load_collection_progress(
self, collection_name: str, timeout: Optional[float] = None, **kwargs
self,
collection_name: str,
timeout: Optional[float] = None,
**kwargs,
):
"""Return loading progress of collection"""
progress = self.get_loading_progress(collection_name, timeout=timeout)
Expand All @@ -1322,7 +1314,10 @@ def can_loop(t: int) -> bool:

while can_loop(time.time()):
progress = self.get_loading_progress(
collection_name, timeout=timeout, is_refresh=is_refresh, **kwargs
collection_name,
is_refresh=is_refresh,
timeout=timeout,
**kwargs,
)
if progress >= 100:
return
Expand All @@ -1345,66 +1340,30 @@ def load_partitions(
self,
collection_name: str,
partition_names: List[str],
replica_number: int = 1,
replica_number: Optional[int] = None,
timeout: Optional[float] = None,
**kwargs,
):
check_pass_param(
collection_name=collection_name,
partition_name_array=partition_names,
replica_number=replica_number,
timeout=timeout,
)
# leading _ is misused for keywork escape for `async`
# other params now support prefix _ or not
# params without leading "_" have higher priority
refresh = kwargs.get("refresh", kwargs.get("_refresh", False))
resource_groups = kwargs.get("resource_groups", kwargs.get("_resource_groups"))
load_fields = kwargs.get("load_fields", kwargs.get("_load_fields"))
skip_load_dynamic_field = kwargs.get(
"skip_load_dynamic_field", kwargs.get("_skip_load_dynamic_field", False)
)
check_pass_param(timeout=timeout)

request = Prepare.load_partitions(
"",
collection_name,
partition_names,
replica_number,
refresh,
resource_groups,
load_fields,
skip_load_dynamic_field,
collection_name=collection_name,
partition_names=partition_names,
replica_number=replica_number,
)
future = self._stub.LoadPartitions.future(
response = self._stub.LoadPartitions(
request, timeout=timeout, metadata=_api_level_md(**kwargs)
)

if kwargs.get("_async", False):

def _check():
if kwargs.get("sync", True):
self.wait_for_loading_partitions(
collection_name, partition_names, is_refresh=refresh, **kwargs
)

load_partitions_future = LoadPartitionsFuture(future)
load_partitions_future.add_callback(_check)

user_cb = kwargs.get("_callback")
if user_cb:
load_partitions_future.add_callback(user_cb)

return load_partitions_future

response = future.result()
check_status(response)
sync = kwargs.get("sync", True)
if sync:

if kwargs.get("sync", True) or not kwargs.get("_async", False):
self.wait_for_loading_partitions(
collection_name, partition_names, is_refresh=refresh, **kwargs
collection_name=collection_name,
partition_names=partition_names,
is_refresh=request.refresh,
timeout=timeout,
**kwargs,
)
return None
return None

@retry_on_rpc_failure()
def wait_for_loading_partitions(
Expand All @@ -1422,7 +1381,11 @@ def can_loop(t: int) -> bool:

while can_loop(time.time()):
progress = self.get_loading_progress(
collection_name, partition_names, timeout=timeout, is_refresh=is_refresh, **kwargs
collection_name=collection_name,
partition_names=partition_names,
timeout=timeout,
is_refresh=is_refresh,
**kwargs,
)
if progress >= 100:
return
Expand Down
88 changes: 61 additions & 27 deletions pymilvus/client/prepare.py
Original file line number Diff line number Diff line change
Expand Up @@ -1191,24 +1191,40 @@ def get_index_state_request(cls, collection_name: str, index_name: str):
@classmethod
def load_collection(
cls,
db_name: str,
collection_name: str,
replica_number: int,
refresh: bool,
resource_groups: List[str],
load_fields: List[str],
skip_load_dynamic_field: bool,
replica_number: Optional[int] = None,
**kwargs,
):
return milvus_types.LoadCollectionRequest(
db_name=db_name,
check_pass_param(collection_name=collection_name)
req = milvus_types.LoadCollectionRequest(
collection_name=collection_name,
replica_number=replica_number,
refresh=refresh,
resource_groups=resource_groups,
load_fields=load_fields,
skip_load_dynamic_field=skip_load_dynamic_field,
)

if replica_number:
check_pass_param(replica_number=replica_number)
req.replica_number = replica_number

# Keep underscore key for backward compatibility
if "refresh" in kwargs or "_refresh" in kwargs:
refresh = kwargs.get("refresh", kwargs.get("_refresh", False))
req.refresh = refresh

if "resource_groups" in kwargs or "_resource_groups" in kwargs:
resource_groups = kwargs.get("resource_groups", kwargs.get("_resource_groups"))
req.resource_groups = resource_groups

if "load_fields" in kwargs or "_load_fields" in kwargs:
load_fields = kwargs.get("load_fields", kwargs.get("_load_fields"))
req.load_fields = load_fields

if "skip_load_dynamic_field" in kwargs or "_skip_load_dynamic_field" in kwargs:
skip_load_dynamic_field = kwargs.get(
"skip_load_dynamic_field", kwargs.get("_skip_load_dynamic_field", False)
)
req.skip_load_dynamic_field = skip_load_dynamic_field

return req

@classmethod
def release_collection(cls, db_name: str, collection_name: str):
return milvus_types.ReleaseCollectionRequest(
Expand All @@ -1218,26 +1234,44 @@ def release_collection(cls, db_name: str, collection_name: str):
@classmethod
def load_partitions(
cls,
db_name: str,
collection_name: str,
partition_names: List[str],
replica_number: int,
refresh: bool,
resource_groups: List[str],
load_fields: List[str],
skip_load_dynamic_field: bool,
replica_number: Optional[int] = None,
**kwargs,
):
return milvus_types.LoadPartitionsRequest(
db_name=db_name,
check_pass_param(collection_name=collection_name)
req = milvus_types.LoadPartitionsRequest(
collection_name=collection_name,
partition_names=partition_names,
replica_number=replica_number,
refresh=refresh,
resource_groups=resource_groups,
load_fields=load_fields,
skip_load_dynamic_field=skip_load_dynamic_field,
)

if partition_names:
check_pass_param(partition_name_array=partition_names)
req.partition_names.extend(partition_names)

if replica_number:
check_pass_param(replica_number=replica_number)
req.replica_number = replica_number

# Keep underscore key for backward compatibility
if "refresh" in kwargs or "_refresh" in kwargs:
refresh = kwargs.get("refresh", kwargs.get("_refresh", False))
req.refresh = refresh

if "resource_groups" in kwargs or "_resource_groups" in kwargs:
resource_groups = kwargs.get("resource_groups", kwargs.get("_resource_groups"))
req.resource_groups = resource_groups

if "load_fields" in kwargs or "_load_fields" in kwargs:
load_fields = kwargs.get("load_fields", kwargs.get("_load_fields"))
req.load_fields = load_fields

if "skip_load_dynamic_field" in kwargs or "_skip_load_dynamic_field" in kwargs:
skip_load_dynamic_field = kwargs.get(
"skip_load_dynamic_field", kwargs.get("_skip_load_dynamic_field", False)
)
req.skip_load_dynamic_field = skip_load_dynamic_field
return req

@classmethod
def release_partitions(cls, db_name: str, collection_name: str, partition_names: List[str]):
return milvus_types.ReleasePartitionsRequest(
Expand Down
Loading