Skip to content

Commit 785242b

Browse files
committed
Fix relpot weights
1 parent 595dba7 commit 785242b

File tree

2 files changed

+19
-5
lines changed

2 files changed

+19
-5
lines changed

seaborn/relational.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -792,17 +792,18 @@ def relplot(
792792

793793
# Add the grid semantics onto the plotter
794794
grid_variables = dict(
795-
x=x, y=y, row=row, col=col,
796-
hue=hue, size=size, style=style,
795+
x=x, y=y, row=row, col=col, hue=hue, size=size, style=style,
797796
)
798797
if kind == "line":
799-
grid_variables["units"] = units
798+
grid_variables.update(units=units, weights=weights)
800799
p.assign_variables(data, grid_variables)
801800

802801
# Define the named variables for plotting on each facet
803802
# Rename the variables with a leading underscore to avoid
804803
# collisions with faceting variable names
805804
plot_variables = {v: f"_{v}" for v in variables}
805+
if "weight" in plot_variables:
806+
plot_variables["weights"] = plot_variables.pop("weight")
806807
plot_kws.update(plot_variables)
807808

808809
# Pass the row/col variables to FacetGrid with their original
@@ -930,6 +931,10 @@ def relplot(
930931
Grouping variable that will produce elements with different styles.
931932
Can have a numeric dtype but will always be treated as categorical.
932933
{params.rel.units}
934+
weights : vector or key in `data`
935+
Data values or column used to compute weighted estimation.
936+
Note that use of weights currently limits the choice of statistics
937+
to a 'mean' estimator and 'ci' errorbar.
933938
{params.facets.rowcol}
934939
{params.facets.col_wrap}
935940
row_order, col_order : lists of strings

tests/test_relational.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -578,6 +578,15 @@ def test_relplot_styles(self, long_df):
578578
expected_paths = [paths[val] for val in grp_df["a"]]
579579
assert self.paths_equal(points.get_paths(), expected_paths)
580580

581+
def test_relplot_weighted_estimator(self, long_df):
582+
583+
g = relplot(data=long_df, x="a", y="y", weights="x", kind="line")
584+
ydata = g.ax.lines[0].get_ydata()
585+
for i, label in enumerate(g.ax.get_xticklabels()):
586+
pos_df = long_df[long_df["a"] == label.get_text()]
587+
expected = np.average(pos_df["y"], weights=pos_df["x"])
588+
assert ydata[i] == pytest.approx(expected)
589+
581590
def test_relplot_stringy_numerics(self, long_df):
582591

583592
long_df["x_str"] = long_df["x"].astype(str)
@@ -1063,8 +1072,8 @@ def test_weights(self, long_df):
10631072

10641073
ax = lineplot(long_df, x="a", y="y", weights="x")
10651074
vals = ax.lines[0].get_ydata()
1066-
for i, a in enumerate(ax.get_xticklabels()):
1067-
pos_df = long_df.loc[long_df["a"] == a.get_text()]
1075+
for i, label in enumerate(ax.get_xticklabels()):
1076+
pos_df = long_df.loc[long_df["a"] == label.get_text()]
10681077
expected = np.average(pos_df["y"], weights=pos_df["x"])
10691078
assert vals[i] == pytest.approx(expected)
10701079

0 commit comments

Comments
 (0)