@@ -44,6 +44,8 @@ def is_float_type(v: Any):
44
44
if len (entity ) == 0 :
45
45
return False
46
46
for item in entity :
47
+ if SciPyHelper .is_scipy_sparse (item ):
48
+ return item .shape [0 ] == 1
47
49
pairs = item .items () if isinstance (item , dict ) else item
48
50
# each row must be a non-empty list of Tuple[int, float]
49
51
if len (pairs ) == 0 :
@@ -103,14 +105,22 @@ def sparse_float_row_to_bytes(indices: Iterable[int], values: Iterable[float]):
103
105
else :
104
106
dim = 0
105
107
for _ , row_data in enumerate (data ):
106
- indices = []
107
- values = []
108
- row = row_data .items () if isinstance (row_data , dict ) else row_data
109
- for index , value in row :
110
- indices .append (int (index ))
111
- values .append (float (value ))
112
- result .contents .append (sparse_float_row_to_bytes (indices , values ))
113
- dim = max (dim , indices [- 1 ] + 1 )
108
+ if SciPyHelper .is_scipy_sparse (row_data ):
109
+ if row_data .shape [0 ] != 1 :
110
+ raise ParamError (message = "invalid input for sparse float vector: expect 1 row" )
111
+ dim = max (dim , row_data .shape [1 ])
112
+ result .contents .append (
113
+ sparse_float_row_to_bytes (row_data .indices , row_data .data )
114
+ )
115
+ else :
116
+ indices = []
117
+ values = []
118
+ row = row_data .items () if isinstance (row_data , dict ) else row_data
119
+ for index , value in row :
120
+ indices .append (int (index ))
121
+ values .append (float (value ))
122
+ result .contents .append (sparse_float_row_to_bytes (indices , values ))
123
+ dim = max (dim , indices [- 1 ] + 1 )
114
124
result .dim = dim
115
125
return result
116
126
0 commit comments