Skip to content

Commit 0f17f86

Browse files
authored
TYP: GroupBy and native types (#1556)
* native types and groupby * #1556 (comment) * more reverts * furthre eliminate GroupBy[Any, Any] in tests https://github.com/pandas-dev/pandas-stubs/pull/1556/changes#r2624041237
1 parent cb6208b commit 0f17f86

File tree

9 files changed

+52
-35
lines changed

9 files changed

+52
-35
lines changed

pandas-stubs/core/frame.pyi

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -463,7 +463,7 @@ class DataFrame(NDFrame, OpsMixin, _GetItemHack):
463463
self,
464464
orient: str = ...,
465465
*,
466-
into: type[defaultdict],
466+
into: type[defaultdict[Any, Any]],
467467
index: Literal[True] = True,
468468
) -> Never: ...
469469
@overload
@@ -479,7 +479,7 @@ class DataFrame(NDFrame, OpsMixin, _GetItemHack):
479479
self,
480480
orient: Literal["records"],
481481
*,
482-
into: type[dict] = ...,
482+
into: type[dict[Any, Any]] = ...,
483483
index: Literal[True] = True,
484484
) -> list[dict[Hashable, Any]]: ...
485485
@overload
@@ -495,23 +495,23 @@ class DataFrame(NDFrame, OpsMixin, _GetItemHack):
495495
self,
496496
orient: Literal["index"],
497497
*,
498-
into: OrderedDict | type[OrderedDict],
498+
into: OrderedDict[Any, Any] | type[OrderedDict[Any, Any]],
499499
index: Literal[True] = True,
500500
) -> OrderedDict[Hashable, dict[Hashable, Any]]: ...
501501
@overload
502502
def to_dict(
503503
self,
504504
orient: Literal["index"],
505505
*,
506-
into: type[MutableMapping],
506+
into: type[MutableMapping[Any, Any]],
507507
index: Literal[True] = True,
508508
) -> MutableMapping[Hashable, dict[Hashable, Any]]: ...
509509
@overload
510510
def to_dict(
511511
self,
512512
orient: Literal["index"],
513513
*,
514-
into: type[dict] = ...,
514+
into: type[dict[Any, Any]] = ...,
515515
index: Literal[True] = True,
516516
) -> dict[Hashable, dict[Hashable, Any]]: ...
517517
@overload
@@ -527,23 +527,23 @@ class DataFrame(NDFrame, OpsMixin, _GetItemHack):
527527
self,
528528
orient: Literal["dict", "list", "series"] = ...,
529529
*,
530-
into: type[dict] = ...,
530+
into: type[dict[Any, Any]] = ...,
531531
index: Literal[True] = True,
532532
) -> dict[Hashable, Any]: ...
533533
@overload
534534
def to_dict(
535535
self,
536536
orient: Literal["split", "tight"],
537537
*,
538-
into: MutableMapping[Any, Any] | type[MutableMapping],
538+
into: MutableMapping[Any, Any] | type[MutableMapping[Any, Any]],
539539
index: bool = ...,
540540
) -> MutableMapping[str, list[Any]]: ...
541541
@overload
542542
def to_dict(
543543
self,
544544
orient: Literal["split", "tight"],
545545
*,
546-
into: type[dict] = ...,
546+
into: type[dict[Any, Any]] = ...,
547547
index: bool = ...,
548548
) -> dict[str, list[Any]]: ...
549549
@classmethod

pandas-stubs/core/groupby/groupby.pyi

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -338,7 +338,7 @@ class GroupBy(BaseGroupBy[NDFrameT]):
338338
random_state: RandomState | None = ...,
339339
) -> NDFrameT: ...
340340

341-
_GroupByT = TypeVar("_GroupByT", bound=GroupBy)
341+
_GroupByT = TypeVar("_GroupByT", bound=GroupBy[Any])
342342

343343
# GroupByPlot does not really inherit from PlotAccessor but it delegates
344344
# to it using __call__ and __getattr__. We lie here to avoid repeating the
@@ -383,15 +383,15 @@ class BaseGroupBy(SelectionMixin[NDFrameT], GroupByIndexingMixin):
383383
@final
384384
def __iter__(self) -> Iterator[tuple[Hashable, NDFrameT]]: ...
385385
@overload
386-
def __getitem__(self: BaseGroupBy[DataFrame], key: Scalar) -> generic.SeriesGroupBy: ... # type: ignore[overload-overlap] # pyright: ignore[reportOverlappingOverload]
386+
def __getitem__(self: BaseGroupBy[DataFrame], key: Scalar) -> generic.SeriesGroupBy[Any, Any]: ... # type: ignore[overload-overlap] # pyright: ignore[reportOverlappingOverload]
387387
@overload
388388
def __getitem__(
389389
self: BaseGroupBy[DataFrame], key: Iterable[Hashable]
390-
) -> generic.DataFrameGroupBy: ...
390+
) -> generic.DataFrameGroupBy[Any, Any]: ...
391391
@overload
392392
def __getitem__(
393393
self: BaseGroupBy[Series[S1]],
394394
idx: list[str] | Index | Series[S1] | MaskType | tuple[Hashable | slice, ...],
395-
) -> generic.SeriesGroupBy: ...
395+
) -> generic.SeriesGroupBy[Any, Any]: ...
396396
@overload
397397
def __getitem__(self: BaseGroupBy[Series[S1]], idx: Scalar) -> S1: ...

pandas-stubs/core/resample.pyi

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ _SeriesGroupByFuncArgs: TypeAlias = (
5555
)
5656

5757
class Resampler(BaseGroupBy[NDFrameT]):
58-
def __getattr__(self, attr: str) -> SeriesGroupBy: ...
58+
def __getattr__(self, attr: str) -> SeriesGroupBy[Any, Any]: ...
5959
@overload
6060
def aggregate(
6161
self: Resampler[DataFrame],

pandas-stubs/core/series.pyi

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -779,10 +779,10 @@ class Series(IndexOpsMixin[S1], ElementOpsMixin[S1], NDFrame):
779779
def items(self) -> Iterator[tuple[Hashable, S1]]: ...
780780
def keys(self) -> Index: ...
781781
@overload
782-
def to_dict(self, *, into: type[dict] = ...) -> dict[Any, S1]: ...
782+
def to_dict(self, *, into: type[dict[Any, Any]] = ...) -> dict[Hashable, S1]: ...
783783
@overload
784784
def to_dict(
785-
self, *, into: type[MutableMapping] | MutableMapping[Any, Any]
785+
self, *, into: type[MutableMapping[Any, Any]] | MutableMapping[Any, Any]
786786
) -> MutableMapping[Hashable, S1]: ...
787787
def to_frame(self, name: object | None = ...) -> DataFrame: ...
788788
@overload
@@ -1105,7 +1105,7 @@ class Series(IndexOpsMixin[S1], ElementOpsMixin[S1], NDFrame):
11051105
def swaplevel(
11061106
self, i: Level = -2, j: Level = -1, copy: _bool = True
11071107
) -> Series[S1]: ...
1108-
def reorder_levels(self, order: list) -> Series[S1]: ...
1108+
def reorder_levels(self, order: list[Any]) -> Series[S1]: ...
11091109
def explode(self, ignore_index: _bool = ...) -> Series[S1]: ...
11101110
def unstack(
11111111
self,

tests/frame/test_frame.py

Lines changed: 26 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3021,10 +3021,10 @@ def test_to_dict_simple() -> None:
30213021
check(assert_type(data.to_dict("dict"), dict[Hashable, Any]), dict)
30223022
check(assert_type(data.to_dict("list"), dict[Hashable, Any]), dict)
30233023
check(assert_type(data.to_dict("series"), dict[Hashable, Any]), dict)
3024-
check(assert_type(data.to_dict("split"), dict[str, list]), dict, str)
3024+
check(assert_type(data.to_dict("split"), dict[str, list[Any]]), dict, str)
30253025

30263026
# orient param accepting "tight" added in 1.4.0 https://pandas.pydata.org/docs/whatsnew/v1.4.0.html
3027-
check(assert_type(data.to_dict("tight"), dict[str, list]), dict, str)
3027+
check(assert_type(data.to_dict("tight"), dict[str, list[Any]]), dict, str)
30283028

30293029
if TYPE_CHECKING_INVALID_USAGE:
30303030

@@ -3075,7 +3075,7 @@ def test_to_dict_into_defaultdict() -> None:
30753075
defaultdict,
30763076
)
30773077
check(
3078-
assert_type(data.to_dict("tight", into=target), MutableMapping[str, list]),
3078+
assert_type(data.to_dict("tight", into=target), MutableMapping[str, list[Any]]),
30793079
defaultdict,
30803080
str,
30813081
)
@@ -3093,7 +3093,11 @@ def test_to_dict_into_ordered_dict() -> None:
30933093

30943094
data = pd.DataFrame({("str", "rts"): [[1, 2, 4], [2, 3], [3]]})
30953095

3096-
check(assert_type(data.to_dict(into=OrderedDict), OrderedDict), OrderedDict, tuple)
3096+
check(
3097+
assert_type(data.to_dict(into=OrderedDict), OrderedDict[Any, Any]),
3098+
OrderedDict,
3099+
tuple,
3100+
)
30973101
check(
30983102
assert_type(
30993103
data.to_dict("index", into=OrderedDict),
@@ -3102,12 +3106,16 @@ def test_to_dict_into_ordered_dict() -> None:
31023106
OrderedDict,
31033107
)
31043108
check(
3105-
assert_type(data.to_dict("tight", into=OrderedDict), MutableMapping[str, list]),
3109+
assert_type(
3110+
data.to_dict("tight", into=OrderedDict), MutableMapping[str, list[Any]]
3111+
),
31063112
OrderedDict,
31073113
str,
31083114
)
31093115
check(
3110-
assert_type(data.to_dict("records", into=OrderedDict), list[OrderedDict]),
3116+
assert_type(
3117+
data.to_dict("records", into=OrderedDict), list[OrderedDict[Any, Any]]
3118+
),
31113119
list,
31123120
OrderedDict,
31133121
)
@@ -3446,16 +3454,24 @@ def test_to_dict_index() -> None:
34463454
dict,
34473455
)
34483456
check(
3449-
assert_type(df.to_dict(orient="split", index=True), dict[str, list]), dict, str
3457+
assert_type(df.to_dict(orient="split", index=True), dict[str, list[Any]]),
3458+
dict,
3459+
str,
34503460
)
34513461
check(
3452-
assert_type(df.to_dict(orient="tight", index=True), dict[str, list]), dict, str
3462+
assert_type(df.to_dict(orient="tight", index=True), dict[str, list[Any]]),
3463+
dict,
3464+
str,
34533465
)
34543466
check(
3455-
assert_type(df.to_dict(orient="tight", index=False), dict[str, list]), dict, str
3467+
assert_type(df.to_dict(orient="tight", index=False), dict[str, list[Any]]),
3468+
dict,
3469+
str,
34563470
)
34573471
check(
3458-
assert_type(df.to_dict(orient="split", index=False), dict[str, list]), dict, str
3472+
assert_type(df.to_dict(orient="split", index=False), dict[str, list[Any]]),
3473+
dict,
3474+
str,
34593475
)
34603476
if TYPE_CHECKING_INVALID_USAGE:
34613477
_0 = df.to_dict(orient="records", index=False) # type: ignore[call-overload] # pyright: ignore[reportArgumentType,reportCallIssue]

tests/series/test_series.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -838,7 +838,7 @@ def makeseries(x: float) -> pd.Series:
838838
def retseries(x: float) -> float:
839839
return x
840840

841-
check(assert_type(s.apply(retseries).tolist(), list), list)
841+
check(assert_type(s.apply(retseries).tolist(), list[Any]), list)
842842

843843
def retlist(x: float) -> list[float]:
844844
return [x]
@@ -1780,7 +1780,7 @@ def test_types_to_list() -> None:
17801780

17811781
def test_types_to_dict() -> None:
17821782
s = pd.Series(["a", "b", "c"], dtype=str)
1783-
assert_type(s.to_dict(), dict[Any, str])
1783+
assert_type(s.to_dict(), dict[Hashable, str])
17841784

17851785

17861786
def test_categorical_codes() -> None:
@@ -2182,7 +2182,7 @@ def test_change_to_dict_return_type() -> None:
21822182
value = ["a", "b", "c"]
21832183
df = pd.DataFrame(zip(id, value), columns=["id", "value"])
21842184
fd = df.set_index("id")["value"].to_dict()
2185-
check(assert_type(fd, dict[Any, Any]), dict)
2185+
check(assert_type(fd, dict[Hashable, Any]), dict)
21862186

21872187

21882188
ASTYPE_BOOL_ARGS: list[tuple[BooleanDtypeArg, type]] = [

tests/test_api_typing.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
# pyright: reportMissingTypeArgument=false
12
"""Test module for classes in pandas.api.typing."""
23

34
from typing import TypeAlias
@@ -26,9 +27,7 @@
2627
Window,
2728
)
2829
import pytest
29-
from typing_extensions import (
30-
assert_type,
31-
)
30+
from typing_extensions import assert_type
3231

3332
from tests import (
3433
check,

tests/test_extension.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import decimal
2+
from typing import Any
23

34
import numpy as np
45
import pandas as pd
@@ -29,9 +30,9 @@ def test_tolist() -> None:
2930
s1 = pd.Series(data1)
3031
# python/mypy#19952: mypy believes ExtensionArray and its subclasses have a
3132
# conflict and gives Any for s.array
32-
check(assert_type(s.array.tolist(), list), list) # type: ignore[assert-type]
33-
check(assert_type(s1.array.tolist(), list), list)
34-
check(assert_type(pd.array([1, 2, 3]).tolist(), list), list)
33+
check(assert_type(s.array.tolist(), list[Any]), list) # type: ignore[assert-type]
34+
check(assert_type(s1.array.tolist(), list[Any]), list)
35+
check(assert_type(pd.array([1, 2, 3]).tolist(), list[Any]), list)
3536

3637

3738
def test_ExtensionArray_reduce_accumulate() -> None:

tests/test_resampler.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
# pyright: reportMissingTypeArgument=false
12
from collections.abc import (
23
Hashable,
34
Iterator,

0 commit comments

Comments
 (0)