2
2
import base64
3
3
import copy
4
4
import socket
5
+ import time
5
6
from pathlib import Path
6
7
from typing import Callable , Dict , List , Optional , Union
7
8
from urllib import parse
8
9
9
10
import grpc
10
11
from grpc ._cython import cygrpc
11
12
12
- from pymilvus .decorators import retry_on_rpc_failure , upgrade_reminder
13
+ from pymilvus .decorators import ignore_unimplemented , retry_on_rpc_failure
13
14
from pymilvus .exceptions import (
15
+ AmbiguousIndexName ,
14
16
DescribeCollectionException ,
17
+ ExceptionsMessage ,
15
18
MilvusException ,
16
19
ParamError ,
17
20
)
31
34
from .types import (
32
35
DataType ,
33
36
ExtraList ,
37
+ IndexState ,
34
38
Status ,
35
39
get_cost_extra ,
36
40
)
@@ -63,6 +67,7 @@ def __init__(
63
67
self ._set_authorization (** kwargs )
64
68
self ._setup_db_name (kwargs .get ("db_name" ))
65
69
self ._setup_grpc_channel (** kwargs )
70
+ self ._is_channel_ready = False
66
71
self .callbacks = []
67
72
68
73
def register_state_change_callback (self , callback : Callable ):
@@ -103,33 +108,10 @@ def __enter__(self):
103
108
def __exit__ (self : object , exc_type : object , exc_val : object , exc_tb : object ):
104
109
pass
105
110
106
- def _wait_for_channel_ready (self , timeout : Union [float ] = 10 , retry_interval : float = 1 ):
107
- try :
108
-
109
- async def wait_for_async_channel_ready ():
110
- await self ._async_channel .channel_ready ()
111
-
112
- loop = asyncio .get_event_loop ()
113
- loop .run_until_complete (wait_for_async_channel_ready ())
114
-
115
- self ._setup_identifier_interceptor (self ._user , timeout = timeout )
116
- except grpc .FutureTimeoutError as e :
117
- raise MilvusException (
118
- code = Status .CONNECT_FAILED ,
119
- message = f"Fail connecting to server on { self ._address } , illegal connection params or server unavailable" ,
120
- ) from e
121
- except Exception as e :
122
- raise e from e
123
-
124
111
def close (self ):
125
112
self .deregister_state_change_callbacks ()
126
113
self ._async_channel .close ()
127
114
128
- def reset_db_name (self , db_name : str ):
129
- self ._setup_db_name (db_name )
130
- self ._setup_grpc_channel ()
131
- self ._setup_identifier_interceptor (self ._user )
132
-
133
115
def _setup_authorization_interceptor (self , user : str , password : str , token : str ):
134
116
keys = []
135
117
values = []
@@ -228,33 +210,51 @@ def _setup_grpc_channel(self, **kwargs):
228
210
self ._request_id = None
229
211
self ._async_stub = milvus_pb2_grpc .MilvusServiceStub (self ._final_channel )
230
212
231
- def _setup_identifier_interceptor (self , user : str , timeout : int = 10 ):
232
- host = socket .gethostname ()
233
- self ._identifier = self .__internal_register (user , host , timeout = timeout )
234
- _async_identifier_interceptor = async_header_adder_interceptor (
235
- ["identifier" ], [str (self ._identifier )]
236
- )
237
- self ._async_channel ._unary_unary_interceptors .append (_async_identifier_interceptor )
238
- self ._async_stub = milvus_pb2_grpc .MilvusServiceStub (self ._async_channel )
239
-
240
213
@property
241
214
def server_address (self ):
242
215
return self ._address
243
216
244
217
def get_server_type (self ):
245
218
return get_server_type (self .server_address .split (":" )[0 ])
246
219
220
+ async def ensure_channel_ready (self ):
221
+ try :
222
+ if not self ._is_channel_ready :
223
+ # wait for channel ready
224
+ await self ._async_channel .channel_ready ()
225
+ # set identifier interceptor
226
+ host = socket .gethostname ()
227
+ req = Prepare .register_request (self ._user , host )
228
+ response = await self ._async_stub .Connect (request = req )
229
+ check_status (response .status )
230
+ _async_identifier_interceptor = async_header_adder_interceptor (
231
+ ["identifier" ], [str (response .identifier )]
232
+ )
233
+ self ._async_channel ._unary_unary_interceptors .append (_async_identifier_interceptor )
234
+ self ._async_stub = milvus_pb2_grpc .MilvusServiceStub (self ._async_channel )
235
+
236
+ self ._is_channel_ready = True
237
+ except grpc .FutureTimeoutError as e :
238
+ raise MilvusException (
239
+ code = Status .CONNECT_FAILED ,
240
+ message = f"Fail connecting to server on { self ._address } , illegal connection params or server unavailable" ,
241
+ ) from e
242
+ except Exception as e :
243
+ raise e from e
244
+
247
245
@retry_on_rpc_failure ()
248
246
async def create_collection (
249
247
self , collection_name : str , fields : List , timeout : Optional [float ] = None , ** kwargs
250
248
):
249
+ await self .ensure_channel_ready ()
251
250
check_pass_param (collection_name = collection_name , timeout = timeout )
252
251
request = Prepare .create_collection_request (collection_name , fields , ** kwargs )
253
252
response = await self ._async_stub .CreateCollection (request , timeout = timeout )
254
253
check_status (response )
255
254
256
255
@retry_on_rpc_failure ()
257
256
async def drop_collection (self , collection_name : str , timeout : Optional [float ] = None ):
257
+ await self .ensure_channel_ready ()
258
258
check_pass_param (collection_name = collection_name , timeout = timeout )
259
259
request = Prepare .drop_collection_request (collection_name )
260
260
response = await self ._async_stub .DropCollection (request , timeout = timeout )
@@ -268,6 +268,7 @@ async def load_collection(
268
268
timeout : Optional [float ] = None ,
269
269
** kwargs ,
270
270
):
271
+ await self .ensure_channel_ready ()
271
272
check_pass_param (
272
273
collection_name = collection_name , replica_number = replica_number , timeout = timeout
273
274
)
@@ -290,10 +291,48 @@ async def load_collection(
290
291
response = await self ._async_stub .LoadCollection (request , timeout = timeout )
291
292
check_status (response )
292
293
294
+ await self .wait_for_loading_collection (collection_name , timeout , is_refresh = refresh )
295
+
296
+ @retry_on_rpc_failure ()
297
+ async def wait_for_loading_collection (
298
+ self , collection_name : str , timeout : Optional [float ] = None , is_refresh : bool = False
299
+ ):
300
+ start = time .time ()
301
+
302
+ def can_loop (t : int ) -> bool :
303
+ return True if timeout is None else t <= (start + timeout )
304
+
305
+ while can_loop (time .time ()):
306
+ progress = await self .get_loading_progress (
307
+ collection_name , timeout = timeout , is_refresh = is_refresh
308
+ )
309
+ if progress >= 100 :
310
+ return
311
+ await asyncio .sleep (Config .WaitTimeDurationWhenLoad )
312
+ raise MilvusException (
313
+ message = f"wait for loading collection timeout, collection: { collection_name } "
314
+ )
315
+
316
+ @retry_on_rpc_failure ()
317
+ async def get_loading_progress (
318
+ self ,
319
+ collection_name : str ,
320
+ partition_names : Optional [List [str ]] = None ,
321
+ timeout : Optional [float ] = None ,
322
+ is_refresh : bool = False ,
323
+ ):
324
+ request = Prepare .get_loading_progress (collection_name , partition_names )
325
+ response = await self ._async_stub .GetLoadingProgress (request , timeout = timeout )
326
+ check_status (response .status )
327
+ if is_refresh :
328
+ return response .refresh_progress
329
+ return response .progress
330
+
293
331
@retry_on_rpc_failure ()
294
332
async def describe_collection (
295
333
self , collection_name : str , timeout : Optional [float ] = None , ** kwargs
296
334
):
335
+ await self .ensure_channel_ready ()
297
336
check_pass_param (collection_name = collection_name , timeout = timeout )
298
337
request = Prepare .describe_collection_request (collection_name )
299
338
response = await self ._async_stub .DescribeCollection (request , timeout = timeout )
@@ -324,6 +363,7 @@ async def insert_rows(
324
363
timeout : Optional [float ] = None ,
325
364
** kwargs ,
326
365
):
366
+ await self .ensure_channel_ready ()
327
367
request = await self ._prepare_row_insert_request (
328
368
collection_name , entities , partition_name , schema , timeout , ** kwargs
329
369
)
@@ -358,6 +398,7 @@ async def _prepare_row_insert_request(
358
398
enable_dynamic = enable_dynamic ,
359
399
)
360
400
401
+ @retry_on_rpc_failure ()
361
402
async def delete (
362
403
self ,
363
404
collection_name : str ,
@@ -366,6 +407,7 @@ async def delete(
366
407
timeout : Optional [float ] = None ,
367
408
** kwargs ,
368
409
):
410
+ await self .ensure_channel_ready ()
369
411
check_pass_param (collection_name = collection_name , timeout = timeout )
370
412
try :
371
413
req = Prepare .delete_request (
@@ -420,6 +462,7 @@ async def upsert(
420
462
timeout : Optional [float ] = None ,
421
463
** kwargs ,
422
464
):
465
+ await self .ensure_channel_ready ()
423
466
if not check_invalid_binary_vector (entities ):
424
467
raise ParamError (message = "Invalid binary vector data exists" )
425
468
@@ -465,6 +508,7 @@ async def upsert_rows(
465
508
timeout : Optional [float ] = None ,
466
509
** kwargs ,
467
510
):
511
+ await self .ensure_channel_ready ()
468
512
if isinstance (entities , dict ):
469
513
entities = [entities ]
470
514
request = await self ._prepare_row_upsert_request (
@@ -518,6 +562,7 @@ async def search(
518
562
timeout : Optional [float ] = None ,
519
563
** kwargs ,
520
564
):
565
+ await self .ensure_channel_ready ()
521
566
check_pass_param (
522
567
limit = limit ,
523
568
round_decimal = round_decimal ,
@@ -555,6 +600,7 @@ async def hybrid_search(
555
600
timeout : Optional [float ] = None ,
556
601
** kwargs ,
557
602
):
603
+ await self .ensure_channel_ready ()
558
604
check_pass_param (
559
605
limit = limit ,
560
606
round_decimal = round_decimal ,
@@ -608,7 +654,7 @@ async def create_index(
608
654
collection_desc = await self .describe_collection (
609
655
collection_name , timeout = timeout , ** copy_kwargs
610
656
)
611
-
657
+ await self . ensure_channel_ready ()
612
658
valid_field = False
613
659
for fields in collection_desc ["fields" ]:
614
660
if field_name != fields ["name" ]:
@@ -633,8 +679,67 @@ async def create_index(
633
679
status = await self ._async_stub .CreateIndex (index_param , timeout = timeout )
634
680
check_status (status )
635
681
682
+ index_success , fail_reason = await self .wait_for_creating_index (
683
+ collection_name = collection_name ,
684
+ index_name = index_name ,
685
+ timeout = timeout ,
686
+ field_name = field_name ,
687
+ )
688
+
689
+ if not index_success :
690
+ raise MilvusException (message = fail_reason )
691
+
636
692
return Status (status .code , status .reason )
637
693
694
+ @retry_on_rpc_failure ()
695
+ async def wait_for_creating_index (
696
+ self , collection_name : str , index_name : str , timeout : Optional [float ] = None , ** kwargs
697
+ ):
698
+ timestamp = await self .alloc_timestamp ()
699
+ start = time .time ()
700
+ while True :
701
+ await asyncio .sleep (0.5 )
702
+ state , fail_reason = await self .get_index_state (
703
+ collection_name , index_name , timeout = timeout , timestamp = timestamp , ** kwargs
704
+ )
705
+ if state == IndexState .Finished :
706
+ return True , fail_reason
707
+ if state == IndexState .Failed :
708
+ return False , fail_reason
709
+ end = time .time ()
710
+ if isinstance (timeout , int ) and end - start > timeout :
711
+ msg = (
712
+ f"collection { collection_name } create index { index_name } "
713
+ f"timeout in { timeout } s"
714
+ )
715
+ raise MilvusException (message = msg )
716
+
717
+ @retry_on_rpc_failure ()
718
+ async def get_index_state (
719
+ self ,
720
+ collection_name : str ,
721
+ index_name : str ,
722
+ timeout : Optional [float ] = None ,
723
+ timestamp : Optional [int ] = None ,
724
+ ** kwargs ,
725
+ ):
726
+ request = Prepare .describe_index_request (collection_name , index_name , timestamp )
727
+ response = await self ._async_stub .DescribeIndex (request , timeout = timeout )
728
+ status = response .status
729
+ check_status (status )
730
+
731
+ if len (response .index_descriptions ) == 1 :
732
+ index_desc = response .index_descriptions [0 ]
733
+ return index_desc .state , index_desc .index_state_fail_reason
734
+ # just for create_index.
735
+ field_name = kwargs .pop ("field_name" , "" )
736
+ if field_name != "" :
737
+ for index_desc in response .index_descriptions :
738
+ if index_desc .field_name == field_name :
739
+ return index_desc .state , index_desc .index_state_fail_reason
740
+
741
+ raise AmbiguousIndexName (message = ExceptionsMessage .AmbiguousIndexName )
742
+
638
743
@retry_on_rpc_failure ()
639
744
async def get (
640
745
self ,
@@ -645,6 +750,7 @@ async def get(
645
750
timeout : Optional [float ] = None ,
646
751
):
647
752
# TODO: some check
753
+ await self .ensure_channel_ready ()
648
754
request = Prepare .retrieve_request (collection_name , ids , output_fields , partition_names )
649
755
return await self ._async_stub .Retrieve .get (request , timeout = timeout )
650
756
@@ -658,6 +764,7 @@ async def query(
658
764
timeout : Optional [float ] = None ,
659
765
** kwargs ,
660
766
):
767
+ await self .ensure_channel_ready ()
661
768
if output_fields is not None and not isinstance (output_fields , (list ,)):
662
769
raise ParamError (message = "Invalid query format. 'output_fields' must be a list" )
663
770
request = Prepare .query_request (
@@ -690,15 +797,9 @@ async def query(
690
797
return ExtraList (results , extra = extra_dict )
691
798
692
799
@retry_on_rpc_failure ()
693
- @upgrade_reminder
694
- def __internal_register (self , user : str , host : str , ** kwargs ) -> int :
695
- req = Prepare .register_request (user , host )
696
-
697
- async def wait_for_connect_response ():
698
- return await self ._async_stub .Connect (request = req )
699
-
700
- loop = asyncio .get_event_loop ()
701
- response = loop .run_until_complete (wait_for_connect_response ())
702
-
800
+ @ignore_unimplemented (0 )
801
+ async def alloc_timestamp (self , timeout : Optional [float ] = None ) -> int :
802
+ request = milvus_types .AllocTimestampRequest ()
803
+ response = await self ._async_stub .AllocTimestamp (request , timeout = timeout )
703
804
check_status (response .status )
704
- return response .identifier
805
+ return response .timestamp
0 commit comments