Skip to content

Commit c62b9b0

Browse files
committed
add DATAFRAME_FRAMEWORK variable to branch pandas and polars
1 parent 45623d1 commit c62b9b0

File tree

2 files changed

+177
-132
lines changed

2 files changed

+177
-132
lines changed

gokart/file_processor.py

Lines changed: 153 additions & 132 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,16 @@
1717
logger = getLogger(__name__)
1818

1919

20+
try:
21+
import polars as pl
22+
23+
DATAFRAME_FRAMEWORK = 'polars'
24+
except ImportError:
25+
import pandas as pd
26+
27+
DATAFRAME_FRAMEWORK = 'pandas'
28+
29+
2030
class FileProcessor(object):
2131
@abstractmethod
2232
def format(self):
@@ -122,38 +132,39 @@ class CsvFileProcessor(FileProcessor):
122132
def __init__(self, sep=',', encoding: str = 'utf-8'):
123133
self._sep = sep
124134
self._encoding = encoding
125-
super(CsvFileProcessor, self).__init__()
135+
super().__init__()
126136

127137
def format(self):
128138
return TextFormat(encoding=self._encoding)
129139

130140
def load(self, file):
131-
try:
132-
import pandas as pd
141+
...
133142

134-
try:
135-
return pd.read_csv(file, sep=self._sep, encoding=self._encoding)
136-
except pd.errors.EmptyDataError:
137-
return pd.DataFrame()
138-
except ImportError:
139-
import polars as pl
143+
def dump(self, obj, file):
144+
...
140145

141-
try:
142-
return pl.read_csv(file, sep=self._sep, encoding=self._encoding)
143-
except pl.exceptions.NoDataError:
144-
return pd.DataFrame()
146+
class PolarsCsvFileProcessor(CsvFileProcessor):
147+
def load(self, file):
148+
try:
149+
return pl.read_csv(file, sep=self._sep, encoding=self._encoding)
150+
except pl.exceptions.NoDataError:
151+
return pl.DataFrame()
145152

146153
def dump(self, obj, file):
147-
try:
148-
import pandas as pd
154+
assert isinstance(obj, (pl.DataFrame, pl.Series)), f'requires pl.DataFrame or pl.Series, but {type(obj)} is passed.'
155+
obj.write_csv(file, separator=self._sep, include_header=True)
156+
149157

150-
assert isinstance(obj, (pd.DataFrame, pd.Series)), f'requires pd.DataFrame or pd.Series, but {type(obj)} is passed.'
151-
obj.to_csv(file, mode='wt', index=False, sep=self._sep, header=True, encoding=self._encoding)
152-
except ImportError:
153-
import polars as pl
158+
class PandasCsvFileProcessor(CsvFileProcessor):
159+
def load(self, file):
160+
try:
161+
return pd.read_csv(file, sep=self._sep, encoding=self._encoding)
162+
except pd.errors.EmptyDataError:
163+
return pd.DataFrame()
154164

155-
assert isinstance(obj, (pl.DataFrame, pl.Series)), f'requires pl.DataFrame or pl.Series, but {type(obj)} is passed.'
156-
obj.write_csv(file, separator=self._sep, include_header=True)
165+
def dump(self, obj, file):
166+
assert isinstance(obj, (pd.DataFrame, pd.Series)), f'requires pd.DataFrame or pd.Series, but {type(obj)} is passed.'
167+
obj.to_csv(file, mode='wt', index=False, sep=self._sep, header=True, encoding=self._encoding)
157168

158169

159170
class GzipFileProcessor(FileProcessor):
@@ -176,40 +187,42 @@ def format(self):
176187
return None
177188

178189
def load(self, file):
179-
try:
180-
import pandas as pd
190+
...
181191

182-
try:
183-
return self.read_json(file)
184-
except pd.errors.EmptyDataError:
185-
return pd.DataFrame()
186-
except ImportError:
187-
import polars as pl
192+
def dump(self, obj, file):
193+
...
188194

189-
try:
190-
return self.read_json(file)
191-
except pl.exceptions.NoDataError:
192-
return pl.DataFrame()
195+
196+
class PolarsJsonFileProcessor(JsonFileProcessor):
197+
def load(self, file):
198+
try:
199+
return self.read_json(file)
200+
except pl.exceptions.NoDataError:
201+
return pl.DataFrame()
193202

194203
def dump(self, obj, file):
204+
assert isinstance(obj, pl.DataFrame) or isinstance(obj, pl.Series) or isinstance(obj, dict), (
205+
f'requires pl.DataFrame or pl.Series or dict, but {type(obj)} is passed.'
206+
)
207+
if isinstance(obj, dict):
208+
obj = pl.from_dict(obj)
209+
obj.write_json(file)
210+
211+
212+
class PandasJsonFileProcessor(JsonFileProcessor):
213+
def load(self, file):
195214
try:
196-
import pandas as pd
215+
return self.read_json(file)
216+
except pd.errors.EmptyDataError:
217+
return pd.DataFrame()
197218

198-
assert isinstance(obj, pd.DataFrame) or isinstance(obj, pd.Series) or isinstance(obj, dict), (
199-
f'requires pd.DataFrame or pd.Series or dict, but {type(obj)} is passed.'
200-
)
201-
if isinstance(obj, dict):
202-
obj = pd.DataFrame.from_dict(obj)
203-
obj.to_json(file)
204-
except ImportError:
205-
import polars as pl
206-
207-
assert isinstance(obj, pl.DataFrame) or isinstance(obj, pl.Series) or isinstance(obj, dict), (
208-
f'requires pl.DataFrame or pl.Series or dict, but {type(obj)} is passed.'
209-
)
210-
if isinstance(obj, dict):
211-
obj = pl.from_dict(obj)
212-
obj.write_json(file)
219+
def dump(self, obj, file):
220+
assert isinstance(obj, pd.DataFrame) or isinstance(obj, pd.Series) or isinstance(obj, dict), (
221+
f'requires pd.DataFrame or pd.Series or dict, but {type(obj)} is passed.'
222+
)
223+
if isinstance(obj, dict):
224+
obj = pd.DataFrame.from_dict(obj)
225+
obj.to_json(file)
213226

214227

215228
class XmlFileProcessor(FileProcessor):
@@ -243,118 +256,126 @@ class ParquetFileProcessor(FileProcessor):
243256
def __init__(self, engine='pyarrow', compression=None):
244257
self._engine = engine
245258
self._compression = compression
246-
super(ParquetFileProcessor, self).__init__()
259+
super().__init__()
247260

248261
def format(self):
249262
return luigi.format.Nop
250263

251264
def load(self, file):
252-
try:
253-
import pandas as pd
254-
255-
# FIXME(mamo3gr): enable streaming (chunked) read with S3.
256-
# pandas.read_parquet accepts file-like object
257-
# but file (luigi.contrib.s3.ReadableS3File) should have 'tell' method,
258-
# which is needed for pandas to read a file in chunks.
259-
if ObjectStorage.is_buffered_reader(file):
260-
return pd.read_parquet(file.name)
261-
else:
262-
return pd.read_parquet(BytesIO(file.read()))
263-
except ImportError:
264-
import polars as pl
265-
266-
if ObjectStorage.is_buffered_reader(file):
267-
return pl.read_parquet(file.name)
268-
else:
269-
return pl.read_parquet(BytesIO(file.read()))
265+
...
270266

271267
def dump(self, obj, file):
272-
try:
273-
import pandas as pd
268+
...
274269

275-
assert isinstance(obj, (pd.DataFrame)), f'requires pd.DataFrame, but {type(obj)} is passed.'
276-
# MEMO: to_parquet only supports a filepath as string (not a file handle)
277-
obj.to_parquet(file.name, index=False, engine=self._engine, compression=self._compression)
278-
except ImportError:
279-
import polars as pl
270+
class PolarsParquetFileProcessor(ParquetFileProcessor):
271+
def load(self, file):
272+
if ObjectStorage.is_buffered_reader(file):
273+
return pl.read_parquet(file.name)
274+
else:
275+
return pl.read_parquet(BytesIO(file.read()))
280276

281-
assert isinstance(obj, (pl.DataFrame)), f'requires pl.DataFrame, but {type(obj)} is passed.'
282-
use_pyarrow = self._engine == 'pyarrow'
283-
compression = 'uncompressed' if self._compression is None else self._compression
284-
obj.write_parquet(file, use_pyarrow=use_pyarrow, compression=compression)
277+
def dump(self, obj, file):
278+
assert isinstance(obj, (pl.DataFrame)), f'requires pl.DataFrame, but {type(obj)} is passed.'
279+
use_pyarrow = self._engine == 'pyarrow'
280+
compression = 'uncompressed' if self._compression is None else self._compression
281+
obj.write_parquet(file, use_pyarrow=use_pyarrow, compression=compression)
282+
283+
284+
class PandasParquetFileProcessor(ParquetFileProcessor):
285+
def load(self, file):
286+
if ObjectStorage.is_buffered_reader(file):
287+
return pd.read_parquet(file.name)
288+
else:
289+
return pd.read_parquet(BytesIO(file.read()))
290+
291+
def dump(self, obj, file):
292+
assert isinstance(obj, (pd.DataFrame)), f'requires pd.DataFrame, but {type(obj)} is passed.'
293+
# MEMO: to_parquet only supports a filepath as string (not a file handle)
294+
obj.to_parquet(file.name, index=False, engine=self._engine, compression=self._compression)
285295

286296

287297
class FeatherFileProcessor(FileProcessor):
288298
def __init__(self, store_index_in_feather: bool):
289-
super(FeatherFileProcessor, self).__init__()
299+
super().__init__()
290300
self._store_index_in_feather = store_index_in_feather
291301
self.INDEX_COLUMN_PREFIX = '__feather_gokart_index__'
292302

293303
def format(self):
294304
return luigi.format.Nop
295305

296306
def load(self, file):
297-
try:
298-
import pandas as pd
299-
300-
# FIXME(mamo3gr): enable streaming (chunked) read with S3.
301-
# pandas.read_feather accepts file-like object
302-
# but file (luigi.contrib.s3.ReadableS3File) should have 'tell' method,
303-
# which is needed for pandas to read a file in chunks.
304-
if ObjectStorage.is_buffered_reader(file):
305-
loaded_df = pd.read_feather(file.name)
306-
else:
307-
loaded_df = pd.read_feather(BytesIO(file.read()))
308-
309-
if self._store_index_in_feather:
310-
if any(col.startswith(self.INDEX_COLUMN_PREFIX) for col in loaded_df.columns):
311-
index_columns = [col_name for col_name in loaded_df.columns[::-1] if col_name[: len(self.INDEX_COLUMN_PREFIX)] == self.INDEX_COLUMN_PREFIX]
312-
index_column = index_columns[0]
313-
index_name = index_column[len(self.INDEX_COLUMN_PREFIX) :]
314-
if index_name == 'None':
315-
index_name = None
316-
loaded_df.index = pd.Index(loaded_df[index_column].values, name=index_name)
317-
loaded_df = loaded_df.drop(columns={index_column})
318-
319-
return loaded_df
320-
except ImportError:
321-
import polars as pl
322-
323-
# Since polars' DataFrame doesn't have index, just load feather file
324-
if ObjectStorage.is_buffered_reader(file):
325-
loaded_df = pl.read_ipc(file.name)
326-
else:
327-
loaded_df = pl.read_ipc(BytesIO(file.read()))
328-
329-
return loaded_df
307+
...
330308

331309
def dump(self, obj, file):
332-
try:
333-
import pandas as pd
310+
...
334311

335-
assert isinstance(obj, (pd.DataFrame)), f'requires pd.DataFrame, but {type(obj)} is passed.'
336-
dump_obj = obj.copy()
337312

338-
if self._store_index_in_feather:
339-
index_column_name = f'{self.INDEX_COLUMN_PREFIX}{dump_obj.index.name}'
340-
assert index_column_name not in dump_obj.columns, (
341-
f'column name {index_column_name} already exists in dump_obj. \
313+
class PolarsFeatherFileProcessor(FeatherFileProcessor):
314+
def load(self, file):
315+
# Since polars' DataFrame doesn't have index, just load feather file
316+
if ObjectStorage.is_buffered_reader(file):
317+
loaded_df = pl.read_ipc(file.name)
318+
else:
319+
loaded_df = pl.read_ipc(BytesIO(file.read()))
320+
321+
def dump(self, obj, file):
322+
assert isinstance(obj, (pl.DataFrame)), f'requires pl.DataFrame, but {type(obj)} is passed.'
323+
dump_obj = obj.copy()
324+
dump_obj.write_ipc(file.name)
325+
326+
327+
class PandasFeatherFileProcessor(FeatherFileProcessor):
328+
def load(self, file):
329+
# FIXME(mamo3gr): enable streaming (chunked) read with S3.
330+
# pandas.read_feather accepts file-like object
331+
# but file (luigi.contrib.s3.ReadableS3File) should have 'tell' method,
332+
# which is needed for pandas to read a file in chunks.
333+
if ObjectStorage.is_buffered_reader(file):
334+
loaded_df = pd.read_feather(file.name)
335+
else:
336+
loaded_df = pd.read_feather(BytesIO(file.read()))
337+
338+
if self._store_index_in_feather:
339+
if any(col.startswith(self.INDEX_COLUMN_PREFIX) for col in loaded_df.columns):
340+
index_columns = [col_name for col_name in loaded_df.columns[::-1] if col_name[: len(self.INDEX_COLUMN_PREFIX)] == self.INDEX_COLUMN_PREFIX]
341+
index_column = index_columns[0]
342+
index_name = index_column[len(self.INDEX_COLUMN_PREFIX) :]
343+
if index_name == 'None':
344+
index_name = None
345+
loaded_df.index = pd.Index(loaded_df[index_column].values, name=index_name)
346+
loaded_df = loaded_df.drop(columns={index_column})
347+
348+
return loaded_df
349+
350+
def dump(self, obj, file):
351+
assert isinstance(obj, (pd.DataFrame)), f'requires pd.DataFrame, but {type(obj)} is passed.'
352+
dump_obj = obj.copy()
353+
354+
if self._store_index_in_feather:
355+
index_column_name = f'{self.INDEX_COLUMN_PREFIX}{dump_obj.index.name}'
356+
assert index_column_name not in dump_obj.columns, (
357+
f'column name {index_column_name} already exists in dump_obj. \
342358
Consider not saving index by setting store_index_in_feather=False.'
343-
)
344-
assert dump_obj.index.name != 'None', 'index name is "None", which is not allowed in gokart. Consider setting another index name.'
359+
)
360+
assert dump_obj.index.name != 'None', 'index name is "None", which is not allowed in gokart. Consider setting another index name.'
345361

346-
dump_obj[index_column_name] = dump_obj.index
347-
dump_obj = dump_obj.reset_index(drop=True)
362+
dump_obj[index_column_name] = dump_obj.index
363+
dump_obj = dump_obj.reset_index(drop=True)
348364

349-
# to_feather supports "binary" file-like object, but file variable is text
350-
dump_obj.to_feather(file.name)
351-
except ImportError:
352-
import polars as pl
365+
# to_feather supports "binary" file-like object, but file variable is text
366+
dump_obj.to_feather(file.name)
353367

354-
assert isinstance(obj, (pl.DataFrame)), f'requires pl.DataFrame, but {type(obj)} is passed.'
355-
dump_obj = obj.copy()
356-
dump_obj.write_ipc(file.name)
357368

369+
if DATAFRAME_FRAMEWORK == 'polars':
370+
CsvFileProcessor = PolarsCsvFileProcessor
371+
JsonFileProcessor = PolarsJsonFileProcessor
372+
ParquetFileProcessor = PolarsParquetFileProcessor
373+
FeatherFileProcessor = PolarsFeatherFileProcessor
374+
else:
375+
CsvFileProcessor = PandasCsvFileProcessor
376+
JsonFileProcessor = PandasJsonFileProcessor
377+
ParquetFileProcessor = PandasParquetFileProcessor
378+
FeatherFileProcessor = PandasFeatherFileProcessor
358379

359380
def make_file_processor(file_path: str, store_index_in_feather: bool) -> FileProcessor:
360381
extension2processor = {

0 commit comments

Comments
 (0)