Skip to content

Commit fb61ded

Browse files
authored
Add Est stat and Interval mark to show error bars (#2912)
* Add a docstring for seaborn.objects namespace * Add Est stat (mostly copied from EstimateAggregator * Handle cases where x or y are not defined better * Improve datalim update with collections * Handle matplotlib edge cases with line capstyles * Add Interval mark * Add Interval unit tests * Revert Est to use EstimateAggregator and add (light) tests * Pandas (?) backcompat
1 parent b5a85ff commit fb61ded

File tree

9 files changed

+294
-47
lines changed

9 files changed

+294
-47
lines changed

seaborn/_core/plot.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1338,10 +1338,11 @@ def split_generator(keep_na=False) -> Generator:
13381338
# Matplotlib (usually?) masks nan data, so this should "work".
13391339
# Downstream code can also drop these rows, at some speed cost.
13401340
present = axes_df.notna().all(axis=1)
1341-
axes_df = axes_df.assign(
1342-
x=axes_df["x"].where(present),
1343-
y=axes_df["y"].where(present),
1344-
)
1341+
nulled = {}
1342+
for axis in "xy":
1343+
if axis in axes_df:
1344+
nulled[axis] = axes_df[axis].where(present)
1345+
axes_df = axes_df.assign(**nulled)
13451346
else:
13461347
axes_df = axes_df.dropna()
13471348

seaborn/_core/scales.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -350,7 +350,10 @@ def normalize(x):
350350
]
351351

352352
def spacer(x):
353-
return np.min(np.diff(np.sort(x.dropna().unique())))
353+
x = x.dropna().unique()
354+
if len(x) < 2:
355+
return np.nan
356+
return np.min(np.diff(np.sort(x)))
354357
new._spacer = spacer
355358

356359
# TODO How to allow disabling of legend for all uses of property?

seaborn/_marks/bars.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -200,10 +200,8 @@ def _plot(self, split_gen, scales, orient):
200200
# Workaround for matplotlib autoscaling bug
201201
# https://github.com/matplotlib/matplotlib/issues/11898
202202
# https://github.com/matplotlib/matplotlib/issues/23129
203-
xy = np.vstack([path.vertices for path in col.get_paths()])
204-
ax.dataLim.update_from_data_xy(
205-
xy, ax.ignore_existing_data_limits, updatex=True, updatey=True
206-
)
203+
xys = np.vstack([path.vertices for path in col.get_paths()])
204+
ax.update_datalim(xys)
207205

208206
if "edgewidth" not in scales and isinstance(self.edgewidth, Mappable):
209207

seaborn/_marks/lines.py

Lines changed: 80 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,9 @@ def _plot(self, split_gen, scales, orient):
5050
if self._sort:
5151
data = data.sort_values(orient)
5252

53+
artist_kws = self.artist_kws.copy()
54+
self._handle_capstyle(artist_kws, vals)
55+
5356
line = mpl.lines.Line2D(
5457
data["x"].to_numpy(),
5558
data["y"].to_numpy(),
@@ -61,7 +64,7 @@ def _plot(self, split_gen, scales, orient):
6164
markerfacecolor=vals["fillcolor"],
6265
markeredgecolor=vals["edgecolor"],
6366
markeredgewidth=vals["edgewidth"],
64-
**self.artist_kws,
67+
**artist_kws,
6568
)
6669
ax.add_line(line)
6770

@@ -77,6 +80,9 @@ def _legend_artist(self, variables, value, scales):
7780
if Version(mpl.__version__) < Version("3.3.0"):
7881
vals["marker"] = vals["marker"]._marker
7982

83+
artist_kws = self.artist_kws.copy()
84+
self._handle_capstyle(artist_kws, vals)
85+
8086
return mpl.lines.Line2D(
8187
[], [],
8288
color=vals["color"],
@@ -87,9 +93,17 @@ def _legend_artist(self, variables, value, scales):
8793
markerfacecolor=vals["fillcolor"],
8894
markeredgecolor=vals["edgecolor"],
8995
markeredgewidth=vals["edgewidth"],
90-
**self.artist_kws,
96+
**artist_kws,
9197
)
9298

99+
def _handle_capstyle(self, kws, vals):
100+
101+
# Work around for this matplotlib issue:
102+
# https://github.com/matplotlib/matplotlib/issues/23437
103+
if vals["linestyle"][1] is None:
104+
capstyle = kws.get("solid_capstyle", mpl.rcParams["lines.solid_capstyle"])
105+
kws["dash_capstyle"] = capstyle
106+
93107

94108
@dataclass
95109
class Line(Path):
@@ -111,7 +125,15 @@ class Paths(Mark):
111125

112126
_sort: ClassVar[bool] = False
113127

114-
def _plot(self, split_gen, scales, orient):
128+
def __post_init__(self):
129+
130+
# LineCollection artists have a capstyle property but don't source its value
131+
# from the rc, so we do that manually here. Unfortunately, because we add
132+
# only one LineCollection, we have the use the same capstyle for all lines
133+
# even when they are dashed. It's a slight inconsistency, but looks fine IMO.
134+
self.artist_kws.setdefault("capstyle", mpl.rcParams["lines.solid_capstyle"])
135+
136+
def _setup_lines(self, split_gen, scales, orient):
115137

116138
line_data = {}
117139

@@ -131,36 +153,42 @@ def _plot(self, split_gen, scales, orient):
131153
if self._sort:
132154
data = data.sort_values(orient)
133155

134-
# TODO comment about block consolidation
156+
# Column stack to avoid block consolidation
135157
xy = np.column_stack([data["x"], data["y"]])
136158
line_data[ax]["segments"].append(xy)
137159
line_data[ax]["colors"].append(vals["color"])
138160
line_data[ax]["linewidths"].append(vals["linewidth"])
139161
line_data[ax]["linestyles"].append(vals["linestyle"])
140162

163+
return line_data
164+
165+
def _plot(self, split_gen, scales, orient):
166+
167+
line_data = self._setup_lines(split_gen, scales, orient)
168+
141169
for ax, ax_data in line_data.items():
142-
lines = mpl.collections.LineCollection(
143-
**ax_data,
144-
**self.artist_kws,
145-
)
146-
ax.add_collection(lines, autolim=False)
170+
lines = mpl.collections.LineCollection(**ax_data, **self.artist_kws)
171+
# Handle datalim update manually
147172
# https://github.com/matplotlib/matplotlib/issues/23129
148-
# TODO get paths from lines object?
173+
ax.add_collection(lines, autolim=False)
149174
xy = np.concatenate(ax_data["segments"])
150-
ax.dataLim.update_from_data_xy(
151-
xy, ax.ignore_existing_data_limits, updatex=True, updatey=True
152-
)
175+
ax.update_datalim(xy)
153176

154177
def _legend_artist(self, variables, value, scales):
155178

156179
key = resolve_properties(self, {v: value for v in variables}, scales)
157180

181+
artist_kws = self.artist_kws.copy()
182+
capstyle = artist_kws.pop("capstyle")
183+
artist_kws["solid_capstyle"] = capstyle
184+
artist_kws["dash_capstyle"] = capstyle
185+
158186
return mpl.lines.Line2D(
159187
[], [],
160188
color=key["color"],
161189
linewidth=key["linewidth"],
162190
linestyle=key["linestyle"],
163-
**self.artist_kws,
191+
**artist_kws,
164192
)
165193

166194

@@ -170,3 +198,41 @@ class Lines(Paths):
170198
A faster but less-flexible mark for drawing many lines.
171199
"""
172200
_sort: ClassVar[bool] = True
201+
202+
203+
@dataclass
204+
class Interval(Paths):
205+
"""
206+
An oriented line mark drawn between min/max values.
207+
"""
208+
def _setup_lines(self, split_gen, scales, orient):
209+
210+
line_data = {}
211+
212+
other = {"x": "y", "y": "x"}[orient]
213+
214+
for keys, data, ax in split_gen(keep_na=not self._sort):
215+
216+
if ax not in line_data:
217+
line_data[ax] = {
218+
"segments": [],
219+
"colors": [],
220+
"linewidths": [],
221+
"linestyles": [],
222+
}
223+
224+
vals = resolve_properties(self, keys, scales)
225+
vals["color"] = resolve_color(self, keys, scales=scales)
226+
227+
cols = [orient, f"{other}min", f"{other}max"]
228+
data = data[cols].melt(orient, value_name=other)[["x", "y"]]
229+
segments = [d.to_numpy() for _, d in data.groupby(orient)]
230+
231+
line_data[ax]["segments"].extend(segments)
232+
233+
n = len(segments)
234+
line_data[ax]["colors"].extend([vals["color"]] * n)
235+
line_data[ax]["linewidths"].extend([vals["linewidth"]] * n)
236+
line_data[ax]["linestyles"].extend([vals["linestyle"]] * n)
237+
238+
return line_data

seaborn/_stats/aggregation.py

Lines changed: 57 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,16 @@
11
from __future__ import annotations
22
from dataclasses import dataclass
3-
from typing import ClassVar
3+
from typing import ClassVar, Callable
44

5+
import pandas as pd
6+
from pandas import DataFrame
7+
8+
from seaborn._core.scales import Scale
9+
from seaborn._core.groupby import GroupBy
510
from seaborn._stats.base import Stat
11+
from seaborn._statistics import EstimateAggregator
612

7-
from typing import TYPE_CHECKING
8-
if TYPE_CHECKING:
9-
from typing import Callable
10-
from numbers import Number
11-
from seaborn._core.typing import Vector
13+
from seaborn._core.typing import Vector
1214

1315

1416
@dataclass
@@ -18,23 +20,22 @@ class Agg(Stat):
1820
1921
Parameters
2022
----------
21-
func
22-
Name of a method understood by Pandas or an arbitrary vector -> scalar function.
23+
func : str or callable
24+
Name of a :class:`pandas.Series` method or a vector -> scalar function.
2325
2426
"""
25-
# TODO In current practice we will always have a numeric x/y variable,
26-
# but they may represent non-numeric values. Needs clear documentation.
27-
func: str | Callable[[Vector], Number] = "mean"
27+
func: str | Callable[[Vector], float] = "mean"
2828

2929
group_by_orient: ClassVar[bool] = True
3030

31-
def __call__(self, data, groupby, orient, scales):
31+
def __call__(
32+
self, data: DataFrame, groupby: GroupBy, orient: str, scales: dict[str, Scale],
33+
) -> DataFrame:
3234

3335
var = {"x": "y", "y": "x"}.get(orient)
3436
res = (
3537
groupby
3638
.agg(data, {var: self.func})
37-
# TODO Could be an option not to drop NA?
3839
.dropna()
3940
.reset_index(drop=True)
4041
)
@@ -43,19 +44,56 @@ def __call__(self, data, groupby, orient, scales):
4344

4445
@dataclass
4546
class Est(Stat):
47+
"""
48+
Calculate a point estimate and error bar interval.
4649
47-
# TODO a string here must be a numpy ufunc?
48-
func: str | Callable[[Vector], Number] = "mean"
50+
Parameters
51+
----------
52+
func : str or callable
53+
Name of a :class:`numpy.ndarray` method or a vector -> scalar function.
54+
errorbar : str, (str, float) tuple, or callable
55+
Name of errorbar method (one of "ci", "pi", "se" or "sd"), or a tuple
56+
with a method name ane a level parameter, or a function that maps from a
57+
vector to a (min, max) interval.
58+
n_boot : int
59+
Number of bootstrap samples to draw for "ci" errorbars.
60+
seed : int
61+
Seed for the PRNG used to draw bootstrap samples.
4962
50-
# TODO type errorbar options with literal?
63+
"""
64+
func: str | Callable[[Vector], float] = "mean"
5165
errorbar: str | tuple[str, float] = ("ci", 95)
66+
n_boot: int = 1000
67+
seed: int | None = None
5268

5369
group_by_orient: ClassVar[bool] = True
5470

55-
def __call__(self, data, groupby, orient, scales):
71+
def _process(
72+
self, data: DataFrame, var: str, estimator: EstimateAggregator
73+
) -> DataFrame:
74+
# Needed because GroupBy.apply assumes func is DataFrame -> DataFrame
75+
# which we could probably make more general to allow Series return
76+
res = estimator(data, var)
77+
return pd.DataFrame([res])
5678

57-
# TODO port code over from _statistics
58-
...
79+
def __call__(
80+
self, data: DataFrame, groupby: GroupBy, orient: str, scales: dict[str, Scale],
81+
) -> DataFrame:
82+
83+
boot_kws = {"n_boot": self.n_boot, "seed": self.seed}
84+
engine = EstimateAggregator(self.func, self.errorbar, **boot_kws)
85+
86+
var = {"x": "y", "y": "x"}.get(orient)
87+
res = (
88+
groupby
89+
.apply(data, self._process, var, engine)
90+
.dropna(subset=["x", "y"])
91+
.reset_index(drop=True)
92+
)
93+
94+
res = res.fillna({f"{var}min": res[var], f"{var}max": res[var]})
95+
96+
return res
5997

6098

6199
@dataclass

seaborn/_stats/histograms.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,9 @@ class Hist(Stat):
3131
# Q: would Discrete() scale imply binwidth=1 or bins centered on integers?
3232
discrete: bool = False
3333

34+
# TODO Note that these methods are mostly copied from _statistics.Histogram,
35+
# but it only computes univariate histograms. We should reconcile the code.
36+
3437
def _define_bin_edges(self, vals, weight, bins, binwidth, binrange, discrete):
3538
"""Inner function that takes bin parameters as arguments."""
3639
vals = vals.dropna()

seaborn/objects.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,11 +31,11 @@
3131
from seaborn._marks.base import Mark # noqa: F401
3232
from seaborn._marks.area import Area, Ribbon # noqa: F401
3333
from seaborn._marks.bars import Bar, Bars # noqa: F401
34-
from seaborn._marks.lines import Line, Lines, Path, Paths # noqa: F401
34+
from seaborn._marks.lines import Line, Lines, Path, Paths, Interval # noqa: F401
3535
from seaborn._marks.scatter import Dot, Scatter # noqa: F401
3636

3737
from seaborn._stats.base import Stat # noqa: F401
38-
from seaborn._stats.aggregation import Agg # noqa: F401
38+
from seaborn._stats.aggregation import Agg, Est # noqa: F401
3939
from seaborn._stats.regression import OLSFit, PolyFit # noqa: F401
4040
from seaborn._stats.histograms import Hist # noqa: F401
4141

0 commit comments

Comments
 (0)