Skip to content

Commit 5dec228

Browse files
authored
fix: [2.4] ensure create_index and load_collection are fully completed (#2477)
- Add `wait_for_creating_index`, `wait_for_loading_collection` to ensure `create_index` and `load_collection` are fully completed --------- Signed-off-by: Ruichen Bao <[email protected]>
1 parent a4ead48 commit 5dec228

File tree

2 files changed

+149
-46
lines changed

2 files changed

+149
-46
lines changed

pymilvus/client/async_grpc_handler.py

Lines changed: 146 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -2,16 +2,19 @@
22
import base64
33
import copy
44
import socket
5+
import time
56
from pathlib import Path
67
from typing import Callable, Dict, List, Optional, Union
78
from urllib import parse
89

910
import grpc
1011
from grpc._cython import cygrpc
1112

12-
from pymilvus.decorators import retry_on_rpc_failure, upgrade_reminder
13+
from pymilvus.decorators import ignore_unimplemented, retry_on_rpc_failure
1314
from pymilvus.exceptions import (
15+
AmbiguousIndexName,
1416
DescribeCollectionException,
17+
ExceptionsMessage,
1518
MilvusException,
1619
ParamError,
1720
)
@@ -31,6 +34,7 @@
3134
from .types import (
3235
DataType,
3336
ExtraList,
37+
IndexState,
3438
Status,
3539
get_cost_extra,
3640
)
@@ -63,6 +67,7 @@ def __init__(
6367
self._set_authorization(**kwargs)
6468
self._setup_db_name(kwargs.get("db_name"))
6569
self._setup_grpc_channel(**kwargs)
70+
self._is_channel_ready = False
6671
self.callbacks = []
6772

6873
def register_state_change_callback(self, callback: Callable):
@@ -103,33 +108,10 @@ def __enter__(self):
103108
def __exit__(self: object, exc_type: object, exc_val: object, exc_tb: object):
104109
pass
105110

106-
def _wait_for_channel_ready(self, timeout: Union[float] = 10, retry_interval: float = 1):
107-
try:
108-
109-
async def wait_for_async_channel_ready():
110-
await self._async_channel.channel_ready()
111-
112-
loop = asyncio.get_event_loop()
113-
loop.run_until_complete(wait_for_async_channel_ready())
114-
115-
self._setup_identifier_interceptor(self._user, timeout=timeout)
116-
except grpc.FutureTimeoutError as e:
117-
raise MilvusException(
118-
code=Status.CONNECT_FAILED,
119-
message=f"Fail connecting to server on {self._address}, illegal connection params or server unavailable",
120-
) from e
121-
except Exception as e:
122-
raise e from e
123-
124111
def close(self):
125112
self.deregister_state_change_callbacks()
126113
self._async_channel.close()
127114

128-
def reset_db_name(self, db_name: str):
129-
self._setup_db_name(db_name)
130-
self._setup_grpc_channel()
131-
self._setup_identifier_interceptor(self._user)
132-
133115
def _setup_authorization_interceptor(self, user: str, password: str, token: str):
134116
keys = []
135117
values = []
@@ -228,33 +210,51 @@ def _setup_grpc_channel(self, **kwargs):
228210
self._request_id = None
229211
self._async_stub = milvus_pb2_grpc.MilvusServiceStub(self._final_channel)
230212

231-
def _setup_identifier_interceptor(self, user: str, timeout: int = 10):
232-
host = socket.gethostname()
233-
self._identifier = self.__internal_register(user, host, timeout=timeout)
234-
_async_identifier_interceptor = async_header_adder_interceptor(
235-
["identifier"], [str(self._identifier)]
236-
)
237-
self._async_channel._unary_unary_interceptors.append(_async_identifier_interceptor)
238-
self._async_stub = milvus_pb2_grpc.MilvusServiceStub(self._async_channel)
239-
240213
@property
241214
def server_address(self):
242215
return self._address
243216

244217
def get_server_type(self):
245218
return get_server_type(self.server_address.split(":")[0])
246219

220+
async def ensure_channel_ready(self):
221+
try:
222+
if not self._is_channel_ready:
223+
# wait for channel ready
224+
await self._async_channel.channel_ready()
225+
# set identifier interceptor
226+
host = socket.gethostname()
227+
req = Prepare.register_request(self._user, host)
228+
response = await self._async_stub.Connect(request=req)
229+
check_status(response.status)
230+
_async_identifier_interceptor = async_header_adder_interceptor(
231+
["identifier"], [str(response.identifier)]
232+
)
233+
self._async_channel._unary_unary_interceptors.append(_async_identifier_interceptor)
234+
self._async_stub = milvus_pb2_grpc.MilvusServiceStub(self._async_channel)
235+
236+
self._is_channel_ready = True
237+
except grpc.FutureTimeoutError as e:
238+
raise MilvusException(
239+
code=Status.CONNECT_FAILED,
240+
message=f"Fail connecting to server on {self._address}, illegal connection params or server unavailable",
241+
) from e
242+
except Exception as e:
243+
raise e from e
244+
247245
@retry_on_rpc_failure()
248246
async def create_collection(
249247
self, collection_name: str, fields: List, timeout: Optional[float] = None, **kwargs
250248
):
249+
await self.ensure_channel_ready()
251250
check_pass_param(collection_name=collection_name, timeout=timeout)
252251
request = Prepare.create_collection_request(collection_name, fields, **kwargs)
253252
response = await self._async_stub.CreateCollection(request, timeout=timeout)
254253
check_status(response)
255254

256255
@retry_on_rpc_failure()
257256
async def drop_collection(self, collection_name: str, timeout: Optional[float] = None):
257+
await self.ensure_channel_ready()
258258
check_pass_param(collection_name=collection_name, timeout=timeout)
259259
request = Prepare.drop_collection_request(collection_name)
260260
response = await self._async_stub.DropCollection(request, timeout=timeout)
@@ -268,6 +268,7 @@ async def load_collection(
268268
timeout: Optional[float] = None,
269269
**kwargs,
270270
):
271+
await self.ensure_channel_ready()
271272
check_pass_param(
272273
collection_name=collection_name, replica_number=replica_number, timeout=timeout
273274
)
@@ -290,10 +291,48 @@ async def load_collection(
290291
response = await self._async_stub.LoadCollection(request, timeout=timeout)
291292
check_status(response)
292293

294+
await self.wait_for_loading_collection(collection_name, timeout, is_refresh=refresh)
295+
296+
@retry_on_rpc_failure()
297+
async def wait_for_loading_collection(
298+
self, collection_name: str, timeout: Optional[float] = None, is_refresh: bool = False
299+
):
300+
start = time.time()
301+
302+
def can_loop(t: int) -> bool:
303+
return True if timeout is None else t <= (start + timeout)
304+
305+
while can_loop(time.time()):
306+
progress = await self.get_loading_progress(
307+
collection_name, timeout=timeout, is_refresh=is_refresh
308+
)
309+
if progress >= 100:
310+
return
311+
await asyncio.sleep(Config.WaitTimeDurationWhenLoad)
312+
raise MilvusException(
313+
message=f"wait for loading collection timeout, collection: {collection_name}"
314+
)
315+
316+
@retry_on_rpc_failure()
317+
async def get_loading_progress(
318+
self,
319+
collection_name: str,
320+
partition_names: Optional[List[str]] = None,
321+
timeout: Optional[float] = None,
322+
is_refresh: bool = False,
323+
):
324+
request = Prepare.get_loading_progress(collection_name, partition_names)
325+
response = await self._async_stub.GetLoadingProgress(request, timeout=timeout)
326+
check_status(response.status)
327+
if is_refresh:
328+
return response.refresh_progress
329+
return response.progress
330+
293331
@retry_on_rpc_failure()
294332
async def describe_collection(
295333
self, collection_name: str, timeout: Optional[float] = None, **kwargs
296334
):
335+
await self.ensure_channel_ready()
297336
check_pass_param(collection_name=collection_name, timeout=timeout)
298337
request = Prepare.describe_collection_request(collection_name)
299338
response = await self._async_stub.DescribeCollection(request, timeout=timeout)
@@ -324,6 +363,7 @@ async def insert_rows(
324363
timeout: Optional[float] = None,
325364
**kwargs,
326365
):
366+
await self.ensure_channel_ready()
327367
request = await self._prepare_row_insert_request(
328368
collection_name, entities, partition_name, schema, timeout, **kwargs
329369
)
@@ -358,6 +398,7 @@ async def _prepare_row_insert_request(
358398
enable_dynamic=enable_dynamic,
359399
)
360400

401+
@retry_on_rpc_failure()
361402
async def delete(
362403
self,
363404
collection_name: str,
@@ -366,6 +407,7 @@ async def delete(
366407
timeout: Optional[float] = None,
367408
**kwargs,
368409
):
410+
await self.ensure_channel_ready()
369411
check_pass_param(collection_name=collection_name, timeout=timeout)
370412
try:
371413
req = Prepare.delete_request(
@@ -420,6 +462,7 @@ async def upsert(
420462
timeout: Optional[float] = None,
421463
**kwargs,
422464
):
465+
await self.ensure_channel_ready()
423466
if not check_invalid_binary_vector(entities):
424467
raise ParamError(message="Invalid binary vector data exists")
425468

@@ -465,6 +508,7 @@ async def upsert_rows(
465508
timeout: Optional[float] = None,
466509
**kwargs,
467510
):
511+
await self.ensure_channel_ready()
468512
if isinstance(entities, dict):
469513
entities = [entities]
470514
request = await self._prepare_row_upsert_request(
@@ -518,6 +562,7 @@ async def search(
518562
timeout: Optional[float] = None,
519563
**kwargs,
520564
):
565+
await self.ensure_channel_ready()
521566
check_pass_param(
522567
limit=limit,
523568
round_decimal=round_decimal,
@@ -555,6 +600,7 @@ async def hybrid_search(
555600
timeout: Optional[float] = None,
556601
**kwargs,
557602
):
603+
await self.ensure_channel_ready()
558604
check_pass_param(
559605
limit=limit,
560606
round_decimal=round_decimal,
@@ -608,7 +654,7 @@ async def create_index(
608654
collection_desc = await self.describe_collection(
609655
collection_name, timeout=timeout, **copy_kwargs
610656
)
611-
657+
await self.ensure_channel_ready()
612658
valid_field = False
613659
for fields in collection_desc["fields"]:
614660
if field_name != fields["name"]:
@@ -633,8 +679,67 @@ async def create_index(
633679
status = await self._async_stub.CreateIndex(index_param, timeout=timeout)
634680
check_status(status)
635681

682+
index_success, fail_reason = await self.wait_for_creating_index(
683+
collection_name=collection_name,
684+
index_name=index_name,
685+
timeout=timeout,
686+
field_name=field_name,
687+
)
688+
689+
if not index_success:
690+
raise MilvusException(message=fail_reason)
691+
636692
return Status(status.code, status.reason)
637693

694+
@retry_on_rpc_failure()
695+
async def wait_for_creating_index(
696+
self, collection_name: str, index_name: str, timeout: Optional[float] = None, **kwargs
697+
):
698+
timestamp = await self.alloc_timestamp()
699+
start = time.time()
700+
while True:
701+
await asyncio.sleep(0.5)
702+
state, fail_reason = await self.get_index_state(
703+
collection_name, index_name, timeout=timeout, timestamp=timestamp, **kwargs
704+
)
705+
if state == IndexState.Finished:
706+
return True, fail_reason
707+
if state == IndexState.Failed:
708+
return False, fail_reason
709+
end = time.time()
710+
if isinstance(timeout, int) and end - start > timeout:
711+
msg = (
712+
f"collection {collection_name} create index {index_name} "
713+
f"timeout in {timeout}s"
714+
)
715+
raise MilvusException(message=msg)
716+
717+
@retry_on_rpc_failure()
718+
async def get_index_state(
719+
self,
720+
collection_name: str,
721+
index_name: str,
722+
timeout: Optional[float] = None,
723+
timestamp: Optional[int] = None,
724+
**kwargs,
725+
):
726+
request = Prepare.describe_index_request(collection_name, index_name, timestamp)
727+
response = await self._async_stub.DescribeIndex(request, timeout=timeout)
728+
status = response.status
729+
check_status(status)
730+
731+
if len(response.index_descriptions) == 1:
732+
index_desc = response.index_descriptions[0]
733+
return index_desc.state, index_desc.index_state_fail_reason
734+
# just for create_index.
735+
field_name = kwargs.pop("field_name", "")
736+
if field_name != "":
737+
for index_desc in response.index_descriptions:
738+
if index_desc.field_name == field_name:
739+
return index_desc.state, index_desc.index_state_fail_reason
740+
741+
raise AmbiguousIndexName(message=ExceptionsMessage.AmbiguousIndexName)
742+
638743
@retry_on_rpc_failure()
639744
async def get(
640745
self,
@@ -645,6 +750,7 @@ async def get(
645750
timeout: Optional[float] = None,
646751
):
647752
# TODO: some check
753+
await self.ensure_channel_ready()
648754
request = Prepare.retrieve_request(collection_name, ids, output_fields, partition_names)
649755
return await self._async_stub.Retrieve.get(request, timeout=timeout)
650756

@@ -658,6 +764,7 @@ async def query(
658764
timeout: Optional[float] = None,
659765
**kwargs,
660766
):
767+
await self.ensure_channel_ready()
661768
if output_fields is not None and not isinstance(output_fields, (list,)):
662769
raise ParamError(message="Invalid query format. 'output_fields' must be a list")
663770
request = Prepare.query_request(
@@ -690,15 +797,9 @@ async def query(
690797
return ExtraList(results, extra=extra_dict)
691798

692799
@retry_on_rpc_failure()
693-
@upgrade_reminder
694-
def __internal_register(self, user: str, host: str, **kwargs) -> int:
695-
req = Prepare.register_request(user, host)
696-
697-
async def wait_for_connect_response():
698-
return await self._async_stub.Connect(request=req)
699-
700-
loop = asyncio.get_event_loop()
701-
response = loop.run_until_complete(wait_for_connect_response())
702-
800+
@ignore_unimplemented(0)
801+
async def alloc_timestamp(self, timeout: Optional[float] = None) -> int:
802+
request = milvus_types.AllocTimestampRequest()
803+
response = await self._async_stub.AllocTimestamp(request, timeout=timeout)
703804
check_status(response.status)
704-
return response.identifier
805+
return response.timestamp

pymilvus/orm/connections.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -400,7 +400,9 @@ def connect_milvus(**kwargs):
400400
t = kwargs.get("timeout")
401401
timeout = t if isinstance(t, (int, float)) else Config.MILVUS_CONN_TIMEOUT
402402

403-
gh._wait_for_channel_ready(timeout=timeout)
403+
if not _async:
404+
gh._wait_for_channel_ready(timeout=timeout)
405+
404406
if kwargs.get("keep_alive", False):
405407
gh.register_state_change_callback(
406408
ReconnectHandler(self, alias, kwargs_copy).reconnect_on_idle

0 commit comments

Comments
 (0)