Skip to content

Commit 637e548

Browse files
authored
fix: [cp25]remove default replica_number in load_collection (#2782) (#2786)
The default value should be decided in the server side See also: #2781, milvus-io/milvus#41673, #2782 Signed-off-by: yangxuan <[email protected]>
1 parent 898eac3 commit 637e548

File tree

5 files changed

+129
-153
lines changed

5 files changed

+129
-153
lines changed

pymilvus/client/async_grpc_handler.py

Lines changed: 22 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -266,37 +266,24 @@ async def drop_collection(
266266
async def load_collection(
267267
self,
268268
collection_name: str,
269-
replica_number: int = 1,
269+
replica_number: Optional[int] = None,
270270
timeout: Optional[float] = None,
271271
**kwargs,
272272
):
273273
await self.ensure_channel_ready()
274-
check_pass_param(
275-
collection_name=collection_name, replica_number=replica_number, timeout=timeout
276-
)
277-
refresh = kwargs.get("refresh", kwargs.get("_refresh", False))
278-
resource_groups = kwargs.get("resource_groups", kwargs.get("_resource_groups"))
279-
load_fields = kwargs.get("load_fields", kwargs.get("_load_fields"))
280-
skip_load_dynamic_field = kwargs.get(
281-
"skip_load_dynamic_field", kwargs.get("_skip_load_dynamic_field", False)
282-
)
283274

284-
request = Prepare.load_collection(
285-
"",
286-
collection_name,
287-
replica_number,
288-
refresh,
289-
resource_groups,
290-
load_fields,
291-
skip_load_dynamic_field,
292-
)
275+
check_pass_param(timeout=timeout)
276+
request = Prepare.load_collection(collection_name, replica_number, **kwargs)
293277
response = await self._async_stub.LoadCollection(
294278
request, timeout=timeout, metadata=_api_level_md(**kwargs)
295279
)
296280
check_status(response)
297281

298282
await self.wait_for_loading_collection(
299-
collection_name, timeout, is_refresh=refresh, **kwargs
283+
collection_name=collection_name,
284+
is_refresh=request.refresh,
285+
timeout=timeout,
286+
**kwargs,
300287
)
301288

302289
@retry_on_rpc_failure()
@@ -314,7 +301,10 @@ def can_loop(t: int) -> bool:
314301

315302
while can_loop(time.time()):
316303
progress = await self.get_loading_progress(
317-
collection_name, timeout=timeout, is_refresh=is_refresh, **kwargs
304+
collection_name=collection_name,
305+
is_refresh=is_refresh,
306+
timeout=timeout,
307+
**kwargs,
318308
)
319309
if progress >= 100:
320310
return
@@ -825,41 +815,30 @@ async def load_partitions(
825815
self,
826816
collection_name: str,
827817
partition_names: List[str],
828-
replica_number: int = 1,
818+
replica_number: Optional[int] = None,
829819
timeout: Optional[float] = None,
830820
**kwargs,
831821
):
832822
await self.ensure_channel_ready()
833-
check_pass_param(
834-
collection_name=collection_name,
835-
partition_name_array=partition_names,
836-
replica_number=replica_number,
837-
timeout=timeout,
838-
)
839-
refresh = kwargs.get("refresh", kwargs.get("_refresh", False))
840-
resource_groups = kwargs.get("resource_groups", kwargs.get("_resource_groups"))
841-
load_fields = kwargs.get("load_fields", kwargs.get("_load_fields"))
842-
skip_load_dynamic_field = kwargs.get(
843-
"skip_load_dynamic_field", kwargs.get("_skip_load_dynamic_field", False)
844-
)
823+
check_pass_param(timeout=timeout)
845824

846825
request = Prepare.load_partitions(
847-
"",
848-
collection_name,
849-
partition_names,
850-
replica_number,
851-
refresh,
852-
resource_groups,
853-
load_fields,
854-
skip_load_dynamic_field,
826+
collection_name=collection_name,
827+
partition_names=partition_names,
828+
replica_number=replica_number,
829+
**kwargs,
855830
)
856831
response = await self._async_stub.LoadPartitions(
857832
request, timeout=timeout, metadata=_api_level_md(**kwargs)
858833
)
859834
check_status(response)
860835

861836
await self.wait_for_loading_partitions(
862-
collection_name, partition_names, is_refresh=refresh, **kwargs
837+
collection_name=collection_name,
838+
partition_names=partition_names,
839+
is_refresh=request.is_refresh,
840+
timeout=timeout,
841+
**kwargs,
863842
)
864843

865844
@retry_on_rpc_failure()

pymilvus/client/grpc_handler.py

Lines changed: 42 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,6 @@
3232
from .asynch import (
3333
CreateIndexFuture,
3434
FlushFuture,
35-
LoadPartitionsFuture,
3635
MutationFuture,
3736
SearchFuture,
3837
)
@@ -1263,43 +1262,36 @@ def wait_for_creating_index(
12631262
def load_collection(
12641263
self,
12651264
collection_name: str,
1266-
replica_number: int = 1,
1265+
replica_number: Optional[int] = None,
12671266
timeout: Optional[float] = None,
12681267
**kwargs,
12691268
):
1270-
check_pass_param(
1271-
collection_name=collection_name, replica_number=replica_number, timeout=timeout
1272-
)
1273-
# leading _ is misused for keywork escape for `async`
1274-
# other params now support prefix _ or not
1275-
# params without leading "_" have higher priority
1276-
refresh = kwargs.get("refresh", kwargs.get("_refresh", False))
1277-
resource_groups = kwargs.get("resource_groups", kwargs.get("_resource_groups"))
1278-
load_fields = kwargs.get("load_fields", kwargs.get("_load_fields"))
1279-
skip_load_dynamic_field = kwargs.get(
1280-
"skip_load_dynamic_field", kwargs.get("_skip_load_dynamic_field", False)
1281-
)
1269+
check_pass_param(timeout=timeout)
12821270

1283-
request = Prepare.load_collection(
1284-
"",
1285-
collection_name,
1286-
replica_number,
1287-
refresh,
1288-
resource_groups,
1289-
load_fields,
1290-
skip_load_dynamic_field,
1291-
)
1271+
request = Prepare.load_collection(collection_name, replica_number, **kwargs)
12921272
response = self._stub.LoadCollection(
1293-
request, timeout=timeout, metadata=_api_level_md(**kwargs)
1273+
request,
1274+
timeout=timeout,
1275+
metadata=_api_level_md(**kwargs),
12941276
)
12951277
check_status(response)
1296-
_async = kwargs.get("_async", False)
1297-
if not _async:
1298-
self.wait_for_loading_collection(collection_name, timeout, is_refresh=refresh, **kwargs)
1278+
1279+
if kwargs.get("_async", False):
1280+
return
1281+
1282+
self.wait_for_loading_collection(
1283+
collection_name=collection_name,
1284+
is_refresh=request.refresh,
1285+
timeout=timeout,
1286+
**kwargs,
1287+
)
12991288

13001289
@retry_on_rpc_failure()
13011290
def load_collection_progress(
1302-
self, collection_name: str, timeout: Optional[float] = None, **kwargs
1291+
self,
1292+
collection_name: str,
1293+
timeout: Optional[float] = None,
1294+
**kwargs,
13031295
):
13041296
"""Return loading progress of collection"""
13051297
progress = self.get_loading_progress(collection_name, timeout=timeout)
@@ -1322,7 +1314,10 @@ def can_loop(t: int) -> bool:
13221314

13231315
while can_loop(time.time()):
13241316
progress = self.get_loading_progress(
1325-
collection_name, timeout=timeout, is_refresh=is_refresh, **kwargs
1317+
collection_name,
1318+
is_refresh=is_refresh,
1319+
timeout=timeout,
1320+
**kwargs,
13261321
)
13271322
if progress >= 100:
13281323
return
@@ -1345,66 +1340,30 @@ def load_partitions(
13451340
self,
13461341
collection_name: str,
13471342
partition_names: List[str],
1348-
replica_number: int = 1,
1343+
replica_number: Optional[int] = None,
13491344
timeout: Optional[float] = None,
13501345
**kwargs,
13511346
):
1352-
check_pass_param(
1353-
collection_name=collection_name,
1354-
partition_name_array=partition_names,
1355-
replica_number=replica_number,
1356-
timeout=timeout,
1357-
)
1358-
# leading _ is misused for keywork escape for `async`
1359-
# other params now support prefix _ or not
1360-
# params without leading "_" have higher priority
1361-
refresh = kwargs.get("refresh", kwargs.get("_refresh", False))
1362-
resource_groups = kwargs.get("resource_groups", kwargs.get("_resource_groups"))
1363-
load_fields = kwargs.get("load_fields", kwargs.get("_load_fields"))
1364-
skip_load_dynamic_field = kwargs.get(
1365-
"skip_load_dynamic_field", kwargs.get("_skip_load_dynamic_field", False)
1366-
)
1347+
check_pass_param(timeout=timeout)
13671348

13681349
request = Prepare.load_partitions(
1369-
"",
1370-
collection_name,
1371-
partition_names,
1372-
replica_number,
1373-
refresh,
1374-
resource_groups,
1375-
load_fields,
1376-
skip_load_dynamic_field,
1350+
collection_name=collection_name,
1351+
partition_names=partition_names,
1352+
replica_number=replica_number,
13771353
)
1378-
future = self._stub.LoadPartitions.future(
1354+
response = self._stub.LoadPartitions(
13791355
request, timeout=timeout, metadata=_api_level_md(**kwargs)
13801356
)
1381-
1382-
if kwargs.get("_async", False):
1383-
1384-
def _check():
1385-
if kwargs.get("sync", True):
1386-
self.wait_for_loading_partitions(
1387-
collection_name, partition_names, is_refresh=refresh, **kwargs
1388-
)
1389-
1390-
load_partitions_future = LoadPartitionsFuture(future)
1391-
load_partitions_future.add_callback(_check)
1392-
1393-
user_cb = kwargs.get("_callback")
1394-
if user_cb:
1395-
load_partitions_future.add_callback(user_cb)
1396-
1397-
return load_partitions_future
1398-
1399-
response = future.result()
14001357
check_status(response)
1401-
sync = kwargs.get("sync", True)
1402-
if sync:
1358+
1359+
if kwargs.get("sync", True) or not kwargs.get("_async", False):
14031360
self.wait_for_loading_partitions(
1404-
collection_name, partition_names, is_refresh=refresh, **kwargs
1361+
collection_name=collection_name,
1362+
partition_names=partition_names,
1363+
is_refresh=request.refresh,
1364+
timeout=timeout,
1365+
**kwargs,
14051366
)
1406-
return None
1407-
return None
14081367

14091368
@retry_on_rpc_failure()
14101369
def wait_for_loading_partitions(
@@ -1422,7 +1381,11 @@ def can_loop(t: int) -> bool:
14221381

14231382
while can_loop(time.time()):
14241383
progress = self.get_loading_progress(
1425-
collection_name, partition_names, timeout=timeout, is_refresh=is_refresh, **kwargs
1384+
collection_name=collection_name,
1385+
partition_names=partition_names,
1386+
timeout=timeout,
1387+
is_refresh=is_refresh,
1388+
**kwargs,
14261389
)
14271390
if progress >= 100:
14281391
return

pymilvus/client/prepare.py

Lines changed: 61 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1191,24 +1191,40 @@ def get_index_state_request(cls, collection_name: str, index_name: str):
11911191
@classmethod
11921192
def load_collection(
11931193
cls,
1194-
db_name: str,
11951194
collection_name: str,
1196-
replica_number: int,
1197-
refresh: bool,
1198-
resource_groups: List[str],
1199-
load_fields: List[str],
1200-
skip_load_dynamic_field: bool,
1195+
replica_number: Optional[int] = None,
1196+
**kwargs,
12011197
):
1202-
return milvus_types.LoadCollectionRequest(
1203-
db_name=db_name,
1198+
check_pass_param(collection_name=collection_name)
1199+
req = milvus_types.LoadCollectionRequest(
12041200
collection_name=collection_name,
1205-
replica_number=replica_number,
1206-
refresh=refresh,
1207-
resource_groups=resource_groups,
1208-
load_fields=load_fields,
1209-
skip_load_dynamic_field=skip_load_dynamic_field,
12101201
)
12111202

1203+
if replica_number:
1204+
check_pass_param(replica_number=replica_number)
1205+
req.replica_number = replica_number
1206+
1207+
# Keep underscore key for backward compatibility
1208+
if "refresh" in kwargs or "_refresh" in kwargs:
1209+
refresh = kwargs.get("refresh", kwargs.get("_refresh", False))
1210+
req.refresh = refresh
1211+
1212+
if "resource_groups" in kwargs or "_resource_groups" in kwargs:
1213+
resource_groups = kwargs.get("resource_groups", kwargs.get("_resource_groups"))
1214+
req.resource_groups = resource_groups
1215+
1216+
if "load_fields" in kwargs or "_load_fields" in kwargs:
1217+
load_fields = kwargs.get("load_fields", kwargs.get("_load_fields"))
1218+
req.load_fields = load_fields
1219+
1220+
if "skip_load_dynamic_field" in kwargs or "_skip_load_dynamic_field" in kwargs:
1221+
skip_load_dynamic_field = kwargs.get(
1222+
"skip_load_dynamic_field", kwargs.get("_skip_load_dynamic_field", False)
1223+
)
1224+
req.skip_load_dynamic_field = skip_load_dynamic_field
1225+
1226+
return req
1227+
12121228
@classmethod
12131229
def release_collection(cls, db_name: str, collection_name: str):
12141230
return milvus_types.ReleaseCollectionRequest(
@@ -1218,26 +1234,44 @@ def release_collection(cls, db_name: str, collection_name: str):
12181234
@classmethod
12191235
def load_partitions(
12201236
cls,
1221-
db_name: str,
12221237
collection_name: str,
12231238
partition_names: List[str],
1224-
replica_number: int,
1225-
refresh: bool,
1226-
resource_groups: List[str],
1227-
load_fields: List[str],
1228-
skip_load_dynamic_field: bool,
1239+
replica_number: Optional[int] = None,
1240+
**kwargs,
12291241
):
1230-
return milvus_types.LoadPartitionsRequest(
1231-
db_name=db_name,
1242+
check_pass_param(collection_name=collection_name)
1243+
req = milvus_types.LoadPartitionsRequest(
12321244
collection_name=collection_name,
1233-
partition_names=partition_names,
1234-
replica_number=replica_number,
1235-
refresh=refresh,
1236-
resource_groups=resource_groups,
1237-
load_fields=load_fields,
1238-
skip_load_dynamic_field=skip_load_dynamic_field,
12391245
)
12401246

1247+
if partition_names:
1248+
check_pass_param(partition_name_array=partition_names)
1249+
req.partition_names.extend(partition_names)
1250+
1251+
if replica_number:
1252+
check_pass_param(replica_number=replica_number)
1253+
req.replica_number = replica_number
1254+
1255+
# Keep underscore key for backward compatibility
1256+
if "refresh" in kwargs or "_refresh" in kwargs:
1257+
refresh = kwargs.get("refresh", kwargs.get("_refresh", False))
1258+
req.refresh = refresh
1259+
1260+
if "resource_groups" in kwargs or "_resource_groups" in kwargs:
1261+
resource_groups = kwargs.get("resource_groups", kwargs.get("_resource_groups"))
1262+
req.resource_groups = resource_groups
1263+
1264+
if "load_fields" in kwargs or "_load_fields" in kwargs:
1265+
load_fields = kwargs.get("load_fields", kwargs.get("_load_fields"))
1266+
req.load_fields = load_fields
1267+
1268+
if "skip_load_dynamic_field" in kwargs or "_skip_load_dynamic_field" in kwargs:
1269+
skip_load_dynamic_field = kwargs.get(
1270+
"skip_load_dynamic_field", kwargs.get("_skip_load_dynamic_field", False)
1271+
)
1272+
req.skip_load_dynamic_field = skip_load_dynamic_field
1273+
return req
1274+
12411275
@classmethod
12421276
def release_partitions(cls, db_name: str, collection_name: str, partition_names: List[str]):
12431277
return milvus_types.ReleasePartitionsRequest(

0 commit comments

Comments
 (0)