@@ -289,7 +289,7 @@ def _persist_parquet(self, local_path: str, **kwargs):
289
289
290
290
def _persist_csv (self , local_path : str , ** kwargs ):
291
291
sep = self ._config .get ("sep" , "," )
292
- # nullkey is not supported in csv now
292
+ nullkey = self . _config . get ( "nullkey" , "" )
293
293
294
294
header = list (self ._buffer .keys ())
295
295
data = pd .DataFrame (columns = header )
@@ -307,17 +307,14 @@ def _persist_csv(self, local_path: str, **kwargs):
307
307
# 2. or convert arr into a string using json.dumps(arr) first and then add it to df
308
308
# I choose method 2 here
309
309
if field_schema .dtype in {
310
- DataType .JSON ,
311
- DataType .ARRAY ,
312
310
DataType .SPARSE_FLOAT_VECTOR ,
313
311
DataType .BINARY_VECTOR ,
314
312
DataType .FLOAT_VECTOR ,
315
313
}:
316
- dt = np .dtype ("str" )
317
314
arr = []
318
315
for val in v :
319
316
arr .append (json .dumps (val ))
320
- data [k ] = pd .Series (arr , dtype = dt )
317
+ data [k ] = pd .Series (arr , dtype = np . dtype ( "str" ) )
321
318
elif field_schema .dtype in {DataType .FLOAT16_VECTOR , DataType .BFLOAT16_VECTOR }:
322
319
# special process for float16 vector, the self._buffer stores bytes for
323
320
# float16 vector, convert the bytes to float list
@@ -330,19 +327,31 @@ def _persist_csv(self, local_path: str, **kwargs):
330
327
for val in v :
331
328
arr .append (json .dumps (np .frombuffer (val , dtype = dt ).tolist ()))
332
329
data [k ] = pd .Series (arr , dtype = np .dtype ("str" ))
330
+ elif field_schema .dtype in {
331
+ DataType .JSON ,
332
+ DataType .ARRAY ,
333
+ }:
334
+ arr = []
335
+ for val in v :
336
+ if val is None :
337
+ arr .append (nullkey )
338
+ else :
339
+ arr .append (json .dumps (val ))
340
+ data [k ] = pd .Series (arr , dtype = np .dtype ("str" ))
333
341
elif field_schema .dtype in {DataType .BOOL }:
334
- dt = np .dtype ("str" )
335
- arr = ["true" if x else "false" for x in v ]
336
- data [k ] = pd .Series (arr , dtype = dt )
337
- elif field_schema .dtype .name in NUMPY_TYPE_CREATOR :
338
- dt = NUMPY_TYPE_CREATOR [field_schema .dtype .name ]
339
- data [k ] = pd .Series (v , dtype = dt )
342
+ arr = []
343
+ for val in v :
344
+ if val is not None :
345
+ arr .append ("true" if val else "false" )
346
+ data [k ] = pd .Series (arr , dtype = np .dtype ("str" ))
340
347
else :
341
- data [k ] = pd .Series (v )
348
+ data [k ] = pd .Series (v , dtype = NUMPY_TYPE_CREATOR [ field_schema . dtype . name ] )
342
349
343
350
file_path = Path (local_path + ".csv" )
344
351
try :
345
- data .to_csv (file_path , sep = sep , index = False )
352
+ # pd.Series will convert None to np.nan,
353
+ # so we can use 'na_rep=nullkey' to replace NaN with nullkey
354
+ data .to_csv (file_path , sep = sep , na_rep = nullkey , index = False )
346
355
except Exception as e :
347
356
self ._throw (f"Failed to persist file { file_path } , error: { e } " )
348
357
0 commit comments