Skip to content

Commit b004478

Browse files
rhshadrachKeiundermyumbrella1
authored
BUG: groupby.agg with UDF changing pyarrow dtypes (#59601)
Co-authored-by: Kei <[email protected]> Co-authored-by: undermyumbrella1 <[email protected]> Co-authored-by: ellaella12 <[email protected]>
1 parent 0eaca9e commit b004478

File tree

4 files changed

+124
-12
lines changed

4 files changed

+124
-12
lines changed

doc/source/whatsnew/v3.0.0.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -852,6 +852,7 @@ Groupby/resample/rolling
852852
- Bug in :meth:`DataFrame.ewm` and :meth:`Series.ewm` when passed ``times`` and aggregation functions other than mean (:issue:`51695`)
853853
- Bug in :meth:`DataFrame.resample` and :meth:`Series.resample` were not keeping the index name when the index had :class:`ArrowDtype` timestamp dtype (:issue:`61222`)
854854
- Bug in :meth:`DataFrame.resample` changing index type to :class:`MultiIndex` when the dataframe is empty and using an upsample method (:issue:`55572`)
855+
- Bug in :meth:`DataFrameGroupBy.agg` and :meth:`SeriesGroupBy.agg` that was returning numpy dtype values when input values are pyarrow dtype values, instead of returning pyarrow dtype values. (:issue:`53030`)
855856
- Bug in :meth:`DataFrameGroupBy.agg` that raises ``AttributeError`` when there is dictionary input and duplicated columns, instead of returning a DataFrame with the aggregation of all duplicate columns. (:issue:`55041`)
856857
- Bug in :meth:`DataFrameGroupBy.agg` where applying a user-defined function to an empty DataFrame returned a Series instead of an empty DataFrame. (:issue:`61503`)
857858
- Bug in :meth:`DataFrameGroupBy.apply` and :meth:`SeriesGroupBy.apply` for empty data frame with ``group_keys=False`` still creating output index using group keys. (:issue:`60471`)

pandas/core/groupby/ops.py

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -44,13 +44,15 @@
4444
ensure_platform_int,
4545
ensure_uint64,
4646
is_1d_only_ea_dtype,
47+
is_string_dtype,
4748
)
4849
from pandas.core.dtypes.missing import (
4950
isna,
5051
maybe_fill,
5152
)
5253

5354
from pandas.core.arrays import Categorical
55+
from pandas.core.arrays.arrow.array import ArrowExtensionArray
5456
from pandas.core.frame import DataFrame
5557
from pandas.core.groupby import grouper
5658
from pandas.core.indexes.api import (
@@ -963,18 +965,26 @@ def agg_series(
963965
-------
964966
np.ndarray or ExtensionArray
965967
"""
968+
result = self._aggregate_series_pure_python(obj, func)
969+
npvalues = lib.maybe_convert_objects(result, try_float=False)
970+
971+
if isinstance(obj._values, ArrowExtensionArray):
972+
# When obj.dtype is a string, any object can be cast. Only do so if the
973+
# UDF returned strings or NA values.
974+
if not is_string_dtype(obj.dtype) or lib.is_string_array(
975+
npvalues, skipna=True
976+
):
977+
out = maybe_cast_pointwise_result(
978+
npvalues, obj.dtype, numeric_only=True, same_dtype=preserve_dtype
979+
)
980+
else:
981+
out = npvalues
966982

967-
if not isinstance(obj._values, np.ndarray):
983+
elif not isinstance(obj._values, np.ndarray):
968984
# we can preserve a little bit more aggressively with EA dtype
969985
# because maybe_cast_pointwise_result will do a try/except
970986
# with _from_sequence. NB we are assuming here that _from_sequence
971987
# is sufficiently strict that it casts appropriately.
972-
preserve_dtype = True
973-
974-
result = self._aggregate_series_pure_python(obj, func)
975-
976-
npvalues = lib.maybe_convert_objects(result, try_float=False)
977-
if preserve_dtype:
978988
out = maybe_cast_pointwise_result(npvalues, obj.dtype, numeric_only=True)
979989
else:
980990
out = npvalues

pandas/tests/groupby/aggregate/test_aggregate.py

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import pytest
1111

1212
from pandas.errors import SpecificationError
13+
import pandas.util._test_decorators as td
1314

1415
from pandas.core.dtypes.common import is_integer_dtype
1516

@@ -23,6 +24,7 @@
2324
to_datetime,
2425
)
2526
import pandas._testing as tm
27+
from pandas.arrays import ArrowExtensionArray
2628
from pandas.core.groupby.grouper import Grouping
2729

2830

@@ -1809,6 +1811,102 @@ def test_groupby_aggregation_func_list_multi_index_duplicate_columns():
18091811
tm.assert_frame_equal(result, expected)
18101812

18111813

1814+
@td.skip_if_no("pyarrow")
1815+
@pytest.mark.parametrize(
1816+
"input_dtype, output_dtype",
1817+
[
1818+
# With NumPy arrays, the results from the UDF would be e.g. np.float32 scalars
1819+
# which we can therefore preserve. However with PyArrow arrays, the results are
1820+
# Python scalars so we have no information about size or uint vs int.
1821+
("float[pyarrow]", "double[pyarrow]"),
1822+
("int64[pyarrow]", "int64[pyarrow]"),
1823+
("uint64[pyarrow]", "int64[pyarrow]"),
1824+
("bool[pyarrow]", "bool[pyarrow]"),
1825+
],
1826+
)
1827+
def test_agg_lambda_pyarrow_dtype_conversion(input_dtype, output_dtype):
1828+
# GH#59601
1829+
# Test PyArrow dtype conversion back to PyArrow dtype
1830+
df = DataFrame(
1831+
{
1832+
"A": ["c1", "c2", "c3", "c1", "c2", "c3"],
1833+
"B": pd.array([100, 200, 255, 0, 199, 40392], dtype=input_dtype),
1834+
}
1835+
)
1836+
gb = df.groupby("A")
1837+
result = gb.agg(lambda x: x.min())
1838+
1839+
expected = DataFrame(
1840+
{"B": pd.array([0, 199, 255], dtype=output_dtype)},
1841+
index=Index(["c1", "c2", "c3"], name="A"),
1842+
)
1843+
tm.assert_frame_equal(result, expected)
1844+
1845+
1846+
@td.skip_if_no("pyarrow")
1847+
def test_agg_lambda_complex128_dtype_conversion():
1848+
# GH#59601
1849+
df = DataFrame(
1850+
{"A": ["c1", "c2", "c3"], "B": pd.array([100, 200, 255], "int64[pyarrow]")}
1851+
)
1852+
gb = df.groupby("A")
1853+
result = gb.agg(lambda x: complex(x.sum(), x.count()))
1854+
1855+
expected = DataFrame(
1856+
{
1857+
"B": pd.array(
1858+
[complex(100, 1), complex(200, 1), complex(255, 1)], dtype="complex128"
1859+
),
1860+
},
1861+
index=Index(["c1", "c2", "c3"], name="A"),
1862+
)
1863+
tm.assert_frame_equal(result, expected)
1864+
1865+
1866+
@td.skip_if_no("pyarrow")
1867+
def test_agg_lambda_numpy_uint64_to_pyarrow_dtype_conversion():
1868+
# GH#59601
1869+
df = DataFrame(
1870+
{
1871+
"A": ["c1", "c2", "c3"],
1872+
"B": pd.array([100, 200, 255], dtype="uint64[pyarrow]"),
1873+
}
1874+
)
1875+
gb = df.groupby("A")
1876+
result = gb.agg(lambda x: np.uint64(x.sum()))
1877+
1878+
expected = DataFrame(
1879+
{
1880+
"B": pd.array([100, 200, 255], dtype="uint64[pyarrow]"),
1881+
},
1882+
index=Index(["c1", "c2", "c3"], name="A"),
1883+
)
1884+
tm.assert_frame_equal(result, expected)
1885+
1886+
1887+
@td.skip_if_no("pyarrow")
1888+
def test_agg_lambda_pyarrow_struct_to_object_dtype_conversion():
1889+
# GH#59601
1890+
import pyarrow as pa
1891+
1892+
df = DataFrame(
1893+
{
1894+
"A": ["c1", "c2", "c3"],
1895+
"B": pd.array([100, 200, 255], dtype="int64[pyarrow]"),
1896+
}
1897+
)
1898+
gb = df.groupby("A")
1899+
result = gb.agg(lambda x: {"number": 1})
1900+
1901+
arr = pa.array([{"number": 1}, {"number": 1}, {"number": 1}])
1902+
expected = DataFrame(
1903+
{"B": ArrowExtensionArray(arr)},
1904+
index=Index(["c1", "c2", "c3"], name="A"),
1905+
)
1906+
1907+
tm.assert_frame_equal(result, expected)
1908+
1909+
18121910
def test_groupby_aggregate_empty_builtin_sum():
18131911
df = DataFrame(columns=["Group", "Data"])
18141912
result = df.groupby(["Group"], as_index=False)["Data"].agg("sum")

pandas/tests/groupby/test_groupby.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2434,25 +2434,28 @@ def test_rolling_wrong_param_min_period():
24342434

24352435
def test_by_column_values_with_same_starting_value(any_string_dtype):
24362436
# GH29635
2437+
dtype = any_string_dtype
24372438
df = DataFrame(
24382439
{
24392440
"Name": ["Thomas", "Thomas", "Thomas John"],
24402441
"Credit": [1200, 1300, 900],
2441-
"Mood": Series(["sad", "happy", "happy"], dtype=any_string_dtype),
2442+
"Mood": Series(["sad", "happy", "happy"], dtype=dtype),
24422443
}
24432444
)
24442445
aggregate_details = {"Mood": Series.mode, "Credit": "sum"}
24452446

24462447
result = df.groupby(["Name"]).agg(aggregate_details)
2447-
expected_result = DataFrame(
2448+
expected = DataFrame(
24482449
{
24492450
"Mood": [["happy", "sad"], "happy"],
24502451
"Credit": [2500, 900],
24512452
"Name": ["Thomas", "Thomas John"],
2452-
}
2453+
},
24532454
).set_index("Name")
2454-
2455-
tm.assert_frame_equal(result, expected_result)
2455+
if getattr(dtype, "storage", None) == "pyarrow":
2456+
mood_values = pd.array(["happy", "sad"], dtype=dtype)
2457+
expected["Mood"] = [mood_values, "happy"]
2458+
tm.assert_frame_equal(result, expected)
24562459

24572460

24582461
def test_groupby_none_in_first_mi_level():

0 commit comments

Comments
 (0)