Skip to content

Commit 03d703b

Browse files
authored
Add weighted estimation in function interface (#3586)
* Add weighted estimation in function interface * Add weights to relplot and catplot * Fix relpot weights * Matplotlib backcompat in tests
1 parent 2bb945c commit 03d703b

File tree

7 files changed

+127
-51
lines changed

7 files changed

+127
-51
lines changed

seaborn/_statistics.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -518,7 +518,7 @@ def __call__(self, data, var):
518518
return pd.Series({var: estimate, f"{var}min": err_min, f"{var}max": err_max})
519519

520520

521-
class WeightedEstimateAggregator:
521+
class WeightedAggregator:
522522

523523
def __init__(self, estimator, errorbar=None, **boot_kws):
524524
"""

seaborn/_stats/aggregation.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from seaborn._stats.base import Stat
1111
from seaborn._statistics import (
1212
EstimateAggregator,
13-
WeightedEstimateAggregator,
13+
WeightedAggregator,
1414
)
1515
from seaborn._core.typing import Vector
1616

@@ -105,7 +105,7 @@ def __call__(
105105

106106
boot_kws = {"n_boot": self.n_boot, "seed": self.seed}
107107
if "weight" in data:
108-
engine = WeightedEstimateAggregator(self.func, self.errorbar, **boot_kws)
108+
engine = WeightedAggregator(self.func, self.errorbar, **boot_kws)
109109
else:
110110
engine = EstimateAggregator(self.func, self.errorbar, **boot_kws)
111111

seaborn/categorical.py

Lines changed: 47 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,11 @@
2828
_version_predates,
2929
)
3030
from seaborn._compat import MarkerStyle
31-
from seaborn._statistics import EstimateAggregator, LetterValues
31+
from seaborn._statistics import (
32+
EstimateAggregator,
33+
LetterValues,
34+
WeightedAggregator,
35+
)
3236
from seaborn.palettes import light_palette
3337
from seaborn.axisgrid import FacetGrid, _facet_docs
3438

@@ -1385,11 +1389,16 @@ class _CategoricalAggPlotter(_CategoricalPlotter):
13851389
.. versionadded:: v0.12.0
13861390
n_boot : int
13871391
Number of bootstrap samples used to compute confidence intervals.
1392+
seed : int, `numpy.random.Generator`, or `numpy.random.RandomState`
1393+
Seed or random number generator for reproducible bootstrapping.
13881394
units : name of variable in `data` or vector data
13891395
Identifier of sampling units; used by the errorbar function to
13901396
perform a multilevel bootstrap and account for repeated measures
1391-
seed : int, `numpy.random.Generator`, or `numpy.random.RandomState`
1392-
Seed or random number generator for reproducible bootstrapping.\
1397+
weights : name of variable in `data` or vector data
1398+
Data values or column used to compute weighted statistics.
1399+
Note that the use of weights may limit other statistical options.
1400+
1401+
.. versionadded:: v0.13.1\
13931402
"""),
13941403
ci=dedent("""\
13951404
ci : float
@@ -2308,10 +2317,10 @@ def swarmplot(
23082317

23092318
def barplot(
23102319
data=None, *, x=None, y=None, hue=None, order=None, hue_order=None,
2311-
estimator="mean", errorbar=("ci", 95), n_boot=1000, units=None, seed=None,
2312-
orient=None, color=None, palette=None, saturation=.75, fill=True, hue_norm=None,
2313-
width=.8, dodge="auto", gap=0, log_scale=None, native_scale=False, formatter=None,
2314-
legend="auto", capsize=0, err_kws=None,
2320+
estimator="mean", errorbar=("ci", 95), n_boot=1000, seed=None, units=None,
2321+
weights=None, orient=None, color=None, palette=None, saturation=.75,
2322+
fill=True, hue_norm=None, width=.8, dodge="auto", gap=0, log_scale=None,
2323+
native_scale=False, formatter=None, legend="auto", capsize=0, err_kws=None,
23152324
ci=deprecated, errcolor=deprecated, errwidth=deprecated, ax=None, **kwargs,
23162325
):
23172326

@@ -2324,7 +2333,7 @@ def barplot(
23242333

23252334
p = _CategoricalAggPlotter(
23262335
data=data,
2327-
variables=dict(x=x, y=y, hue=hue, units=units),
2336+
variables=dict(x=x, y=y, hue=hue, units=units, weight=weights),
23282337
order=order,
23292338
orient=orient,
23302339
color=color,
@@ -2354,7 +2363,8 @@ def barplot(
23542363
p.map_hue(palette=palette, order=hue_order, norm=hue_norm, saturation=saturation)
23552364
color = _default_color(ax.bar, hue, color, kwargs, saturation=saturation)
23562365

2357-
aggregator = EstimateAggregator(estimator, errorbar, n_boot=n_boot, seed=seed)
2366+
agg_cls = WeightedAggregator if "weight" in p.plot_data else EstimateAggregator
2367+
aggregator = agg_cls(estimator, errorbar, n_boot=n_boot, seed=seed)
23582368
err_kws = {} if err_kws is None else _normalize_kwargs(err_kws, mpl.lines.Line2D)
23592369

23602370
# Deprecations to remove in v0.15.0.
@@ -2449,20 +2459,19 @@ def barplot(
24492459

24502460
def pointplot(
24512461
data=None, *, x=None, y=None, hue=None, order=None, hue_order=None,
2452-
estimator="mean", errorbar=("ci", 95), n_boot=1000, units=None, seed=None,
2453-
color=None, palette=None, hue_norm=None, markers=default, linestyles=default,
2454-
dodge=False, log_scale=None, native_scale=False, orient=None, capsize=0,
2455-
formatter=None, legend="auto", err_kws=None,
2462+
estimator="mean", errorbar=("ci", 95), n_boot=1000, seed=None, units=None,
2463+
weights=None, color=None, palette=None, hue_norm=None, markers=default,
2464+
linestyles=default, dodge=False, log_scale=None, native_scale=False,
2465+
orient=None, capsize=0, formatter=None, legend="auto", err_kws=None,
24562466
ci=deprecated, errwidth=deprecated, join=deprecated, scale=deprecated,
2457-
ax=None,
2458-
**kwargs,
2467+
ax=None, **kwargs,
24592468
):
24602469

24612470
errorbar = utils._deprecate_ci(errorbar, ci)
24622471

24632472
p = _CategoricalAggPlotter(
24642473
data=data,
2465-
variables=dict(x=x, y=y, hue=hue, units=units),
2474+
variables=dict(x=x, y=y, hue=hue, units=units, weight=weights),
24662475
order=order,
24672476
orient=orient,
24682477
# Handle special backwards compatibility where pointplot originally
@@ -2489,7 +2498,8 @@ def pointplot(
24892498
p.map_hue(palette=palette, order=hue_order, norm=hue_norm)
24902499
color = _default_color(ax.plot, hue, color, kwargs)
24912500

2492-
aggregator = EstimateAggregator(estimator, errorbar, n_boot=n_boot, seed=seed)
2501+
agg_cls = WeightedAggregator if "weight" in p.plot_data else EstimateAggregator
2502+
aggregator = agg_cls(estimator, errorbar, n_boot=n_boot, seed=seed)
24932503
err_kws = {} if err_kws is None else _normalize_kwargs(err_kws, mpl.lines.Line2D)
24942504

24952505
# Deprecations to remove in v0.15.0.
@@ -2729,12 +2739,12 @@ def countplot(
27292739

27302740
def catplot(
27312741
data=None, *, x=None, y=None, hue=None, row=None, col=None, kind="strip",
2732-
estimator="mean", errorbar=("ci", 95), n_boot=1000, units=None, seed=None,
2733-
order=None, hue_order=None, row_order=None, col_order=None, col_wrap=None,
2734-
height=5, aspect=1, log_scale=None, native_scale=False, formatter=None,
2735-
orient=None, color=None, palette=None, hue_norm=None, legend="auto",
2736-
legend_out=True, sharex=True, sharey=True, margin_titles=False, facet_kws=None,
2737-
ci=deprecated, **kwargs
2742+
estimator="mean", errorbar=("ci", 95), n_boot=1000, seed=None, units=None,
2743+
weights=None, order=None, hue_order=None, row_order=None, col_order=None,
2744+
col_wrap=None, height=5, aspect=1, log_scale=None, native_scale=False,
2745+
formatter=None, orient=None, color=None, palette=None, hue_norm=None,
2746+
legend="auto", legend_out=True, sharex=True, sharey=True,
2747+
margin_titles=False, facet_kws=None, ci=deprecated, **kwargs
27382748
):
27392749

27402750
# Check for attempt to plot onto specific axes and warn
@@ -2764,7 +2774,9 @@ def catplot(
27642774

27652775
p = Plotter(
27662776
data=data,
2767-
variables=dict(x=x, y=y, hue=hue, row=row, col=col, units=units),
2777+
variables=dict(
2778+
x=x, y=y, hue=hue, row=row, col=col, units=units, weight=weights
2779+
),
27682780
order=order,
27692781
orient=orient,
27702782
# Handle special backwards compatibility where pointplot originally
@@ -2840,6 +2852,14 @@ def catplot(
28402852
if dodge == "auto":
28412853
dodge = p._dodge_needed()
28422854

2855+
if "weight" in p.plot_data:
2856+
if kind not in ["bar", "point"]:
2857+
msg = f"The `weights` parameter has no effect with kind={kind!r}."
2858+
warnings.warn(msg, stacklevel=2)
2859+
agg_cls = WeightedAggregator
2860+
else:
2861+
agg_cls = EstimateAggregator
2862+
28432863
if kind == "strip":
28442864

28452865
jitter = kwargs.pop("jitter", True)
@@ -2989,9 +3009,7 @@ def catplot(
29893009

29903010
elif kind == "point":
29913011

2992-
aggregator = EstimateAggregator(
2993-
estimator, errorbar, n_boot=n_boot, seed=seed
2994-
)
3012+
aggregator = agg_cls(estimator, errorbar, n_boot=n_boot, seed=seed)
29953013

29963014
markers = kwargs.pop("markers", default)
29973015
linestyles = kwargs.pop("linestyles", default)
@@ -3025,9 +3043,8 @@ def catplot(
30253043

30263044
elif kind == "bar":
30273045

3028-
aggregator = EstimateAggregator(
3029-
estimator, errorbar, n_boot=n_boot, seed=seed
3030-
)
3046+
aggregator = agg_cls(estimator, errorbar, n_boot=n_boot, seed=seed)
3047+
30313048
err_kws, capsize = p._err_kws_backcompat(
30323049
_normalize_kwargs(kwargs.pop("err_kws", {}), mpl.lines.Line2D),
30333050
errcolor=kwargs.pop("errcolor", deprecated),

seaborn/relational.py

Lines changed: 28 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
_normalize_kwargs,
1818
_scatter_legend_artist,
1919
)
20-
from ._statistics import EstimateAggregator
20+
from ._statistics import EstimateAggregator, WeightedAggregator
2121
from .axisgrid import FacetGrid, _facet_docs
2222
from ._docstrings import DocstringComponents, _core_docs
2323

@@ -252,7 +252,8 @@ def plot(self, ax, kws):
252252
raise ValueError(err.format(self.err_style))
253253

254254
# Initialize the aggregation object
255-
agg = EstimateAggregator(
255+
weighted = "weight" in self.plot_data
256+
agg = (WeightedAggregator if weighted else EstimateAggregator)(
256257
self.estimator, self.errorbar, n_boot=self.n_boot, seed=self.seed,
257258
)
258259

@@ -464,7 +465,7 @@ def plot(self, ax, kws):
464465

465466
def lineplot(
466467
data=None, *,
467-
x=None, y=None, hue=None, size=None, style=None, units=None,
468+
x=None, y=None, hue=None, size=None, style=None, units=None, weights=None,
468469
palette=None, hue_order=None, hue_norm=None,
469470
sizes=None, size_order=None, size_norm=None,
470471
dashes=True, markers=None, style_order=None,
@@ -478,7 +479,9 @@ def lineplot(
478479

479480
p = _LinePlotter(
480481
data=data,
481-
variables=dict(x=x, y=y, hue=hue, size=size, style=style, units=units),
482+
variables=dict(
483+
x=x, y=y, hue=hue, size=size, style=style, units=units, weight=weights
484+
),
482485
estimator=estimator, n_boot=n_boot, seed=seed, errorbar=errorbar,
483486
sort=sort, orient=orient, err_style=err_style, err_kws=err_kws,
484487
legend=legend,
@@ -536,6 +539,10 @@ def lineplot(
536539
and/or markers. Can have a numeric dtype but will always be treated
537540
as categorical.
538541
{params.rel.units}
542+
weights : vector or key in `data`
543+
Data values or column used to compute weighted estimation.
544+
Note that use of weights currently limits the choice of statistics
545+
to a 'mean' estimator and 'ci' errorbar.
539546
{params.core.palette}
540547
{params.core.hue_order}
541548
{params.core.hue_norm}
@@ -687,7 +694,7 @@ def scatterplot(
687694

688695
def relplot(
689696
data=None, *,
690-
x=None, y=None, hue=None, size=None, style=None, units=None,
697+
x=None, y=None, hue=None, size=None, style=None, units=None, weights=None,
691698
row=None, col=None, col_wrap=None, row_order=None, col_order=None,
692699
palette=None, hue_order=None, hue_norm=None,
693700
sizes=None, size_order=None, size_norm=None,
@@ -725,9 +732,14 @@ def relplot(
725732
variables = dict(x=x, y=y, hue=hue, size=size, style=style)
726733
if kind == "line":
727734
variables["units"] = units
728-
elif units is not None:
729-
msg = "The `units` parameter of `relplot` has no effect with kind='scatter'"
730-
warnings.warn(msg, stacklevel=2)
735+
variables["weight"] = weights
736+
else:
737+
if units is not None:
738+
msg = "The `units` parameter has no effect with kind='scatter'."
739+
warnings.warn(msg, stacklevel=2)
740+
if weights is not None:
741+
msg = "The `weights` parameter has no effect with kind='scatter'."
742+
warnings.warn(msg, stacklevel=2)
731743
p = Plotter(
732744
data=data,
733745
variables=variables,
@@ -780,17 +792,18 @@ def relplot(
780792

781793
# Add the grid semantics onto the plotter
782794
grid_variables = dict(
783-
x=x, y=y, row=row, col=col,
784-
hue=hue, size=size, style=style,
795+
x=x, y=y, row=row, col=col, hue=hue, size=size, style=style,
785796
)
786797
if kind == "line":
787-
grid_variables["units"] = units
798+
grid_variables.update(units=units, weights=weights)
788799
p.assign_variables(data, grid_variables)
789800

790801
# Define the named variables for plotting on each facet
791802
# Rename the variables with a leading underscore to avoid
792803
# collisions with faceting variable names
793804
plot_variables = {v: f"_{v}" for v in variables}
805+
if "weight" in plot_variables:
806+
plot_variables["weights"] = plot_variables.pop("weight")
794807
plot_kws.update(plot_variables)
795808

796809
# Pass the row/col variables to FacetGrid with their original
@@ -918,6 +931,10 @@ def relplot(
918931
Grouping variable that will produce elements with different styles.
919932
Can have a numeric dtype but will always be treated as categorical.
920933
{params.rel.units}
934+
weights : vector or key in `data`
935+
Data values or column used to compute weighted estimation.
936+
Note that use of weights currently limits the choice of statistics
937+
to a 'mean' estimator and 'ci' errorbar.
921938
{params.facets.rowcol}
922939
{params.facets.col_wrap}
923940
row_order, col_order : lists of strings

tests/test_categorical.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2131,6 +2131,13 @@ def test_estimate_func(self, long_df):
21312131
for i, bar in enumerate(ax.patches):
21322132
assert bar.get_height() == approx(agg_df[order[i]])
21332133

2134+
def test_weighted_estimate(self, long_df):
2135+
2136+
ax = barplot(long_df, y="y", weights="x")
2137+
height = ax.patches[0].get_height()
2138+
expected = np.average(long_df["y"], weights=long_df["x"])
2139+
assert height == expected
2140+
21342141
def test_estimate_log_transform(self, long_df):
21352142

21362143
ax = mpl.figure.Figure().subplots()
@@ -2490,6 +2497,13 @@ def test_estimate(self, long_df, estimator):
24902497
for i, xy in enumerate(ax.lines[0].get_xydata()):
24912498
assert tuple(xy) == approx((i, agg_df[order[i]]))
24922499

2500+
def test_weighted_estimate(self, long_df):
2501+
2502+
ax = pointplot(long_df, y="y", weights="x")
2503+
val = ax.lines[0].get_ydata().item()
2504+
expected = np.average(long_df["y"], weights=long_df["x"])
2505+
assert val == expected
2506+
24932507
def test_estimate_log_transform(self, long_df):
24942508

24952509
ax = mpl.figure.Figure().subplots()
@@ -3133,6 +3147,12 @@ def test_legend_with_auto(self):
31333147
g2 = catplot(self.df, x="g", y="y", hue="g", legend=True)
31343148
assert g2._legend is not None
31353149

3150+
def test_weights_warning(self, long_df):
3151+
3152+
with pytest.warns(UserWarning, match="The `weights` parameter"):
3153+
g = catplot(long_df, x="a", y="y", weights="z")
3154+
assert g.ax is not None
3155+
31363156

31373157
class TestBeeswarm:
31383158

tests/test_relational.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -578,6 +578,15 @@ def test_relplot_styles(self, long_df):
578578
expected_paths = [paths[val] for val in grp_df["a"]]
579579
assert self.paths_equal(points.get_paths(), expected_paths)
580580

581+
def test_relplot_weighted_estimator(self, long_df):
582+
583+
g = relplot(data=long_df, x="a", y="y", weights="x", kind="line")
584+
ydata = g.ax.lines[0].get_ydata()
585+
for i, level in enumerate(categorical_order(long_df["a"])):
586+
pos_df = long_df[long_df["a"] == level]
587+
expected = np.average(pos_df["y"], weights=pos_df["x"])
588+
assert ydata[i] == pytest.approx(expected)
589+
581590
def test_relplot_stringy_numerics(self, long_df):
582591

583592
long_df["x_str"] = long_df["x"].astype(str)
@@ -668,12 +677,16 @@ def test_facet_variable_collision(self, long_df):
668677
)
669678
assert g.axes.shape == (1, len(col_data.unique()))
670679

671-
def test_relplot_scatter_units(self, long_df):
680+
def test_relplot_scatter_unused_variables(self, long_df):
672681

673682
with pytest.warns(UserWarning, match="The `units` parameter"):
674683
g = relplot(long_df, x="x", y="y", units="a")
675684
assert g.ax is not None
676685

686+
with pytest.warns(UserWarning, match="The `weights` parameter"):
687+
g = relplot(long_df, x="x", y="y", weights="x")
688+
assert g.ax is not None
689+
677690
def test_ax_kwarg_removal(self, long_df):
678691

679692
f, ax = plt.subplots()
@@ -1055,6 +1068,15 @@ def test_plot(self, long_df, repeated_df):
10551068
ax.clear()
10561069
p.plot(ax, {})
10571070

1071+
def test_weights(self, long_df):
1072+
1073+
ax = lineplot(long_df, x="a", y="y", weights="x")
1074+
vals = ax.lines[0].get_ydata()
1075+
for i, level in enumerate(categorical_order(long_df["a"])):
1076+
pos_df = long_df[long_df["a"] == level]
1077+
expected = np.average(pos_df["y"], weights=pos_df["x"])
1078+
assert vals[i] == pytest.approx(expected)
1079+
10581080
def test_non_aggregated_data(self):
10591081

10601082
x = [1, 2, 3, 4]

0 commit comments

Comments
 (0)