Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion seaborn/_statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -518,7 +518,7 @@ def __call__(self, data, var):
return pd.Series({var: estimate, f"{var}min": err_min, f"{var}max": err_max})


class WeightedEstimateAggregator:
class WeightedAggregator:

def __init__(self, estimator, errorbar=None, **boot_kws):
"""
Expand Down
4 changes: 2 additions & 2 deletions seaborn/_stats/aggregation.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from seaborn._stats.base import Stat
from seaborn._statistics import (
EstimateAggregator,
WeightedEstimateAggregator,
WeightedAggregator,
)
from seaborn._core.typing import Vector

Expand Down Expand Up @@ -105,7 +105,7 @@ def __call__(

boot_kws = {"n_boot": self.n_boot, "seed": self.seed}
if "weight" in data:
engine = WeightedEstimateAggregator(self.func, self.errorbar, **boot_kws)
engine = WeightedAggregator(self.func, self.errorbar, **boot_kws)
else:
engine = EstimateAggregator(self.func, self.errorbar, **boot_kws)

Expand Down
77 changes: 47 additions & 30 deletions seaborn/categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,11 @@
_version_predates,
)
from seaborn._compat import MarkerStyle
from seaborn._statistics import EstimateAggregator, LetterValues
from seaborn._statistics import (
EstimateAggregator,
LetterValues,
WeightedAggregator,
)
from seaborn.palettes import light_palette
from seaborn.axisgrid import FacetGrid, _facet_docs

Expand Down Expand Up @@ -1385,11 +1389,16 @@ class _CategoricalAggPlotter(_CategoricalPlotter):
.. versionadded:: v0.12.0
n_boot : int
Number of bootstrap samples used to compute confidence intervals.
seed : int, `numpy.random.Generator`, or `numpy.random.RandomState`
Seed or random number generator for reproducible bootstrapping.
units : name of variable in `data` or vector data
Identifier of sampling units; used by the errorbar function to
perform a multilevel bootstrap and account for repeated measures
seed : int, `numpy.random.Generator`, or `numpy.random.RandomState`
Seed or random number generator for reproducible bootstrapping.\
weights : name of variable in `data` or vector data
Data values or column used to compute weighted statistics.
Note that the use of weights may limit other statistical options.

.. versionadded:: v0.13.1\
"""),
ci=dedent("""\
ci : float
Expand Down Expand Up @@ -2308,10 +2317,10 @@ def swarmplot(

def barplot(
data=None, *, x=None, y=None, hue=None, order=None, hue_order=None,
estimator="mean", errorbar=("ci", 95), n_boot=1000, units=None, seed=None,
orient=None, color=None, palette=None, saturation=.75, fill=True, hue_norm=None,
width=.8, dodge="auto", gap=0, log_scale=None, native_scale=False, formatter=None,
legend="auto", capsize=0, err_kws=None,
estimator="mean", errorbar=("ci", 95), n_boot=1000, seed=None, units=None,
weights=None, orient=None, color=None, palette=None, saturation=.75,
fill=True, hue_norm=None, width=.8, dodge="auto", gap=0, log_scale=None,
native_scale=False, formatter=None, legend="auto", capsize=0, err_kws=None,
ci=deprecated, errcolor=deprecated, errwidth=deprecated, ax=None, **kwargs,
):

Expand All @@ -2324,7 +2333,7 @@ def barplot(

p = _CategoricalAggPlotter(
data=data,
variables=dict(x=x, y=y, hue=hue, units=units),
variables=dict(x=x, y=y, hue=hue, units=units, weight=weights),
order=order,
orient=orient,
color=color,
Expand Down Expand Up @@ -2354,7 +2363,8 @@ def barplot(
p.map_hue(palette=palette, order=hue_order, norm=hue_norm, saturation=saturation)
color = _default_color(ax.bar, hue, color, kwargs, saturation=saturation)

aggregator = EstimateAggregator(estimator, errorbar, n_boot=n_boot, seed=seed)
agg_cls = WeightedAggregator if "weight" in p.plot_data else EstimateAggregator
aggregator = agg_cls(estimator, errorbar, n_boot=n_boot, seed=seed)
err_kws = {} if err_kws is None else _normalize_kwargs(err_kws, mpl.lines.Line2D)

# Deprecations to remove in v0.15.0.
Expand Down Expand Up @@ -2449,20 +2459,19 @@ def barplot(

def pointplot(
data=None, *, x=None, y=None, hue=None, order=None, hue_order=None,
estimator="mean", errorbar=("ci", 95), n_boot=1000, units=None, seed=None,
color=None, palette=None, hue_norm=None, markers=default, linestyles=default,
dodge=False, log_scale=None, native_scale=False, orient=None, capsize=0,
formatter=None, legend="auto", err_kws=None,
estimator="mean", errorbar=("ci", 95), n_boot=1000, seed=None, units=None,
weights=None, color=None, palette=None, hue_norm=None, markers=default,
linestyles=default, dodge=False, log_scale=None, native_scale=False,
orient=None, capsize=0, formatter=None, legend="auto", err_kws=None,
ci=deprecated, errwidth=deprecated, join=deprecated, scale=deprecated,
ax=None,
**kwargs,
ax=None, **kwargs,
):

errorbar = utils._deprecate_ci(errorbar, ci)

p = _CategoricalAggPlotter(
data=data,
variables=dict(x=x, y=y, hue=hue, units=units),
variables=dict(x=x, y=y, hue=hue, units=units, weight=weights),
order=order,
orient=orient,
# Handle special backwards compatibility where pointplot originally
Expand All @@ -2489,7 +2498,8 @@ def pointplot(
p.map_hue(palette=palette, order=hue_order, norm=hue_norm)
color = _default_color(ax.plot, hue, color, kwargs)

aggregator = EstimateAggregator(estimator, errorbar, n_boot=n_boot, seed=seed)
agg_cls = WeightedAggregator if "weight" in p.plot_data else EstimateAggregator
aggregator = agg_cls(estimator, errorbar, n_boot=n_boot, seed=seed)
err_kws = {} if err_kws is None else _normalize_kwargs(err_kws, mpl.lines.Line2D)

# Deprecations to remove in v0.15.0.
Expand Down Expand Up @@ -2729,12 +2739,12 @@ def countplot(

def catplot(
data=None, *, x=None, y=None, hue=None, row=None, col=None, kind="strip",
estimator="mean", errorbar=("ci", 95), n_boot=1000, units=None, seed=None,
order=None, hue_order=None, row_order=None, col_order=None, col_wrap=None,
height=5, aspect=1, log_scale=None, native_scale=False, formatter=None,
orient=None, color=None, palette=None, hue_norm=None, legend="auto",
legend_out=True, sharex=True, sharey=True, margin_titles=False, facet_kws=None,
ci=deprecated, **kwargs
estimator="mean", errorbar=("ci", 95), n_boot=1000, seed=None, units=None,
weights=None, order=None, hue_order=None, row_order=None, col_order=None,
col_wrap=None, height=5, aspect=1, log_scale=None, native_scale=False,
formatter=None, orient=None, color=None, palette=None, hue_norm=None,
legend="auto", legend_out=True, sharex=True, sharey=True,
margin_titles=False, facet_kws=None, ci=deprecated, **kwargs
):

# Check for attempt to plot onto specific axes and warn
Expand Down Expand Up @@ -2764,7 +2774,9 @@ def catplot(

p = Plotter(
data=data,
variables=dict(x=x, y=y, hue=hue, row=row, col=col, units=units),
variables=dict(
x=x, y=y, hue=hue, row=row, col=col, units=units, weight=weights
),
order=order,
orient=orient,
# Handle special backwards compatibility where pointplot originally
Expand Down Expand Up @@ -2840,6 +2852,14 @@ def catplot(
if dodge == "auto":
dodge = p._dodge_needed()

if "weight" in p.plot_data:
if kind not in ["bar", "point"]:
msg = f"The `weights` parameter has no effect with kind={kind!r}."
warnings.warn(msg, stacklevel=2)
agg_cls = WeightedAggregator
else:
agg_cls = EstimateAggregator

if kind == "strip":

jitter = kwargs.pop("jitter", True)
Expand Down Expand Up @@ -2989,9 +3009,7 @@ def catplot(

elif kind == "point":

aggregator = EstimateAggregator(
estimator, errorbar, n_boot=n_boot, seed=seed
)
aggregator = agg_cls(estimator, errorbar, n_boot=n_boot, seed=seed)

markers = kwargs.pop("markers", default)
linestyles = kwargs.pop("linestyles", default)
Expand Down Expand Up @@ -3025,9 +3043,8 @@ def catplot(

elif kind == "bar":

aggregator = EstimateAggregator(
estimator, errorbar, n_boot=n_boot, seed=seed
)
aggregator = agg_cls(estimator, errorbar, n_boot=n_boot, seed=seed)

err_kws, capsize = p._err_kws_backcompat(
_normalize_kwargs(kwargs.pop("err_kws", {}), mpl.lines.Line2D),
errcolor=kwargs.pop("errcolor", deprecated),
Expand Down
39 changes: 28 additions & 11 deletions seaborn/relational.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
_normalize_kwargs,
_scatter_legend_artist,
)
from ._statistics import EstimateAggregator
from ._statistics import EstimateAggregator, WeightedAggregator
from .axisgrid import FacetGrid, _facet_docs
from ._docstrings import DocstringComponents, _core_docs

Expand Down Expand Up @@ -252,7 +252,8 @@ def plot(self, ax, kws):
raise ValueError(err.format(self.err_style))

# Initialize the aggregation object
agg = EstimateAggregator(
weighted = "weight" in self.plot_data
agg = (WeightedAggregator if weighted else EstimateAggregator)(
self.estimator, self.errorbar, n_boot=self.n_boot, seed=self.seed,
)

Expand Down Expand Up @@ -464,7 +465,7 @@ def plot(self, ax, kws):

def lineplot(
data=None, *,
x=None, y=None, hue=None, size=None, style=None, units=None,
x=None, y=None, hue=None, size=None, style=None, units=None, weights=None,
palette=None, hue_order=None, hue_norm=None,
sizes=None, size_order=None, size_norm=None,
dashes=True, markers=None, style_order=None,
Expand All @@ -478,7 +479,9 @@ def lineplot(

p = _LinePlotter(
data=data,
variables=dict(x=x, y=y, hue=hue, size=size, style=style, units=units),
variables=dict(
x=x, y=y, hue=hue, size=size, style=style, units=units, weight=weights
),
estimator=estimator, n_boot=n_boot, seed=seed, errorbar=errorbar,
sort=sort, orient=orient, err_style=err_style, err_kws=err_kws,
legend=legend,
Expand Down Expand Up @@ -536,6 +539,10 @@ def lineplot(
and/or markers. Can have a numeric dtype but will always be treated
as categorical.
{params.rel.units}
weights : vector or key in `data`
Data values or column used to compute weighted estimation.
Note that use of weights currently limits the choice of statistics
to a 'mean' estimator and 'ci' errorbar.
{params.core.palette}
{params.core.hue_order}
{params.core.hue_norm}
Expand Down Expand Up @@ -687,7 +694,7 @@ def scatterplot(

def relplot(
data=None, *,
x=None, y=None, hue=None, size=None, style=None, units=None,
x=None, y=None, hue=None, size=None, style=None, units=None, weights=None,
row=None, col=None, col_wrap=None, row_order=None, col_order=None,
palette=None, hue_order=None, hue_norm=None,
sizes=None, size_order=None, size_norm=None,
Expand Down Expand Up @@ -725,9 +732,14 @@ def relplot(
variables = dict(x=x, y=y, hue=hue, size=size, style=style)
if kind == "line":
variables["units"] = units
elif units is not None:
msg = "The `units` parameter of `relplot` has no effect with kind='scatter'"
warnings.warn(msg, stacklevel=2)
variables["weight"] = weights
else:
if units is not None:
msg = "The `units` parameter has no effect with kind='scatter'."
warnings.warn(msg, stacklevel=2)
if weights is not None:
msg = "The `weights` parameter has no effect with kind='scatter'."
warnings.warn(msg, stacklevel=2)
p = Plotter(
data=data,
variables=variables,
Expand Down Expand Up @@ -780,17 +792,18 @@ def relplot(

# Add the grid semantics onto the plotter
grid_variables = dict(
x=x, y=y, row=row, col=col,
hue=hue, size=size, style=style,
x=x, y=y, row=row, col=col, hue=hue, size=size, style=style,
)
if kind == "line":
grid_variables["units"] = units
grid_variables.update(units=units, weights=weights)
p.assign_variables(data, grid_variables)

# Define the named variables for plotting on each facet
# Rename the variables with a leading underscore to avoid
# collisions with faceting variable names
plot_variables = {v: f"_{v}" for v in variables}
if "weight" in plot_variables:
plot_variables["weights"] = plot_variables.pop("weight")
plot_kws.update(plot_variables)

# Pass the row/col variables to FacetGrid with their original
Expand Down Expand Up @@ -918,6 +931,10 @@ def relplot(
Grouping variable that will produce elements with different styles.
Can have a numeric dtype but will always be treated as categorical.
{params.rel.units}
weights : vector or key in `data`
Data values or column used to compute weighted estimation.
Note that use of weights currently limits the choice of statistics
to a 'mean' estimator and 'ci' errorbar.
{params.facets.rowcol}
{params.facets.col_wrap}
row_order, col_order : lists of strings
Expand Down
20 changes: 20 additions & 0 deletions tests/test_categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -2131,6 +2131,13 @@ def test_estimate_func(self, long_df):
for i, bar in enumerate(ax.patches):
assert bar.get_height() == approx(agg_df[order[i]])

def test_weighted_estimate(self, long_df):

ax = barplot(long_df, y="y", weights="x")
height = ax.patches[0].get_height()
expected = np.average(long_df["y"], weights=long_df["x"])
assert height == expected

def test_estimate_log_transform(self, long_df):

ax = mpl.figure.Figure().subplots()
Expand Down Expand Up @@ -2490,6 +2497,13 @@ def test_estimate(self, long_df, estimator):
for i, xy in enumerate(ax.lines[0].get_xydata()):
assert tuple(xy) == approx((i, agg_df[order[i]]))

def test_weighted_estimate(self, long_df):

ax = pointplot(long_df, y="y", weights="x")
val = ax.lines[0].get_ydata().item()
expected = np.average(long_df["y"], weights=long_df["x"])
assert val == expected

def test_estimate_log_transform(self, long_df):

ax = mpl.figure.Figure().subplots()
Expand Down Expand Up @@ -3133,6 +3147,12 @@ def test_legend_with_auto(self):
g2 = catplot(self.df, x="g", y="y", hue="g", legend=True)
assert g2._legend is not None

def test_weights_warning(self, long_df):

with pytest.warns(UserWarning, match="The `weights` parameter"):
g = catplot(long_df, x="a", y="y", weights="z")
assert g.ax is not None


class TestBeeswarm:

Expand Down
24 changes: 23 additions & 1 deletion tests/test_relational.py
Original file line number Diff line number Diff line change
Expand Up @@ -578,6 +578,15 @@ def test_relplot_styles(self, long_df):
expected_paths = [paths[val] for val in grp_df["a"]]
assert self.paths_equal(points.get_paths(), expected_paths)

def test_relplot_weighted_estimator(self, long_df):

g = relplot(data=long_df, x="a", y="y", weights="x", kind="line")
ydata = g.ax.lines[0].get_ydata()
for i, level in enumerate(categorical_order(long_df["a"])):
pos_df = long_df[long_df["a"] == level]
expected = np.average(pos_df["y"], weights=pos_df["x"])
assert ydata[i] == pytest.approx(expected)

def test_relplot_stringy_numerics(self, long_df):

long_df["x_str"] = long_df["x"].astype(str)
Expand Down Expand Up @@ -668,12 +677,16 @@ def test_facet_variable_collision(self, long_df):
)
assert g.axes.shape == (1, len(col_data.unique()))

def test_relplot_scatter_units(self, long_df):
def test_relplot_scatter_unused_variables(self, long_df):

with pytest.warns(UserWarning, match="The `units` parameter"):
g = relplot(long_df, x="x", y="y", units="a")
assert g.ax is not None

with pytest.warns(UserWarning, match="The `weights` parameter"):
g = relplot(long_df, x="x", y="y", weights="x")
assert g.ax is not None

def test_ax_kwarg_removal(self, long_df):

f, ax = plt.subplots()
Expand Down Expand Up @@ -1055,6 +1068,15 @@ def test_plot(self, long_df, repeated_df):
ax.clear()
p.plot(ax, {})

def test_weights(self, long_df):

ax = lineplot(long_df, x="a", y="y", weights="x")
vals = ax.lines[0].get_ydata()
for i, level in enumerate(categorical_order(long_df["a"])):
pos_df = long_df[long_df["a"] == level]
expected = np.average(pos_df["y"], weights=pos_df["x"])
assert vals[i] == pytest.approx(expected)

def test_non_aggregated_data(self):

x = [1, 2, 3, 4]
Expand Down
Loading