Skip to content

Commit f70c29b

Browse files
committed
update Funcion interface to use input/output_field_names
Signed-off-by: Buqian Zheng <[email protected]>
1 parent 26cc8dd commit f70c29b

File tree

6 files changed

+48
-48
lines changed

6 files changed

+48
-48
lines changed

examples/hello_bm25.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -57,9 +57,8 @@
5757
bm25_function = Function(
5858
name="bm25",
5959
function_type=FunctionType.BM25,
60-
inputs=["document"],
61-
outputs=["sparse"],
62-
params={"bm25_k1": 1.2, "bm25_b": 0.75},
60+
input_field_names=["document"],
61+
output_field_names="sparse",
6362
)
6463

6564
schema = CollectionSchema(fields, "hello_bm25 demo")
@@ -95,13 +94,12 @@
9594

9695
################################################################################
9796
# 4. create index
98-
# We are going to create an SPARSE_INVERTED_INDEX index for hello_bm25 collection.
99-
# create_index() can only be applied to `FloatVector` and `BinaryVector` fields.
100-
print(fmt.format("Start Creating index SPARSE_INVERTED_INDEX"))
97+
# We are going to create an index for hello_bm25 collection, here we simply
98+
# uses AUTOINDEX so Milvus can use the default parameters.
99+
print(fmt.format("Start Creating index AUTOINDEX"))
101100
index = {
102-
"index_type": "SPARSE_INVERTED_INDEX",
101+
"index_type": "AUTOINDEX",
103102
"metric_type": "BM25",
104-
'params': {"bm25_k1": 1.2, "bm25_b": 0.75},
105103
}
106104

107105
hello_bm25.create_index("sparse", index)

examples/hello_hybrid_bm25.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -103,9 +103,8 @@ def random_embedding(texts):
103103
Function(
104104
name="bm25",
105105
function_type=FunctionType.BM25,
106-
inputs=["text"],
107-
outputs=["sparse_vector"],
108-
params={"bm25_k1": 1.2, "bm25_b": 0.75},
106+
input_field_names=["text"],
107+
output_field_names="sparse_vector",
109108
)
110109
]
111110
schema = CollectionSchema(fields, "", functions=functions)

examples/milvus_client/bm25.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,9 @@
2020

2121
bm25_function = Function(
2222
name="bm25_fn",
23-
inputs=["document_content"],
24-
outputs=["sparse_vector"],
23+
input_field_names=["document_content"],
24+
output_field_names="sparse_vector",
2525
function_type=FunctionType.BM25,
26-
params={"bm25_k1": 1.2, "bm25_b": 0.75},
2726
)
2827
schema.add_function(bm25_function)
2928

pymilvus/orm/prepare.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
# or implied. See the License for the specific language governing permissions and limitations under
1111
# the License.
1212

13-
import copy
1413
from typing import List, Tuple, Union
1514

1615
import numpy as np

pymilvus/orm/schema.py

Lines changed: 37 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import pandas as pd
1717
from pandas.api.types import is_list_like, is_scalar
1818

19+
from pymilvus.client.types import FunctionType
1920
from pymilvus.exceptions import (
2021
AutoIDException,
2122
CannotInferSchemaException,
@@ -24,19 +25,18 @@
2425
DataTypeNotSupportException,
2526
ExceptionsMessage,
2627
FieldsTypeException,
27-
FunctionsTypeException,
2828
FieldTypeException,
29+
FunctionsTypeException,
2930
ParamError,
3031
PartitionKeyException,
3132
PrimaryKeyException,
3233
SchemaNotReadyException,
3334
)
3435
from pymilvus.grpc_gen import schema_pb2 as schema_types
3536

36-
from .constants import COMMON_TYPE_PARAMS, BM25_k1, BM25_b
37+
from .constants import COMMON_TYPE_PARAMS
3738
from .types import (
3839
DataType,
39-
FunctionType,
4040
infer_dtype_by_scalar_data,
4141
infer_dtype_bydata,
4242
map_numpy_dtype_to_datatype,
@@ -87,7 +87,9 @@ def validate_clustering_key(clustering_key_field_name: Any, clustering_key_field
8787

8888

8989
class CollectionSchema:
90-
def __init__(self, fields: List, description: str = "", functions: List = [], **kwargs):
90+
def __init__(
91+
self, fields: List, description: str = "", functions: Optional[List] = None, **kwargs
92+
):
9193
self._kwargs = copy.deepcopy(kwargs)
9294
self._fields = []
9395
self._description = description
@@ -97,6 +99,9 @@ def __init__(self, fields: List, description: str = "", functions: List = [], **
9799
self._partition_key_field = None
98100
self._clustering_key_field = None
99101

102+
if functions is None:
103+
functions = []
104+
100105
if not isinstance(functions, list):
101106
raise FunctionsTypeException(message=ExceptionsMessage.FunctionsType)
102107
for function in functions:
@@ -352,7 +357,7 @@ def add_field(self, field_name: str, datatype: DataType, **kwargs):
352357
self._mark_output_fields()
353358
return self
354359

355-
def add_function(self, function):
360+
def add_function(self, function: "Function"):
356361
if not isinstance(function, Function):
357362
raise ParamError(message=ExceptionsMessage.FunctionIncorrectType)
358363
self._functions.append(function)
@@ -459,7 +464,6 @@ def to_dict(self):
459464
"name": self.name,
460465
"description": self._description,
461466
"type": self.dtype,
462-
"is_function_output": self.is_function_output,
463467
}
464468
if self._type_params:
465469
_dict["params"] = copy.deepcopy(self.params)
@@ -480,6 +484,8 @@ def to_dict(self):
480484
_dict["element_type"] = self.element_type
481485
if self.is_clustering_key:
482486
_dict["is_clustering_key"] = True
487+
if self.is_function_output:
488+
_dict["is_function_output"] = True
483489
return _dict
484490

485491
def __getattr__(self, item: str):
@@ -537,36 +543,38 @@ def __init__(
537543
self,
538544
name: str,
539545
function_type: FunctionType,
540-
inputs: List[str],
541-
outputs: List[str],
546+
input_field_names: Union[str, List[str]],
547+
output_field_names: Union[str, List[str]],
542548
description: str = "",
543-
params: Dict = {},
549+
params: Optional[Dict] = None,
544550
):
545551
self._name = name
546552
self._description = description
553+
input_field_names = (
554+
[input_field_names] if isinstance(input_field_names, str) else input_field_names
555+
)
556+
output_field_names = (
557+
[output_field_names] if isinstance(output_field_names, str) else output_field_names
558+
)
547559
try:
548560
self._type = FunctionType(function_type)
549-
except ValueError:
550-
raise ParamError(message=ExceptionsMessage.UnknownFunctionType)
561+
except ValueError as err:
562+
raise ParamError(message=ExceptionsMessage.UnknownFunctionType) from err
551563

552-
for field_name in list(inputs) + list(outputs):
564+
for field_name in list(input_field_names) + list(output_field_names):
553565
if not isinstance(field_name, str):
554566
raise ParamError(message=ExceptionsMessage.FunctionIncorrectInputOutputType)
555-
if len(inputs) != len(set(inputs)):
567+
if len(input_field_names) != len(set(input_field_names)):
556568
raise ParamError(message=ExceptionsMessage.FunctionDuplicateInputs)
557-
if len(outputs) != len(set(outputs)):
569+
if len(output_field_names) != len(set(output_field_names)):
558570
raise ParamError(message=ExceptionsMessage.FunctionDuplicateOutputs)
559571

560-
if set(inputs) & set(outputs):
572+
if set(input_field_names) & set(output_field_names):
561573
raise ParamError(message=ExceptionsMessage.FunctionCommonInputOutput)
562574

563-
self._input_field_names = inputs
564-
self._output_field_names = outputs
565-
if BM25_k1 in params:
566-
params[BM25_k1] = str(params[BM25_k1])
567-
if BM25_b in params:
568-
params[BM25_b] = str(params[BM25_b])
569-
self._params = params
575+
self._input_field_names = input_field_names
576+
self._output_field_names = output_field_names
577+
self._params = params if params is not None else {}
570578

571579
@property
572580
def name(self):
@@ -598,16 +606,13 @@ def verify(self, schema: CollectionSchema):
598606
raise ParamError(message=ExceptionsMessage.BM25FunctionIncorrectInputOutputCount)
599607

600608
for field in schema.fields:
601-
if field.name == self._input_field_names[0]:
602-
if field.dtype != DataType.VARCHAR:
603-
raise ParamError(
604-
message=ExceptionsMessage.BM25FunctionIncorrectInputFieldType
605-
)
606-
if field.name == self._output_field_names[0]:
607-
if field.dtype != DataType.SPARSE_FLOAT_VECTOR:
608-
raise ParamError(
609-
message=ExceptionsMessage.BM25FunctionIncorrectOutputFieldType
610-
)
609+
if field.name == self._input_field_names[0] and field.dtype != DataType.VARCHAR:
610+
raise ParamError(message=ExceptionsMessage.BM25FunctionIncorrectInputFieldType)
611+
if (
612+
field.name == self._output_field_names[0]
613+
and field.dtype != DataType.SPARSE_FLOAT_VECTOR
614+
):
615+
raise ParamError(message=ExceptionsMessage.BM25FunctionIncorrectOutputFieldType)
611616

612617
elif self._type == FunctionType.UNKNOWN:
613618
raise ParamError(message=ExceptionsMessage.UnknownFunctionType)

pymilvus/orm/types.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
is_scalar,
2222
)
2323

24-
from pymilvus.client.types import DataType, FunctionType
24+
from pymilvus.client.types import DataType
2525

2626
dtype_str_map = {
2727
"string": DataType.VARCHAR,

0 commit comments

Comments
 (0)