Skip to content
This repository was archived by the owner on Sep 13, 2023. It is now read-only.

Commit 92782f5

Browse files
authored
Fix stata saving (#624)
close #618
1 parent c08038a commit 92782f5

File tree

2 files changed

+55
-1
lines changed

2 files changed

+55
-1
lines changed

mlem/contrib/pandas.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -305,6 +305,11 @@ def has_index(df: pd.DataFrame):
305305
return not isinstance(df.index, pd.RangeIndex)
306306

307307

308+
def has_named_index(df: pd.DataFrame):
309+
"""Returns true if all index columns are named"""
310+
return df.index.name or all(df.index.names)
311+
312+
308313
def _reset_index(df: pd.DataFrame):
309314
"""Transforms indexes to columns"""
310315
index_name = df.index.name or "" # save it for future renaming
@@ -459,8 +464,19 @@ def write(
459464
write_kwargs.update(self.write_args)
460465
write_kwargs.update(kwargs)
461466

467+
# sometimes index may be consumed by model or used at feature engineering step,
468+
# so we keep it instead of dropping if it's non-trivial
462469
if has_index(df):
463-
df = reset_index(df)
470+
if PANDAS_FORMATS[
471+
"stata"
472+
].write_func == self.write_func and not has_named_index(df):
473+
logging.info(
474+
"Stata format doesn't allow saving columns with empty names, so you must name the index."
475+
"Use `df.index.name = 'index'` to name it or df.reset_index(drop=True) to drop it instead."
476+
)
477+
df = df.reset_index(drop=True)
478+
else:
479+
df = reset_index(df)
464480

465481
with storage.open(path) as (f, art):
466482
if self.string_buffer:

tests/contrib/test_pandas.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import pytest
1111
from fsspec.implementations.local import LocalFileSystem
1212
from pydantic import parse_obj_as
13+
from pytest_lazyfixture import lazy_fixture
1314
from sklearn.datasets import load_iris
1415
from sklearn.model_selection import train_test_split
1516

@@ -27,6 +28,7 @@
2728
PandasWriter,
2829
SeriesType,
2930
get_pandas_batch_formats,
31+
has_index,
3032
pd_type_from_string,
3133
python_type_from_pd_string_repr,
3234
python_type_from_pd_type,
@@ -510,6 +512,20 @@ def test_import_data_csv(tmpdir, write_csv, file_ext, type_, data):
510512
_check_data(meta, target_path)
511513

512514

515+
def test_import_data_stata(tmpdir, data):
516+
path = str(tmpdir / "mydata.stata")
517+
data.to_stata(path, write_index=False)
518+
meta = import_object(
519+
path, target=path, type_="pandas[stata]", copy_data=True
520+
)
521+
pandas_assert(
522+
data.astype(
523+
"int32"
524+
), # TODO: int32 converts to int64 for some reason for stata
525+
meta.get_value(),
526+
)
527+
528+
513529
@long
514530
def test_import_data_csv_remote(s3_tmp_path, s3_storage_fs, write_csv):
515531
project_path = s3_tmp_path("test_csv_import")
@@ -619,6 +635,28 @@ def f(x):
619635
assert set(get_object_requirements(sig).modules) == {"pandas"}
620636

621637

638+
@pytest.mark.parametrize(
639+
"df",
640+
[
641+
lazy_fixture("data"),
642+
lazy_fixture("data2"),
643+
],
644+
)
645+
def test_does_not_have_index(df):
646+
assert not has_index(df)
647+
648+
649+
@pytest.mark.parametrize(
650+
"df",
651+
[
652+
PD_DATA_FRAME_INDEX,
653+
PD_DATA_FRAME_MULTIINDEX,
654+
],
655+
)
656+
def test_has_index(df):
657+
assert has_index(df)
658+
659+
622660
# Copyright 2019 Zyfra
623661
# Copyright 2021 Iterative
624662
#

0 commit comments

Comments
 (0)