16
16
import pandas as pd
17
17
from pandas .api .types import is_list_like , is_scalar
18
18
19
+ from pymilvus .client .types import FunctionType
19
20
from pymilvus .exceptions import (
20
21
AutoIDException ,
21
22
CannotInferSchemaException ,
24
25
DataTypeNotSupportException ,
25
26
ExceptionsMessage ,
26
27
FieldsTypeException ,
27
- FunctionsTypeException ,
28
28
FieldTypeException ,
29
+ FunctionsTypeException ,
29
30
ParamError ,
30
31
PartitionKeyException ,
31
32
PrimaryKeyException ,
32
33
SchemaNotReadyException ,
33
34
)
34
35
from pymilvus .grpc_gen import schema_pb2 as schema_types
35
36
36
- from .constants import COMMON_TYPE_PARAMS , BM25_k1 , BM25_b
37
+ from .constants import COMMON_TYPE_PARAMS
37
38
from .types import (
38
39
DataType ,
39
- FunctionType ,
40
40
infer_dtype_by_scalar_data ,
41
41
infer_dtype_bydata ,
42
42
map_numpy_dtype_to_datatype ,
@@ -87,7 +87,9 @@ def validate_clustering_key(clustering_key_field_name: Any, clustering_key_field
87
87
88
88
89
89
class CollectionSchema :
90
- def __init__ (self , fields : List , description : str = "" , functions : List = [], ** kwargs ):
90
+ def __init__ (
91
+ self , fields : List , description : str = "" , functions : Optional [List ] = None , ** kwargs
92
+ ):
91
93
self ._kwargs = copy .deepcopy (kwargs )
92
94
self ._fields = []
93
95
self ._description = description
@@ -97,6 +99,9 @@ def __init__(self, fields: List, description: str = "", functions: List = [], **
97
99
self ._partition_key_field = None
98
100
self ._clustering_key_field = None
99
101
102
+ if functions is None :
103
+ functions = []
104
+
100
105
if not isinstance (functions , list ):
101
106
raise FunctionsTypeException (message = ExceptionsMessage .FunctionsType )
102
107
for function in functions :
@@ -352,7 +357,7 @@ def add_field(self, field_name: str, datatype: DataType, **kwargs):
352
357
self ._mark_output_fields ()
353
358
return self
354
359
355
- def add_function (self , function ):
360
+ def add_function (self , function : "Function" ):
356
361
if not isinstance (function , Function ):
357
362
raise ParamError (message = ExceptionsMessage .FunctionIncorrectType )
358
363
self ._functions .append (function )
@@ -459,7 +464,6 @@ def to_dict(self):
459
464
"name" : self .name ,
460
465
"description" : self ._description ,
461
466
"type" : self .dtype ,
462
- "is_function_output" : self .is_function_output ,
463
467
}
464
468
if self ._type_params :
465
469
_dict ["params" ] = copy .deepcopy (self .params )
@@ -480,6 +484,8 @@ def to_dict(self):
480
484
_dict ["element_type" ] = self .element_type
481
485
if self .is_clustering_key :
482
486
_dict ["is_clustering_key" ] = True
487
+ if self .is_function_output :
488
+ _dict ["is_function_output" ] = True
483
489
return _dict
484
490
485
491
def __getattr__ (self , item : str ):
@@ -537,36 +543,38 @@ def __init__(
537
543
self ,
538
544
name : str ,
539
545
function_type : FunctionType ,
540
- inputs : List [str ],
541
- outputs : List [str ],
546
+ input_field_names : Union [ str , List [str ] ],
547
+ output_field_names : Union [ str , List [str ] ],
542
548
description : str = "" ,
543
- params : Dict = {} ,
549
+ params : Optional [ Dict ] = None ,
544
550
):
545
551
self ._name = name
546
552
self ._description = description
553
+ input_field_names = (
554
+ [input_field_names ] if isinstance (input_field_names , str ) else input_field_names
555
+ )
556
+ output_field_names = (
557
+ [output_field_names ] if isinstance (output_field_names , str ) else output_field_names
558
+ )
547
559
try :
548
560
self ._type = FunctionType (function_type )
549
- except ValueError :
550
- raise ParamError (message = ExceptionsMessage .UnknownFunctionType )
561
+ except ValueError as err :
562
+ raise ParamError (message = ExceptionsMessage .UnknownFunctionType ) from err
551
563
552
- for field_name in list (inputs ) + list (outputs ):
564
+ for field_name in list (input_field_names ) + list (output_field_names ):
553
565
if not isinstance (field_name , str ):
554
566
raise ParamError (message = ExceptionsMessage .FunctionIncorrectInputOutputType )
555
- if len (inputs ) != len (set (inputs )):
567
+ if len (input_field_names ) != len (set (input_field_names )):
556
568
raise ParamError (message = ExceptionsMessage .FunctionDuplicateInputs )
557
- if len (outputs ) != len (set (outputs )):
569
+ if len (output_field_names ) != len (set (output_field_names )):
558
570
raise ParamError (message = ExceptionsMessage .FunctionDuplicateOutputs )
559
571
560
- if set (inputs ) & set (outputs ):
572
+ if set (input_field_names ) & set (output_field_names ):
561
573
raise ParamError (message = ExceptionsMessage .FunctionCommonInputOutput )
562
574
563
- self ._input_field_names = inputs
564
- self ._output_field_names = outputs
565
- if BM25_k1 in params :
566
- params [BM25_k1 ] = str (params [BM25_k1 ])
567
- if BM25_b in params :
568
- params [BM25_b ] = str (params [BM25_b ])
569
- self ._params = params
575
+ self ._input_field_names = input_field_names
576
+ self ._output_field_names = output_field_names
577
+ self ._params = params if params is not None else {}
570
578
571
579
@property
572
580
def name (self ):
@@ -598,16 +606,13 @@ def verify(self, schema: CollectionSchema):
598
606
raise ParamError (message = ExceptionsMessage .BM25FunctionIncorrectInputOutputCount )
599
607
600
608
for field in schema .fields :
601
- if field .name == self ._input_field_names [0 ]:
602
- if field .dtype != DataType .VARCHAR :
603
- raise ParamError (
604
- message = ExceptionsMessage .BM25FunctionIncorrectInputFieldType
605
- )
606
- if field .name == self ._output_field_names [0 ]:
607
- if field .dtype != DataType .SPARSE_FLOAT_VECTOR :
608
- raise ParamError (
609
- message = ExceptionsMessage .BM25FunctionIncorrectOutputFieldType
610
- )
609
+ if field .name == self ._input_field_names [0 ] and field .dtype != DataType .VARCHAR :
610
+ raise ParamError (message = ExceptionsMessage .BM25FunctionIncorrectInputFieldType )
611
+ if (
612
+ field .name == self ._output_field_names [0 ]
613
+ and field .dtype != DataType .SPARSE_FLOAT_VECTOR
614
+ ):
615
+ raise ParamError (message = ExceptionsMessage .BM25FunctionIncorrectOutputFieldType )
611
616
612
617
elif self ._type == FunctionType .UNKNOWN :
613
618
raise ParamError (message = ExceptionsMessage .UnknownFunctionType )
0 commit comments