Skip to content

Commit 6e6183d

Browse files
committed
remove scipy dependency for sparse while still supporting it
Signed-off-by: Buqian Zheng <[email protected]>
1 parent 97f12ae commit 6e6183d

File tree

11 files changed

+157
-113
lines changed

11 files changed

+157
-113
lines changed

examples/hello_sparse.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
import time
1111

1212
import numpy as np
13-
from scipy.sparse import rand
13+
import random
1414
from pymilvus import (
1515
connections,
1616
utility,
@@ -20,7 +20,9 @@
2020

2121
fmt = "=== {:30} ==="
2222
search_latency_fmt = "search latency = {:.4f}s"
23-
num_entities, dim, density = 1000, 3000, 0.005
23+
num_entities, dim = 1000, 3000
24+
# non zero count of randomly generated sparse vectors
25+
nnz = 30
2426

2527
def log(msg):
2628
print(time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()) + " " + msg)
@@ -54,11 +56,16 @@ def log(msg):
5456
# insert
5557
log(fmt.format("Start creating entities to insert"))
5658
rng = np.random.default_rng(seed=19530)
57-
# this step is so damn slow
58-
matrix_csr = rand(num_entities, dim, density=density, format='csr')
59+
60+
def generate_sparse_vector(dimension: int, non_zero_count: int) -> dict:
61+
indices = random.sample(range(dimension), non_zero_count)
62+
values = [random.random() for _ in range(non_zero_count)]
63+
sparse_vector = {index: value for index, value in zip(indices, values)}
64+
return sparse_vector
65+
5966
entities = [
6067
rng.random(num_entities).tolist(),
61-
matrix_csr,
68+
[generate_sparse_vector(dim, nnz) for _ in range(num_entities)],
6269
]
6370

6471
log(fmt.format("Start inserting entities"))

pymilvus/client/abstract.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from pymilvus.grpc_gen import common_pb2, schema_pb2
88
from pymilvus.settings import Config
99

10-
from . import entity_helper
10+
from . import entity_helper, utils
1111
from .constants import DEFAULT_CONSISTENCY_LEVEL, RANKER_TYPE_RRF, RANKER_TYPE_WEIGHTED
1212
from .types import DataType
1313

@@ -337,7 +337,7 @@ def dict(self):
337337
class AnnSearchRequest:
338338
def __init__(
339339
self,
340-
data: Union[List, entity_helper.SparseMatrixInputType],
340+
data: Union[List, utils.SparseMatrixInputType],
341341
anns_field: str,
342342
param: Dict,
343343
limit: int,

pymilvus/client/entity_helper.py

Lines changed: 26 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,9 @@
11
import math
22
import struct
3-
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
3+
from typing import Any, Dict, Iterable, List, Optional
44

55
import numpy as np
66
import ujson
7-
from scipy import sparse
87

98
from pymilvus.exceptions import (
109
DataNotMatchException,
@@ -16,67 +15,13 @@
1615
from pymilvus.settings import Config
1716

1817
from .types import DataType
18+
from .utils import SciPyHelper, SparseMatrixInputType, SparseRowOutputType
1919

2020
CHECK_STR_ARRAY = True
2121

22-
# in search results, if output fields includes a sparse float vector field, we
23-
# will return a SparseRowOutputType for each entity. Using Dict for readability.
24-
# TODO(SPARSE): to allow the user to specify output format.
25-
SparseRowOutputType = Dict[int, float]
26-
27-
# we accept the following types as input for sparse matrix in user facing APIs
28-
# such as insert, search, etc.:
29-
# - scipy sparse array/matrix family: csr, csc, coo, bsr, dia, dok, lil
30-
# - iterable of iterables, each element(iterable) is a sparse vector with index
31-
# as key and value as float.
32-
# dict example: [{2: 0.33, 98: 0.72, ...}, {4: 0.45, 198: 0.52, ...}, ...]
33-
# list of tuple example: [[(2, 0.33), (98, 0.72), ...], [(4, 0.45), ...], ...]
34-
# both index/value can be str numbers: {'2': '3.1'}
35-
SparseMatrixInputType = Union[
36-
Iterable[
37-
Union[
38-
SparseRowOutputType,
39-
Iterable[Tuple[int, float]], # only type hint, we accept int/float like types
40-
]
41-
],
42-
sparse.csc_array,
43-
sparse.coo_array,
44-
sparse.bsr_array,
45-
sparse.dia_array,
46-
sparse.dok_array,
47-
sparse.lil_array,
48-
sparse.csr_array,
49-
sparse.spmatrix,
50-
]
51-
52-
53-
def sparse_is_scipy_matrix(data: Any):
54-
return isinstance(data, sparse.spmatrix)
55-
56-
57-
def sparse_is_scipy_array(data: Any):
58-
# sparse.sparray, the common superclass of sparse.*_array, is introduced in
59-
# scipy 1.11.0, which requires python 3.9, higher than pymilvus's current requirement.
60-
return isinstance(
61-
data,
62-
(
63-
sparse.bsr_array,
64-
sparse.coo_array,
65-
sparse.csc_array,
66-
sparse.csr_array,
67-
sparse.dia_array,
68-
sparse.dok_array,
69-
sparse.lil_array,
70-
),
71-
)
72-
73-
74-
def sparse_is_scipy_format(data: Any):
75-
return sparse_is_scipy_matrix(data) or sparse_is_scipy_array(data)
76-
7722

7823
def entity_is_sparse_matrix(entity: Any):
79-
if sparse_is_scipy_format(entity):
24+
if SciPyHelper.is_scipy_sparse(entity):
8025
return True
8126
try:
8227

@@ -143,34 +88,30 @@ def sparse_float_row_to_bytes(indices: Iterable[int], values: Iterable[float]):
14388
data += struct.pack("f", v)
14489
return data
14590

146-
def unify_sparse_input(data: SparseMatrixInputType) -> sparse.csr_array:
147-
if isinstance(data, sparse.csr_array):
148-
return data
149-
if sparse_is_scipy_array(data):
150-
return data.tocsr()
151-
if sparse_is_scipy_matrix(data):
152-
return sparse.csr_array(data.tocsr())
153-
row_indices = []
154-
col_indices = []
155-
values = []
156-
for row_id, row_data in enumerate(data):
157-
row = row_data.items() if isinstance(row_data, dict) else row_data
158-
row_indices.extend([row_id] * len(row))
159-
col_indices.extend(
160-
[int(col_id) if isinstance(col_id, str) else col_id for col_id, _ in row]
161-
)
162-
values.extend([float(value) if isinstance(value, str) else value for _, value in row])
163-
return sparse.csr_array((values, (row_indices, col_indices)))
164-
16591
if not entity_is_sparse_matrix(data):
16692
raise ParamError(message="input must be a sparse matrix in supported format")
167-
csr = unify_sparse_input(data)
93+
16894
result = schema_types.SparseFloatArray()
169-
result.dim = csr.shape[1]
170-
for start, end in zip(csr.indptr[:-1], csr.indptr[1:]):
171-
result.contents.append(
172-
sparse_float_row_to_bytes(csr.indices[start:end], csr.data[start:end])
173-
)
95+
96+
if SciPyHelper.is_scipy_sparse(data):
97+
csr = data.tocsr()
98+
result.dim = csr.shape[1]
99+
for start, end in zip(csr.indptr[:-1], csr.indptr[1:]):
100+
result.contents.append(
101+
sparse_float_row_to_bytes(csr.indices[start:end], csr.data[start:end])
102+
)
103+
else:
104+
dim = 0
105+
for _, row_data in enumerate(data):
106+
indices = []
107+
values = []
108+
row = row_data.items() if isinstance(row_data, dict) else row_data
109+
for index, value in row:
110+
indices.append(index)
111+
values.append(value)
112+
result.contents.append(sparse_float_row_to_bytes(indices, values))
113+
dim = max(dim, indices[-1] + 1)
114+
result.dim = dim
174115
return result
175116

176117

@@ -186,7 +127,7 @@ def sparse_proto_to_rows(
186127

187128

188129
def get_input_num_rows(entity: Any) -> int:
189-
if sparse_is_scipy_format(entity):
130+
if SciPyHelper.is_scipy_sparse(entity):
190131
return entity.shape[0]
191132
return len(entity)
192133

@@ -354,7 +295,7 @@ def pack_field_value_to_field_data(
354295
field_data.vectors.bfloat16_vector += v_bytes
355296
elif field_type == DataType.SPARSE_FLOAT_VECTOR:
356297
# field_value is a single row of sparse float vector in user provided format
357-
if not sparse_is_scipy_format(field_value):
298+
if not SciPyHelper.is_scipy_sparse(field_value):
358299
field_value = [field_value]
359300
elif field_value.shape[0] != 1:
360301
raise ParamError(message="invalid input for sparse float vector: expect 1 row")

pymilvus/client/grpc_handler.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from pymilvus.grpc_gen import milvus_pb2 as milvus_types
2424
from pymilvus.settings import Config
2525

26-
from . import entity_helper, interceptor, ts_utils
26+
from . import entity_helper, interceptor, ts_utils, utils
2727
from .abstract import AnnSearchRequest, BaseRanker, CollectionSchema, MutationResult, SearchResult
2828
from .asynch import (
2929
CreateIndexFuture,
@@ -763,7 +763,7 @@ def _execute_hybrid_search(
763763
def search(
764764
self,
765765
collection_name: str,
766-
data: Union[List[List[float]], entity_helper.SparseMatrixInputType],
766+
data: Union[List[List[float]], utils.SparseMatrixInputType],
767767
anns_field: str,
768768
param: Dict,
769769
limit: int,

pymilvus/client/prepare.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,13 @@
44

55
import numpy as np
66

7-
from pymilvus.client import __version__, entity_helper
87
from pymilvus.exceptions import DataNotMatchException, ExceptionsMessage, ParamError
98
from pymilvus.grpc_gen import common_pb2 as common_types
109
from pymilvus.grpc_gen import milvus_pb2 as milvus_types
1110
from pymilvus.grpc_gen import schema_pb2 as schema_types
1211
from pymilvus.orm.schema import CollectionSchema
1312

14-
from . import blob, ts_utils, utils
13+
from . import __version__, blob, entity_helper, ts_utils, utils
1514
from .check import check_pass_param, is_legal_collection_properties
1615
from .constants import (
1716
DEFAULT_CONSISTENCY_LEVEL,
@@ -626,7 +625,7 @@ def _prepare_placeholder_str(cls, data: Any):
626625
def search_requests_with_expr(
627626
cls,
628627
collection_name: str,
629-
data: Union[List, entity_helper.SparseMatrixInputType],
628+
data: Union[List, utils.SparseMatrixInputType],
630629
anns_field: str,
631630
param: Dict,
632631
limit: int,

pymilvus/client/utils.py

Lines changed: 99 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import datetime
2+
import importlib.util
23
from datetime import timedelta
3-
from typing import Any, List, Optional, Union
4+
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple, Union
45

56
import ujson
67

@@ -270,3 +271,100 @@ def get_server_type(host: str):
270271

271272
def dumps(v: Union[dict, str]) -> str:
272273
return ujson.dumps(v) if isinstance(v, dict) else str(v)
274+
275+
276+
class SciPyHelper:
277+
_checked = False
278+
279+
# whether scipy.sparse.*_matrix classes exists
280+
_matrix_available = False
281+
# whether scipy.sparse.*_array classes exists
282+
_array_available = False
283+
284+
@classmethod
285+
def _init(cls):
286+
if cls._checked:
287+
return
288+
289+
sparse_spec = importlib.util.find_spec("scipy.sparse")
290+
if sparse_spec is not None:
291+
scipy_sparse = importlib.util.module_from_spec(sparse_spec)
292+
sparse_spec.loader.exec_module(scipy_sparse)
293+
# all scipy.sparse.*_matrix classes are introduced in the same scipy
294+
# version, so we only need to check one of them.
295+
cls._matrix_available = hasattr(scipy_sparse, "csr_matrix")
296+
# all scipy.sparse.*_array classes are introduced in the same scipy
297+
# version, so we only need to check one of them.
298+
cls._array_available = hasattr(scipy_sparse, "csr_array")
299+
else:
300+
cls._matrix_available = False
301+
cls._array_available = False
302+
303+
cls._checked = True
304+
305+
@classmethod
306+
def is_spmatrix(cls, data: Any):
307+
cls._init()
308+
if not cls._matrix_available:
309+
return False
310+
from scipy.sparse import isspmatrix
311+
312+
return isspmatrix(data)
313+
314+
@classmethod
315+
def is_sparray(cls, data: Any):
316+
cls._init()
317+
if not cls._array_available:
318+
return False
319+
from scipy.sparse import issparse, isspmatrix
320+
321+
return issparse(data) and not isspmatrix(data)
322+
323+
@classmethod
324+
def is_scipy_sparse(cls, data: Any):
325+
return cls.is_spmatrix(data) or cls.is_sparray(data)
326+
327+
328+
# in search results, if output fields includes a sparse float vector field, we
329+
# will return a SparseRowOutputType for each entity. Using Dict for readability.
330+
# TODO(SPARSE): to allow the user to specify output format.
331+
SparseRowOutputType = Dict[int, float]
332+
333+
334+
# this import will be called only during static type checking
335+
if TYPE_CHECKING:
336+
from scipy.sparse import (
337+
bsr_array,
338+
coo_array,
339+
csc_array,
340+
csr_array,
341+
dia_array,
342+
dok_array,
343+
lil_array,
344+
spmatrix,
345+
)
346+
347+
# we accept the following types as input for sparse matrix in user facing APIs
348+
# such as insert, search, etc.:
349+
# - scipy sparse array/matrix family: csr, csc, coo, bsr, dia, dok, lil
350+
# - iterable of iterables, each element(iterable) is a sparse vector with index
351+
# as key and value as float.
352+
# dict example: [{2: 0.33, 98: 0.72, ...}, {4: 0.45, 198: 0.52, ...}, ...]
353+
# list of tuple example: [[(2, 0.33), (98, 0.72), ...], [(4, 0.45), ...], ...]
354+
# both index/value can be str numbers: {'2': '3.1'}
355+
SparseMatrixInputType = Union[
356+
Iterable[
357+
Union[
358+
SparseRowOutputType,
359+
Iterable[Tuple[int, float]], # only type hint, we accept int/float like types
360+
]
361+
],
362+
"csc_array",
363+
"coo_array",
364+
"bsr_array",
365+
"dia_array",
366+
"dok_array",
367+
"lil_array",
368+
"csr_array",
369+
"spmatrix",
370+
]

0 commit comments

Comments
 (0)