Skip to content

Commit b39aeef

Browse files
authored
fix: Raise ParamError if index params wrong (#2718)
See also: milvus-io/milvus#40564 Signed-off-by: yangxuan <[email protected]> --------- Signed-off-by: yangxuan <[email protected]>
1 parent f0e1072 commit b39aeef

File tree

7 files changed

+156
-55
lines changed

7 files changed

+156
-55
lines changed

examples/simple_async.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -122,4 +122,4 @@ async def other_async_task(collection_name):
122122

123123
results = loop.run_until_complete(other_async_task(collection_name))
124124
for r in results:
125-
print(r)
125+
print(r)

pymilvus/milvus_client/async_milvus_client.py

Lines changed: 23 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,8 @@
2121
from pymilvus.orm.connections import connections
2222
from pymilvus.orm.types import DataType
2323

24-
from .index import IndexParams
24+
from .check import validate_param
25+
from .index import IndexParam, IndexParams
2526

2627
logger = logging.getLogger(__name__)
2728
logger.setLevel(logging.DEBUG)
@@ -124,7 +125,7 @@ async def _fast_create_collection(
124125
raise ex from ex
125126

126127
index_params = IndexParams()
127-
index_params.add_index(vector_field_name, "", "", metric_type=metric_type)
128+
index_params.add_index(vector_field_name, index_type="AUTOINDEX", metric_type=metric_type)
128129
await self.create_index(collection_name, index_params, timeout=timeout)
129130
await self.load_collection(collection_name, timeout=timeout)
130131

@@ -186,24 +187,29 @@ async def create_index(
186187
timeout: Optional[float] = None,
187188
**kwargs,
188189
):
190+
validate_param("collection_name", collection_name, str)
191+
validate_param("index_params", index_params, IndexParams)
192+
if len(index_params) == 0:
193+
raise ParamError(message="IndexParams is empty, no index can be created")
194+
189195
for index_param in index_params:
190196
await self._create_index(collection_name, index_param, timeout=timeout, **kwargs)
191197

192198
async def _create_index(
193-
self, collection_name: str, index_param: Dict, timeout: Optional[float] = None, **kwargs
199+
self,
200+
collection_name: str,
201+
index_param: IndexParam,
202+
timeout: Optional[float] = None,
203+
**kwargs,
194204
):
195205
conn = self._get_connection()
196206
try:
197-
params = index_param.pop("params", {})
198-
field_name = index_param.pop("field_name", "")
199-
index_name = index_param.pop("index_name", "")
200-
params.update(index_param)
201207
await conn.create_index(
202208
collection_name,
203-
field_name,
204-
params,
209+
index_param.field_name,
210+
index_param.get_index_configs(),
205211
timeout=timeout,
206-
index_name=index_name,
212+
index_name=index_param.index_name,
207213
**kwargs,
208214
)
209215
logger.debug("Successfully created an index on collection: %s", collection_name)
@@ -571,6 +577,13 @@ def create_schema(cls, **kwargs):
571577
kwargs["check_fields"] = False # do not check fields for now
572578
return CollectionSchema([], **kwargs)
573579

580+
@classmethod
581+
def prepare_index_params(cls, field_name: str = "", **kwargs) -> IndexParams:
582+
index_params = IndexParams()
583+
if field_name and validate_param("field_name", field_name, str):
584+
index_params.add_index(field_name, **kwargs)
585+
return index_params
586+
574587
async def close(self):
575588
await connections.async_disconnect(self._using)
576589

pymilvus/milvus_client/check.py

Lines changed: 29 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,32 @@
1-
from typing import Any
1+
from typing import Any, Dict, Tuple, Union
22

3+
from pymilvus.exceptions import ParamError
4+
5+
6+
def validate_params(params: Dict[str, Any], expected_type: Union[type, Tuple[type, ...]]):
7+
validate_param("params", params, Dict)
8+
for param_name, param in params.items():
9+
validate_param(param_name, param, expected_type)
10+
11+
12+
def validate_param(param_name: str, param: Any, expected_type: Union[type, Tuple[type, ...]]):
13+
if param is None:
14+
msg = f"missing required argument: [{param_name}]"
15+
raise ParamError(message=msg)
316

4-
def check_param_type(param_name: str, param: Any, expected_type: Any, ignore_none: bool = True):
5-
if ignore_none and param is None:
6-
return
717
if not isinstance(param, expected_type):
8-
msg = f"wrong type of arugment '{param_name}', "
9-
msg += f"expected '{expected_type.__name__}', "
10-
msg += f"got '{type(param).__name__}'"
11-
raise TypeError(msg)
18+
msg = (
19+
f"wrong type of argument [{param_name}], "
20+
f"expected type: [{expected_type.__name__}], "
21+
f"got type: [{type(param).__name__}]"
22+
)
23+
raise ParamError(message=msg)
24+
25+
26+
def validate_noneable_param(
27+
param_name: str, param: Any, expected_type: Union[type, Tuple[type, ...]]
28+
):
29+
if param is None:
30+
return
31+
32+
validate_param(param_name, param, expected_type)

pymilvus/milvus_client/index.py

Lines changed: 53 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,30 @@
1+
from typing import Dict
2+
3+
14
class IndexParam:
25
def __init__(self, field_name: str, index_type: str, index_name: str, **kwargs):
6+
"""
7+
Examples:
8+
9+
>>> IndexParam(
10+
>>> field_name="embeddings",
11+
>>> index_type="HNSW",
12+
>>> index_name="hnsw_index",
13+
>>> metric_type="COSINE",
14+
>>> M=64,
15+
>>> efConstruction=100,
16+
>>> )
17+
"""
318
self._field_name = field_name
419
self._index_type = index_type
520
self._index_name = index_name
6-
self._kwargs = kwargs
21+
22+
# index configs are unique to each index,
23+
# if params={} is passed in, it will be flattened and merged
24+
# with other configs.
25+
self._configs = {}
26+
self._configs.update(kwargs.pop("params", {}))
27+
self._configs.update(kwargs)
728

829
@property
930
def field_name(self):
@@ -17,12 +38,31 @@ def index_name(self):
1738
def index_type(self):
1839
return self._index_type
1940

20-
def __iter__(self):
21-
yield "field_name", self.field_name
22-
if self.index_type:
23-
yield "index_type", self.index_type
24-
yield "index_name", self.index_name
25-
yield from self._kwargs.items()
41+
def get_index_configs(self) -> Dict:
42+
"""return index_type and index configs in a dict
43+
44+
Examples:
45+
46+
{
47+
"index_type": "HNSW",
48+
"metrics_type": "COSINE",
49+
"M": 64,
50+
"efConstruction": 100,
51+
}
52+
"""
53+
return {
54+
"index_type": self.index_type,
55+
**self._configs,
56+
}
57+
58+
def to_dict(self) -> Dict:
59+
"""All params"""
60+
return {
61+
"field_name": self.field_name,
62+
"index_type": self.index_type,
63+
"index_name": self.index_name,
64+
**self._configs,
65+
}
2666

2767
def __str__(self):
2868
return str(dict(self))
@@ -36,19 +76,15 @@ def __eq__(self, other: None):
3676
return False
3777

3878

39-
class IndexParams:
40-
def __init__(self, field_name: str = "", **kwargs):
41-
self._indexes = []
42-
if field_name:
43-
self.add_index(field_name, **kwargs)
79+
class IndexParams(list):
80+
"""List of indexs of a collection"""
81+
82+
def __init__(self, *args, **kwargs):
83+
super().__init__(*args, **kwargs)
4484

4585
def add_index(self, field_name: str, index_type: str = "", index_name: str = "", **kwargs):
4686
index_param = IndexParam(field_name, index_type, index_name, **kwargs)
47-
self._indexes.append(index_param)
48-
49-
def __iter__(self):
50-
for v in self._indexes:
51-
yield dict(v)
87+
super().append(index_param)
5288

5389
def __str__(self):
5490
return str(list(self))

pymilvus/milvus_client/milvus_client.py

Lines changed: 24 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
"""MilvusClient for dealing with simple workflows."""
2-
31
import logging
42
from typing import Dict, List, Optional, Union
53
from uuid import uuid4
@@ -32,7 +30,8 @@
3230
from pymilvus.orm.iterator import QueryIterator, SearchIterator
3331
from pymilvus.orm.types import DataType
3432

35-
from .index import IndexParams
33+
from .check import validate_param
34+
from .index import IndexParam, IndexParams
3635

3736
logger = logging.getLogger(__name__)
3837

@@ -112,9 +111,8 @@ def _fast_create_collection(
112111
timeout: Optional[float] = None,
113112
**kwargs,
114113
):
115-
if dimension is None:
116-
msg = "missing requried argument: 'dimension'"
117-
raise TypeError(msg)
114+
validate_param("dimension", dimension, int)
115+
118116
if "enable_dynamic_field" not in kwargs:
119117
kwargs["enable_dynamic_field"] = True
120118

@@ -132,8 +130,7 @@ def _fast_create_collection(
132130
pk_args["max_length"] = kwargs["max_length"]
133131

134132
schema.add_field(primary_field_name, pk_data_type, is_primary=True, **pk_args)
135-
vector_type = DataType.FLOAT_VECTOR
136-
schema.add_field(vector_field_name, vector_type, dim=dimension)
133+
schema.add_field(vector_field_name, DataType.FLOAT_VECTOR, dim=dimension)
137134
schema.verify()
138135

139136
conn = self._get_connection()
@@ -147,7 +144,7 @@ def _fast_create_collection(
147144
raise ex from ex
148145

149146
index_params = IndexParams()
150-
index_params.add_index(vector_field_name, "", "", metric_type=metric_type)
147+
index_params.add_index(vector_field_name, index_type="AUTOINDEX", metric_type=metric_type)
151148
self.create_index(collection_name, index_params, timeout=timeout)
152149
self.load_collection(collection_name, timeout=timeout)
153150

@@ -158,24 +155,29 @@ def create_index(
158155
timeout: Optional[float] = None,
159156
**kwargs,
160157
):
158+
validate_param("collection_name", collection_name, str)
159+
validate_param("index_params", index_params, IndexParams)
160+
if len(index_params) == 0:
161+
raise ParamError(message="IndexParams is empty, no index can be created")
162+
161163
for index_param in index_params:
162164
self._create_index(collection_name, index_param, timeout=timeout, **kwargs)
163165

164166
def _create_index(
165-
self, collection_name: str, index_param: Dict, timeout: Optional[float] = None, **kwargs
167+
self,
168+
collection_name: str,
169+
index_param: IndexParam,
170+
timeout: Optional[float] = None,
171+
**kwargs,
166172
):
167173
conn = self._get_connection()
168174
try:
169-
params = index_param.pop("params", {})
170-
field_name = index_param.pop("field_name", "")
171-
index_name = index_param.pop("index_name", "")
172-
params.update(index_param)
173175
conn.create_index(
174176
collection_name,
175-
field_name,
176-
params,
177+
index_param.field_name,
178+
index_param.get_index_configs(),
177179
timeout=timeout,
178-
index_name=index_name,
180+
index_name=index_param.index_name,
179181
**kwargs,
180182
)
181183
logger.debug("Successfully created an index on collection: %s", collection_name)
@@ -872,8 +874,11 @@ def create_schema(cls, **kwargs):
872874
return CollectionSchema([], **kwargs)
873875

874876
@classmethod
875-
def prepare_index_params(cls, field_name: str = "", **kwargs):
876-
return IndexParams(field_name, **kwargs)
877+
def prepare_index_params(cls, field_name: str = "", **kwargs) -> IndexParams:
878+
index_params = IndexParams()
879+
if field_name and validate_param("field_name", field_name, str):
880+
index_params.add_index(field_name, **kwargs)
881+
return index_params
877882

878883
def _create_collection_with_schema(
879884
self,

tests/conftest.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44

55
# https://github.com/grpc/grpc/blob/5918f98ecbf5ace77f30fa97f7fc3e8bdac08e04/src/python/grpcio_tests/tests/testing/_client_test.py
66
from grpc.framework.foundation import logging_pool
7+
import logging
8+
logging.getLogger("faker").setLevel(logging.WARNING)
79

810
from pymilvus.grpc_gen import milvus_pb2
911

tests/test_milvus_client.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
from unittest.mock import patch
2+
3+
import pytest
4+
5+
from pymilvus.exceptions import ParamError
6+
from pymilvus.milvus_client.index import IndexParams
7+
from pymilvus.milvus_client.milvus_client import MilvusClient
8+
9+
10+
class TestMilvusClient:
11+
@pytest.mark.parametrize("index_params", [None, {}, "str", MilvusClient.prepare_index_params()])
12+
def test_create_index_invalid_params(self, index_params):
13+
with patch("pymilvus.orm.utility.get_server_type", return_value="milvus"), patch('pymilvus.milvus_client.milvus_client.MilvusClient._create_connection', return_value="test"):
14+
client = MilvusClient()
15+
16+
if isinstance(index_params, IndexParams):
17+
with pytest.raises(ParamError, match="IndexParams is empty, no index can be created"):
18+
client.create_index("test_collection", index_params)
19+
elif index_params is None:
20+
with pytest.raises(ParamError, match="missing required argument:.*"):
21+
client.create_index("test_collection", index_params)
22+
else:
23+
with pytest.raises(ParamError, match="wrong type of argument .*"):
24+
client.create_index("test_collection", index_params)

0 commit comments

Comments
 (0)