Skip to content

Commit 90b8961

Browse files
psmyth94Patrick SmythlhoestqHuggingFaceDocBuilder
authored
Add polars compatibility (#6531)
* Add Polars support for data formatting and conversion * Update Polars availability check in config.py * added to_polars * changed the logic of importing polars if not already called * Remove to and from_polars from table.py in order to maintain pa.table logic-only * fix unused import * fixed code formatting with ruff * fix formatting issues with ruff * fix formatting issues using ruff * add tests for polars formatting * removed using InMemoryTable classmethod to convert polars to Table * added test for polars conversion * added missing ruff fixes * add polars in test dependencies * Fixed not executing default write method due to nested polars check. * Update src/datasets/arrow_dataset.py Co-authored-by: Quentin Lhoest <[email protected]> * Update src/datasets/arrow_dataset.py Co-authored-by: Quentin Lhoest <[email protected]> * Fix Polars DataFrame conversion bug * Fix DataFrame conversion in arrow_dataset.py * Fix variable name in arrow_dataset.py * Fix write_table to write_row in Dataset class * fix formatting with ruff * Update polars dependency to include timezone support * Remove polars in EXTRAS_REQUIRE * Replace deprecated method * perform cleanup after use * remove unused import * Add garbage collection to test_to_polars method * Remove unused import and unnecessary code in test_to_polars method * Add additional args for to_polars method * Fixed unclosed links to dataset file * ruff cleanup * even ruffier cleanup * changed hash to reflect new SHA for ref/convert/parquet --------- Co-authored-by: Patrick Smyth <[email protected]> Co-authored-by: Quentin Lhoest <[email protected]> Co-authored-by: Patrick Smyth <[email protected]>
1 parent 6fb6c83 commit 90b8961

File tree

10 files changed

+464
-4
lines changed

10 files changed

+464
-4
lines changed

setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,7 @@
185185
"transformers",
186186
"typing-extensions>=4.6.1", # due to conflict between apache-beam and pydantic
187187
"zstandard",
188+
"polars[timezone]>=0.20.0",
188189
]
189190

190191

src/datasets/arrow_dataset.py

Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,7 @@
134134
if TYPE_CHECKING:
135135
import sqlite3
136136

137+
import polars as pl
137138
import pyspark
138139
import sqlalchemy
139140

@@ -868,6 +869,48 @@ def from_pandas(
868869
table = table.cast(features.arrow_schema)
869870
return cls(table, info=info, split=split)
870871

872+
@classmethod
873+
def from_polars(
874+
cls,
875+
df: "pl.DataFrame",
876+
features: Optional[Features] = None,
877+
info: Optional[DatasetInfo] = None,
878+
split: Optional[NamedSplit] = None,
879+
) -> "Dataset":
880+
"""
881+
Collect the underlying arrow arrays in an Arrow Table.
882+
883+
This operation is mostly zero copy.
884+
885+
Data types that do copy:
886+
* CategoricalType
887+
888+
Args:
889+
df (`polars.DataFrame`): DataFrame to convert to Arrow Table
890+
features (`Features`, optional): Dataset features.
891+
info (`DatasetInfo`, optional): Dataset information, like description, citation, etc.
892+
split (`NamedSplit`, optional): Name of the dataset split.
893+
894+
Examples:
895+
```py
896+
>>> ds = Dataset.from_polars(df)
897+
```
898+
"""
899+
if info is not None and features is not None and info.features != features:
900+
raise ValueError(
901+
f"Features specified in `features` and `info.features` can't be different:\n{features}\n{info.features}"
902+
)
903+
features = features if features is not None else info.features if info is not None else None
904+
if info is None:
905+
info = DatasetInfo()
906+
info.features = features
907+
table = InMemoryTable(df.to_arrow())
908+
if features is not None:
909+
# more expensive cast than InMemoryTable.from_polars(..., schema=features.arrow_schema)
910+
# needed to support the str to Audio conversion for instance
911+
table = table.cast(features.arrow_schema)
912+
return cls(table, info=info, split=split)
913+
871914
@classmethod
872915
def from_dict(
873916
cls,
@@ -3319,6 +3362,10 @@ def validate_function_output(processed_inputs, indices):
33193362
)
33203363
elif isinstance(indices, list) and isinstance(processed_inputs, Mapping):
33213364
allowed_batch_return_types = (list, np.ndarray, pd.Series)
3365+
if config.POLARS_AVAILABLE and "polars" in sys.modules:
3366+
import polars as pl
3367+
3368+
allowed_batch_return_types += (pl.Series, pl.DataFrame)
33223369
if config.TF_AVAILABLE and "tensorflow" in sys.modules:
33233370
import tensorflow as tf
33243371

@@ -3438,6 +3485,10 @@ def init_buffer_and_writer():
34383485
# If `update_data` is True after processing the first example/batch, initalize these resources with `init_buffer_and_writer`
34393486
buf_writer, writer, tmp_file = None, None, None
34403487

3488+
# Check if Polars is available and import it if so
3489+
if config.POLARS_AVAILABLE and "polars" in sys.modules:
3490+
import polars as pl
3491+
34413492
# Optionally initialize the writer as a context manager
34423493
with contextlib.ExitStack() as stack:
34433494
try:
@@ -3464,6 +3515,12 @@ def init_buffer_and_writer():
34643515
writer.write_row(example)
34653516
elif isinstance(example, pd.DataFrame):
34663517
writer.write_row(pa.Table.from_pandas(example))
3518+
elif (
3519+
config.POLARS_AVAILABLE
3520+
and "polars" in sys.modules
3521+
and isinstance(example, pl.DataFrame)
3522+
):
3523+
writer.write_row(example.to_arrow())
34673524
else:
34683525
writer.write(example)
34693526
num_examples_progress_update += 1
@@ -3497,6 +3554,10 @@ def init_buffer_and_writer():
34973554
writer.write_table(batch)
34983555
elif isinstance(batch, pd.DataFrame):
34993556
writer.write_table(pa.Table.from_pandas(batch))
3557+
elif (
3558+
config.POLARS_AVAILABLE and "polars" in sys.modules and isinstance(batch, pl.DataFrame)
3559+
):
3560+
writer.write_table(batch.to_arrow())
35003561
else:
35013562
writer.write_batch(batch)
35023563
num_examples_progress_update += num_examples_in_batch
@@ -4949,6 +5010,66 @@ def to_pandas(
49495010
for offset in range(0, len(self), batch_size)
49505011
)
49515012

5013+
def to_polars(
5014+
self,
5015+
batch_size: Optional[int] = None,
5016+
batched: bool = False,
5017+
schema_overrides: Optional[dict] = None,
5018+
rechunk: bool = True,
5019+
) -> Union["pl.DataFrame", Iterator["pl.DataFrame"]]:
5020+
"""Returns the dataset as a `polars.DataFrame`. Can also return a generator for large datasets.
5021+
5022+
Args:
5023+
batched (`bool`):
5024+
Set to `True` to return a generator that yields the dataset as batches
5025+
of `batch_size` rows. Defaults to `False` (returns the whole datasets once).
5026+
batch_size (`int`, *optional*):
5027+
The size (number of rows) of the batches if `batched` is `True`.
5028+
Defaults to `genomicsml.datasets.config.DEFAULT_MAX_BATCH_SIZE`.
5029+
schema_overrides (`dict`, *optional*):
5030+
Support type specification or override of one or more columns; note that
5031+
any dtypes inferred from the schema param will be overridden.
5032+
rechunk (`bool`):
5033+
Make sure that all data is in contiguous memory. Defaults to `True`.
5034+
Returns:
5035+
`polars.DataFrame` or `Iterator[polars.DataFrame]`
5036+
5037+
Example:
5038+
5039+
```py
5040+
>>> ds.to_polars()
5041+
```
5042+
"""
5043+
if config.POLARS_AVAILABLE:
5044+
import polars as pl
5045+
5046+
if not batched:
5047+
return pl.from_arrow(
5048+
query_table(
5049+
table=self._data,
5050+
key=slice(0, len(self)),
5051+
indices=self._indices if self._indices is not None else None,
5052+
),
5053+
schema_overrides=schema_overrides,
5054+
rechunk=rechunk,
5055+
)
5056+
else:
5057+
batch_size = batch_size if batch_size else config.DEFAULT_MAX_BATCH_SIZE
5058+
return (
5059+
pl.from_arrow(
5060+
query_table(
5061+
table=self._data,
5062+
key=slice(offset, offset + batch_size),
5063+
indices=self._indices if self._indices is not None else None,
5064+
),
5065+
schema_overrides=schema_overrides,
5066+
rechunk=rechunk,
5067+
)
5068+
for offset in range(0, len(self), batch_size)
5069+
)
5070+
else:
5071+
raise ValueError("Polars needs to be installed to be able to return Polars dataframes.")
5072+
49525073
def to_parquet(
49535074
self,
49545075
path_or_buf: Union[PathLike, BinaryIO],

src/datasets/config.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,16 @@
6161
else:
6262
logger.info("Disabling PyTorch because USE_TF is set")
6363

64+
POLARS_VERSION = "N/A"
65+
POLARS_AVAILABLE = importlib.util.find_spec("polars") is not None
66+
67+
if POLARS_AVAILABLE:
68+
try:
69+
POLARS_VERSION = version.parse(importlib.metadata.version("polars"))
70+
logger.info(f"Polars version {POLARS_VERSION} available.")
71+
except importlib.metadata.PackageNotFoundError:
72+
pass
73+
6474
TF_VERSION = "N/A"
6575
TF_AVAILABLE = False
6676

src/datasets/formatting/__init__.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,14 @@ def _register_unavailable_formatter(
8080
_register_formatter(PandasFormatter, "pandas", aliases=["pd"])
8181
_register_formatter(CustomFormatter, "custom")
8282

83+
if config.POLARS_AVAILABLE:
84+
from .polars_formatter import PolarsFormatter
85+
86+
_register_formatter(PolarsFormatter, "polars", aliases=["pl"])
87+
else:
88+
_polars_error = ValueError("Polars needs to be installed to be able to return Polars dataframes.")
89+
_register_unavailable_formatter(_polars_error, "polars", aliases=["pl"])
90+
8391
if config.TORCH_AVAILABLE:
8492
from .torch_formatter import TorchFormatter
8593

Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
# Copyright 2020 The HuggingFace Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import sys
16+
from collections.abc import Mapping
17+
from functools import partial
18+
from typing import TYPE_CHECKING, Optional
19+
20+
import pyarrow as pa
21+
22+
from .. import config
23+
from ..features import Features
24+
from ..features.features import decode_nested_example
25+
from ..utils.py_utils import no_op_if_value_is_null
26+
from .formatting import BaseArrowExtractor, TensorFormatter
27+
28+
29+
if TYPE_CHECKING:
30+
import polars as pl
31+
32+
33+
class PolarsArrowExtractor(BaseArrowExtractor["pl.DataFrame", "pl.Series", "pl.DataFrame"]):
34+
def extract_row(self, pa_table: pa.Table) -> "pl.DataFrame":
35+
if config.POLARS_AVAILABLE:
36+
if "polars" not in sys.modules:
37+
import polars
38+
else:
39+
polars = sys.modules["polars"]
40+
41+
return polars.from_arrow(pa_table.slice(length=1))
42+
else:
43+
raise ValueError("Polars needs to be installed to be able to return Polars dataframes.")
44+
45+
def extract_column(self, pa_table: pa.Table) -> "pl.Series":
46+
if config.POLARS_AVAILABLE:
47+
if "polars" not in sys.modules:
48+
import polars
49+
else:
50+
polars = sys.modules["polars"]
51+
52+
return polars.from_arrow(pa_table.select([0]))[pa_table.column_names[0]]
53+
else:
54+
raise ValueError("Polars needs to be installed to be able to return Polars dataframes.")
55+
56+
def extract_batch(self, pa_table: pa.Table) -> "pl.DataFrame":
57+
if config.POLARS_AVAILABLE:
58+
if "polars" not in sys.modules:
59+
import polars
60+
else:
61+
polars = sys.modules["polars"]
62+
63+
return polars.from_arrow(pa_table)
64+
else:
65+
raise ValueError("Polars needs to be installed to be able to return Polars dataframes.")
66+
67+
68+
class PolarsFeaturesDecoder:
69+
def __init__(self, features: Optional[Features]):
70+
self.features = features
71+
import polars as pl # noqa: F401 - import pl at initialization
72+
73+
def decode_row(self, row: "pl.DataFrame") -> "pl.DataFrame":
74+
decode = (
75+
{
76+
column_name: no_op_if_value_is_null(partial(decode_nested_example, feature))
77+
for column_name, feature in self.features.items()
78+
if self.features._column_requires_decoding[column_name]
79+
}
80+
if self.features
81+
else {}
82+
)
83+
if decode:
84+
row[list(decode.keys())] = row.map_rows(decode)
85+
return row
86+
87+
def decode_column(self, column: "pl.Series", column_name: str) -> "pl.Series":
88+
decode = (
89+
no_op_if_value_is_null(partial(decode_nested_example, self.features[column_name]))
90+
if self.features and column_name in self.features and self.features._column_requires_decoding[column_name]
91+
else None
92+
)
93+
if decode:
94+
column = column.map_elements(decode)
95+
return column
96+
97+
def decode_batch(self, batch: "pl.DataFrame") -> "pl.DataFrame":
98+
return self.decode_row(batch)
99+
100+
101+
class PolarsFormatter(TensorFormatter[Mapping, "pl.DataFrame", Mapping]):
102+
def __init__(self, features=None, **np_array_kwargs):
103+
super().__init__(features=features)
104+
self.np_array_kwargs = np_array_kwargs
105+
self.polars_arrow_extractor = PolarsArrowExtractor
106+
self.polars_features_decoder = PolarsFeaturesDecoder(features)
107+
import polars as pl # noqa: F401 - import pl at initialization
108+
109+
def format_row(self, pa_table: pa.Table) -> "pl.DataFrame":
110+
row = self.polars_arrow_extractor().extract_row(pa_table)
111+
row = self.polars_features_decoder.decode_row(row)
112+
return row
113+
114+
def format_column(self, pa_table: pa.Table) -> "pl.Series":
115+
column = self.polars_arrow_extractor().extract_column(pa_table)
116+
column = self.polars_features_decoder.decode_column(column, pa_table.column_names[0])
117+
return column
118+
119+
def format_batch(self, pa_table: pa.Table) -> "pl.DataFrame":
120+
row = self.polars_arrow_extractor().extract_batch(pa_table)
121+
row = self.polars_features_decoder.decode_batch(row)
122+
return row

0 commit comments

Comments
 (0)