Skip to content

Introduce Polars for dumping and loading data #457

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 16 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 4 additions & 10 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,9 @@ jobs:
strategy:
max-parallel: 7
matrix:
platform: ["ubuntu-latest"]
tox-env: ["py39", "py310", "py311", "py312", "py313"]
include:
- platform: ubuntu-latest
tox-env: "py39"
- platform: ubuntu-latest
tox-env: "py310"
- platform: ubuntu-latest
tox-env: "py311"
- platform: ubuntu-latest
tox-env: "py312"
- platform: ubuntu-latest
tox-env: "py313"
# test only on latest python for macos
- platform: macos-13
tox-env: "py313"
Expand All @@ -38,3 +30,5 @@ jobs:
uv tool install --python-preference only-managed --python 3.12 tox --with tox-uv
- name: Test with tox
run: uvx --with tox-uv tox run -e ${{ matrix.tox-env }}
- name: Test with tox for polars extra
run: uvx --with tox-uv tox run -e ${{ matrix.tox-env }}-polars
Comment on lines +32 to +33
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add the test run with polars extra

115 changes: 109 additions & 6 deletions gokart/file_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,6 @@
import luigi.contrib.s3
import luigi.format
import numpy as np
import pandas as pd
import pandas.errors
from luigi.format import TextFormat

from gokart.object_storage import ObjectStorage
Expand All @@ -21,6 +19,16 @@
logger = getLogger(__name__)


try:
import polars as pl
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I prefer to raise an exception instead of ignoring the ValueError.
users never find ValueError('GOKART_DATAFRAME_FRAMEWORK_POLARS_ENABLED is not set. Use pandas as dataframe framework.') since it is ignored on L31.

DATAFRAME_FRAMEWORK = os.getenv('GOKART_DATAFRAME_FRAMEWORK', 'pandas')

if GOKART_DATAFRAME_FRAMEWORK == 'polars'
  try:
       import polars
  except ImportError:
       raise ValueError('please install polars to use polars as a framework of dataframe for gokart')

Copy link
Collaborator Author

@hirosassa hirosassa Mar 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your suggestion. fixed in 79839e1


DATAFRAME_FRAMEWORK = 'polars'
except ImportError:
import pandas as pd

DATAFRAME_FRAMEWORK = 'pandas'


class FileProcessor:
@abstractmethod
def format(self):
Expand Down Expand Up @@ -131,6 +139,24 @@ def __init__(self, sep=',', encoding: str = 'utf-8'):
def format(self):
return TextFormat(encoding=self._encoding)

def load(self, file): ...

def dump(self, obj, file): ...


class PolarsCsvFileProcessor(CsvFileProcessor):
def load(self, file):
try:
return pl.read_csv(file, separator=self._sep, encoding=self._encoding)
except pl.exceptions.NoDataError:
return pl.DataFrame()

def dump(self, obj, file):
assert isinstance(obj, (pl.DataFrame, pl.Series)), f'requires pl.DataFrame or pl.Series, but {type(obj)} is passed.'
obj.write_csv(file, separator=self._sep, include_header=True)


class PandasCsvFileProcessor(CsvFileProcessor):
def load(self, file):
try:
return pd.read_csv(file, sep=self._sep, encoding=self._encoding)
Expand Down Expand Up @@ -164,6 +190,34 @@ def __init__(self, orient: str | None = None):
def format(self):
return luigi.format.Nop

def load(self, file): ...

def dump(self, obj, file): ...


class PolarsJsonFileProcessor(JsonFileProcessor):
def load(self, file):
try:
if self._orient == 'records':
return pl.read_ndjson(file)
return pl.read_json(file)
except pl.exceptions.ComputeError:
return pl.DataFrame()

def dump(self, obj, file):
assert isinstance(obj, pl.DataFrame) or isinstance(obj, pl.Series) or isinstance(obj, dict), (
f'requires pl.DataFrame or pl.Series or dict, but {type(obj)} is passed.'
)
if isinstance(obj, dict):
obj = pl.from_dict(obj)

if self._orient == 'records':
obj.write_ndjson(file)
else:
obj.write_json(file)


class PandasJsonFileProcessor(JsonFileProcessor):
def load(self, file):
try:
return pd.read_json(file, orient=self._orient, lines=True if self._orient == 'records' else False)
Expand Down Expand Up @@ -215,11 +269,27 @@ def __init__(self, engine='pyarrow', compression=None):
def format(self):
return luigi.format.Nop

def load(self, file): ...

def dump(self, obj, file): ...


class PolarsParquetFileProcessor(ParquetFileProcessor):
def load(self, file):
if ObjectStorage.is_buffered_reader(file):
return pl.read_parquet(file.name)
else:
return pl.read_parquet(BytesIO(file.read()))

def dump(self, obj, file):
assert isinstance(obj, (pl.DataFrame)), f'requires pl.DataFrame, but {type(obj)} is passed.'
use_pyarrow = self._engine == 'pyarrow'
compression = 'uncompressed' if self._compression is None else self._compression
obj.write_parquet(file, use_pyarrow=use_pyarrow, compression=compression)


class PandasParquetFileProcessor(ParquetFileProcessor):
def load(self, file):
# FIXME(mamo3gr): enable streaming (chunked) read with S3.
# pandas.read_parquet accepts file-like object
# but file (luigi.contrib.s3.ReadableS3File) should have 'tell' method,
# which is needed for pandas to read a file in chunks.
if ObjectStorage.is_buffered_reader(file):
return pd.read_parquet(file.name)
else:
Expand All @@ -240,6 +310,27 @@ def __init__(self, store_index_in_feather: bool):
def format(self):
return luigi.format.Nop

def load(self, file): ...

def dump(self, obj, file): ...


class PolarsFeatherFileProcessor(FeatherFileProcessor):
def load(self, file):
# Since polars' DataFrame doesn't have index, just load feather file
# TODO: Fix ingnoring store_index_in_feather variable
# Currently in PolarsFeatherFileProcessor, we ignored store_index_in_feather variable to avoid
# a breaking change of FeatherFileProcessor's default behavior.
if ObjectStorage.is_buffered_reader(file):
return pl.read_ipc(file.name)
return pl.read_ipc(BytesIO(file.read()))

def dump(self, obj, file):
assert isinstance(obj, (pl.DataFrame)), f'requires pl.DataFrame, but {type(obj)} is passed.'
obj.write_ipc(file.name)


class PandasFeatherFileProcessor(FeatherFileProcessor):
def load(self, file):
# FIXME(mamo3gr): enable streaming (chunked) read with S3.
# pandas.read_feather accepts file-like object
Expand Down Expand Up @@ -281,6 +372,18 @@ def dump(self, obj, file):
dump_obj.to_feather(file.name)


if DATAFRAME_FRAMEWORK == 'polars':
CsvFileProcessor = PolarsCsvFileProcessor # type: ignore
JsonFileProcessor = PolarsJsonFileProcessor # type: ignore
ParquetFileProcessor = PolarsParquetFileProcessor # type: ignore
FeatherFileProcessor = PolarsFeatherFileProcessor # type: ignore
else:
CsvFileProcessor = PandasCsvFileProcessor # type: ignore
JsonFileProcessor = PandasJsonFileProcessor # type: ignore
ParquetFileProcessor = PandasParquetFileProcessor # type: ignore
FeatherFileProcessor = PandasFeatherFileProcessor # type: ignore


def make_file_processor(file_path: str, store_index_in_feather: bool) -> FileProcessor:
extension2processor = {
'.txt': TextFileProcessor(),
Expand Down
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,9 @@ Homepage = "https://github.com/m3dev/gokart"
Repository = "https://github.com/m3dev/gokart"
Documentation = "https://gokart.readthedocs.io/en/latest/"

[project.optional-dependencies]
polars = ["polars-lts-cpu"]

[dependency-groups]
test = [
"fakeredis",
Expand Down
156 changes: 9 additions & 147 deletions test/test_file_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,126 +6,26 @@
from typing import Callable

import boto3
import pandas as pd
import pytest
from luigi import LocalTarget
from moto import mock_aws

from gokart.file_processor import (
DATAFRAME_FRAMEWORK,
CsvFileProcessor,
FeatherFileProcessor,
GzipFileProcessor,
JsonFileProcessor,
NpzFileProcessor,
PandasCsvFileProcessor,
ParquetFileProcessor,
PickleFileProcessor,
PolarsCsvFileProcessor,
TextFileProcessor,
make_file_processor,
)
from gokart.object_storage import ObjectStorage


class TestCsvFileProcessor(unittest.TestCase):
def test_dump_csv_with_utf8(self):
df = pd.DataFrame({'あ': [1, 2, 3], 'い': [4, 5, 6]})
processor = CsvFileProcessor()

with tempfile.TemporaryDirectory() as temp_dir:
temp_path = f'{temp_dir}/temp.csv'

local_target = LocalTarget(path=temp_path, format=processor.format())
with local_target.open('w') as f:
processor.dump(df, f)

# read with utf-8 to check if the file is dumped with utf8
loaded_df = pd.read_csv(temp_path, encoding='utf-8')
pd.testing.assert_frame_equal(df, loaded_df)

def test_dump_csv_with_cp932(self):
df = pd.DataFrame({'あ': [1, 2, 3], 'い': [4, 5, 6]})
processor = CsvFileProcessor(encoding='cp932')

with tempfile.TemporaryDirectory() as temp_dir:
temp_path = f'{temp_dir}/temp.csv'

local_target = LocalTarget(path=temp_path, format=processor.format())
with local_target.open('w') as f:
processor.dump(df, f)

# read with cp932 to check if the file is dumped with cp932
loaded_df = pd.read_csv(temp_path, encoding='cp932')
pd.testing.assert_frame_equal(df, loaded_df)

def test_load_csv_with_utf8(self):
df = pd.DataFrame({'あ': [1, 2, 3], 'い': [4, 5, 6]})
processor = CsvFileProcessor()

with tempfile.TemporaryDirectory() as temp_dir:
temp_path = f'{temp_dir}/temp.csv'
df.to_csv(temp_path, encoding='utf-8', index=False)

local_target = LocalTarget(path=temp_path, format=processor.format())
with local_target.open('r') as f:
# read with utf-8 to check if the file is dumped with utf8
loaded_df = processor.load(f)
pd.testing.assert_frame_equal(df, loaded_df)

def test_load_csv_with_cp932(self):
df = pd.DataFrame({'あ': [1, 2, 3], 'い': [4, 5, 6]})
processor = CsvFileProcessor(encoding='cp932')

with tempfile.TemporaryDirectory() as temp_dir:
temp_path = f'{temp_dir}/temp.csv'
df.to_csv(temp_path, encoding='cp932', index=False)

local_target = LocalTarget(path=temp_path, format=processor.format())
with local_target.open('r') as f:
# read with cp932 to check if the file is dumped with cp932
loaded_df = processor.load(f)
pd.testing.assert_frame_equal(df, loaded_df)


class TestJsonFileProcessor:
@pytest.mark.parametrize(
'orient,input_data,expected_json',
[
pytest.param(
None,
pd.DataFrame({'A': [1, 2, 3], 'B': [4, 5, 6]}),
'{"A":{"0":1,"1":2,"2":3},"B":{"0":4,"1":5,"2":6}}',
id='With Default Orient for DataFrame',
),
pytest.param(
'records',
pd.DataFrame({'A': [1, 2, 3], 'B': [4, 5, 6]}),
'{"A":1,"B":4}\n{"A":2,"B":5}\n{"A":3,"B":6}\n',
id='With Records Orient for DataFrame',
),
pytest.param(None, {'A': [1, 2, 3], 'B': [4, 5, 6]}, '{"A":{"0":1,"1":2,"2":3},"B":{"0":4,"1":5,"2":6}}', id='With Default Orient for Dict'),
pytest.param('records', {'A': [1, 2, 3], 'B': [4, 5, 6]}, '{"A":1,"B":4}\n{"A":2,"B":5}\n{"A":3,"B":6}\n', id='With Records Orient for Dict'),
pytest.param(None, {}, '{}', id='With Default Orient for Empty Dict'),
pytest.param('records', {}, '\n', id='With Records Orient for Empty Dict'),
],
)
def test_dump_and_load_json(self, orient, input_data, expected_json):
processor = JsonFileProcessor(orient=orient)

with tempfile.TemporaryDirectory() as temp_dir:
temp_path = f'{temp_dir}/temp.json'
local_target = LocalTarget(path=temp_path, format=processor.format())
with local_target.open('w') as f:
processor.dump(input_data, f)
with local_target.open('r') as f:
loaded_df = processor.load(f)
f.seek(0)
loaded_json = f.read().decode('utf-8')

assert loaded_json == expected_json

df_input = pd.DataFrame(input_data)
pd.testing.assert_frame_equal(df_input, loaded_df)


class TestPickleFileProcessor(unittest.TestCase):
def test_dump_and_load_normal_obj(self):
var = 'abc'
Expand Down Expand Up @@ -189,50 +89,12 @@ def test_dump_and_load_with_readables3file(self):
self.assertEqual(loaded, var)


class TestFeatherFileProcessor(unittest.TestCase):
def test_feather_should_return_same_dataframe(self):
df = pd.DataFrame({'a': [1]})
processor = FeatherFileProcessor(store_index_in_feather=True)

with tempfile.TemporaryDirectory() as temp_dir:
temp_path = f'{temp_dir}/temp.feather'

local_target = LocalTarget(path=temp_path, format=processor.format())
with local_target.open('w') as f:
processor.dump(df, f)

with local_target.open('r') as f:
loaded_df = processor.load(f)

pd.testing.assert_frame_equal(df, loaded_df)

def test_feather_should_save_index_name(self):
df = pd.DataFrame({'a': [1]}, index=pd.Index([1], name='index_name'))
processor = FeatherFileProcessor(store_index_in_feather=True)

with tempfile.TemporaryDirectory() as temp_dir:
temp_path = f'{temp_dir}/temp.feather'

local_target = LocalTarget(path=temp_path, format=processor.format())
with local_target.open('w') as f:
processor.dump(df, f)

with local_target.open('r') as f:
loaded_df = processor.load(f)

pd.testing.assert_frame_equal(df, loaded_df)

def test_feather_should_raise_error_index_name_is_None(self):
df = pd.DataFrame({'a': [1]}, index=pd.Index([1], name='None'))
processor = FeatherFileProcessor(store_index_in_feather=True)

with tempfile.TemporaryDirectory() as temp_dir:
temp_path = f'{temp_dir}/temp.feather'

local_target = LocalTarget(path=temp_path, format=processor.format())
with local_target.open('w') as f:
with self.assertRaises(AssertionError):
processor.dump(df, f)
class TestFileProcessorClassSelection(unittest.TestCase):
def test_processor_selection(self):
if DATAFRAME_FRAMEWORK == 'polars':
self.assertTrue(issubclass(CsvFileProcessor, PolarsCsvFileProcessor))
else:
self.assertTrue(issubclass(CsvFileProcessor, PandasCsvFileProcessor))


class TestMakeFileProcessor(unittest.TestCase):
Expand Down
Loading
Loading