Skip to content

Commit 3c8447f

Browse files
authored
Add Bars for more efficient bar plots and improve Bar as well (#2893)
* Vectorize bar->patch reparameterization * Refactor bar setup * Add Bars mark that uses a PatchCollection * Work around matplotlib bug to autoscale Bars correctly * Reorganize function * Add tests * Fix min edgewidth calculation * Autoscale all axes before computing minimum edgewidth * Test auto edgewidth
1 parent 3180fd7 commit 3c8447f

File tree

6 files changed

+257
-74
lines changed

6 files changed

+257
-74
lines changed

seaborn/_marks/bars.py

Lines changed: 152 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from __future__ import annotations
2+
from collections import defaultdict
23
from dataclasses import dataclass
34

45
import numpy as np
@@ -23,22 +24,52 @@
2324
from seaborn._core.scales import Scale
2425

2526

26-
@dataclass
27-
class Bar(Mark):
28-
"""
29-
An interval mark drawn between baseline and data values with a width.
30-
"""
31-
color: MappableColor = Mappable("C0", )
32-
alpha: MappableFloat = Mappable(.7, )
33-
fill: MappableBool = Mappable(True, )
34-
edgecolor: MappableColor = Mappable(depend="color", )
35-
edgealpha: MappableFloat = Mappable(1, )
36-
edgewidth: MappableFloat = Mappable(rc="patch.linewidth")
37-
edgestyle: MappableStyle = Mappable("-", )
38-
# pattern: MappableString = Mappable(None, ) # TODO no Property yet
27+
class BarBase(Mark):
3928

40-
width: MappableFloat = Mappable(.8, grouping=False)
41-
baseline: MappableFloat = Mappable(0, grouping=False) # TODO *is* this mappable?
29+
def _make_patches(self, data, scales, orient):
30+
31+
kws = self._resolve_properties(data, scales)
32+
if orient == "x":
33+
kws["x"] = (data["x"] - data["width"] / 2).to_numpy()
34+
kws["y"] = data["baseline"].to_numpy()
35+
kws["w"] = data["width"].to_numpy()
36+
kws["h"] = (data["y"] - data["baseline"]).to_numpy()
37+
else:
38+
kws["x"] = data["baseline"].to_numpy()
39+
kws["y"] = (data["y"] - data["width"] / 2).to_numpy()
40+
kws["w"] = (data["x"] - data["baseline"]).to_numpy()
41+
kws["h"] = data["width"].to_numpy()
42+
43+
kws.pop("width", None)
44+
kws.pop("baseline", None)
45+
46+
val_dim = {"x": "h", "y": "w"}[orient]
47+
bars, vals = [], []
48+
49+
for i in range(len(data)):
50+
51+
row = {k: v[i] for k, v in kws.items()}
52+
53+
# Skip bars with no value. It's possible we'll want to make this
54+
# an option (i.e so you have an artist for animating or annotating),
55+
# but let's keep things simple for now.
56+
if not np.nan_to_num(row[val_dim]):
57+
continue
58+
59+
bar = mpl.patches.Rectangle(
60+
xy=(row["x"], row["y"]),
61+
width=row["w"],
62+
height=row["h"],
63+
facecolor=row["facecolor"],
64+
edgecolor=row["edgecolor"],
65+
linestyle=row["edgestyle"],
66+
linewidth=row["edgewidth"],
67+
**self.artist_kws,
68+
)
69+
bars.append(bar)
70+
vals.append(row[val_dim])
71+
72+
return bars, vals
4273

4374
def _resolve_properties(self, data, scales):
4475

@@ -56,58 +87,57 @@ def _resolve_properties(self, data, scales):
5687

5788
return resolved
5889

59-
def _plot(self, split_gen, scales, orient):
90+
def _legend_artist(
91+
self, variables: list[str], value: Any, scales: dict[str, Scale],
92+
) -> Artist:
93+
# TODO return some sensible default?
94+
key = {v: value for v in variables}
95+
key = self._resolve_properties(key, scales)
96+
artist = mpl.patches.Patch(
97+
facecolor=key["facecolor"],
98+
edgecolor=key["edgecolor"],
99+
linewidth=key["edgewidth"],
100+
linestyle=key["edgestyle"],
101+
)
102+
return artist
60103

61-
def coords_to_geometry(x, y, w, b):
62-
# TODO possible too slow with lots of bars (e.g. dense hist)
63-
# Why not just use BarCollection?
64-
if orient == "x":
65-
w, h = w, y - b
66-
xy = x - w / 2, b
67-
else:
68-
w, h = x - b, w
69-
xy = b, y - h / 2
70-
return xy, w, h
71104

72-
val_idx = ["y", "x"].index(orient)
105+
@dataclass
106+
class Bar(BarBase):
107+
"""
108+
An rectangular mark drawn between baseline and data values.
109+
"""
110+
color: MappableColor = Mappable("C0", grouping=False)
111+
alpha: MappableFloat = Mappable(.7, grouping=False)
112+
fill: MappableBool = Mappable(True, grouping=False)
113+
edgecolor: MappableColor = Mappable(depend="color", grouping=False)
114+
edgealpha: MappableFloat = Mappable(1, grouping=False)
115+
edgewidth: MappableFloat = Mappable(rc="patch.linewidth", grouping=False)
116+
edgestyle: MappableStyle = Mappable("-", grouping=False)
117+
# pattern: MappableString = Mappable(None) # TODO no Property yet
73118

74-
for _, data, ax in split_gen():
119+
width: MappableFloat = Mappable(.8, grouping=False)
120+
baseline: MappableFloat = Mappable(0, grouping=False) # TODO *is* this mappable?
75121

76-
xys = data[["x", "y"]].to_numpy()
77-
data = self._resolve_properties(data, scales)
122+
def _plot(self, split_gen, scales, orient):
78123

79-
bars, vals = [], []
80-
for i, (x, y) in enumerate(xys):
124+
val_idx = ["y", "x"].index(orient)
81125

82-
baseline = data["baseline"][i]
83-
width = data["width"][i]
84-
xy, w, h = coords_to_geometry(x, y, width, baseline)
126+
for _, data, ax in split_gen():
127+
128+
bars, vals = self._make_patches(data, scales, orient)
85129

86-
# Skip bars with no value. It's possible we'll want to make this
87-
# an option (i.e so you have an artist for animating or annotating),
88-
# but let's keep things simple for now.
89-
if not np.nan_to_num(h):
90-
continue
130+
for bar in bars:
91131

92-
# TODO Because we are clipping the artist (see below), the edges end up
132+
# Because we are clipping the artist (see below), the edges end up
93133
# looking half as wide as they actually are. I don't love this clumsy
94134
# workaround, which is going to cause surprises if you work with the
95135
# artists directly. We may need to revisit after feedback.
96-
linewidth = data["edgewidth"][i] * 2
97-
linestyle = data["edgestyle"][i]
136+
bar.set_linewidth(bar.get_linewidth() * 2)
137+
linestyle = bar.get_linestyle()
98138
if linestyle[1]:
99139
linestyle = (linestyle[0], tuple(x / 2 for x in linestyle[1]))
100-
101-
bar = mpl.patches.Rectangle(
102-
xy=xy,
103-
width=w,
104-
height=h,
105-
facecolor=data["facecolor"][i],
106-
edgecolor=data["edgecolor"][i],
107-
linestyle=linestyle,
108-
linewidth=linewidth,
109-
**self.artist_kws,
110-
)
140+
bar.set_linestyle(linestyle)
111141

112142
# This is a bit of a hack to handle the fact that the edge lines are
113143
# centered on the actual extents of the bar, and overlap when bars are
@@ -121,8 +151,6 @@ def coords_to_geometry(x, y, w, b):
121151
bar.set_clip_box(ax.bbox)
122152
bar.sticky_edges[val_idx][:] = (0, np.inf)
123153
ax.add_patch(bar)
124-
bars.append(bar)
125-
vals.append(h)
126154

127155
# Add a container which is useful for, e.g. Axes.bar_label
128156
if Version(mpl.__version__) >= Version("3.4.0"):
@@ -133,16 +161,71 @@ def coords_to_geometry(x, y, w, b):
133161
container = mpl.container.BarContainer(bars, **container_kws)
134162
ax.add_container(container)
135163

136-
def _legend_artist(
137-
self, variables: list[str], value: Any, scales: dict[str, Scale],
138-
) -> Artist:
139-
# TODO return some sensible default?
140-
key = {v: value for v in variables}
141-
key = self._resolve_properties(key, scales)
142-
artist = mpl.patches.Patch(
143-
facecolor=key["facecolor"],
144-
edgecolor=key["edgecolor"],
145-
linewidth=key["edgewidth"],
146-
linestyle=key["edgestyle"],
147-
)
148-
return artist
164+
165+
@dataclass
166+
class Bars(BarBase):
167+
"""
168+
A faster Bar mark with defaults that are more suitable for histograms.
169+
"""
170+
color: MappableColor = Mappable("C0", grouping=False)
171+
alpha: MappableFloat = Mappable(.7, grouping=False)
172+
fill: MappableBool = Mappable(True, grouping=False)
173+
edgecolor: MappableColor = Mappable(rc="patch.edgecolor", grouping=False)
174+
edgealpha: MappableFloat = Mappable(1, grouping=False)
175+
edgewidth: MappableFloat = Mappable(auto=True, grouping=False)
176+
edgestyle: MappableStyle = Mappable("-", grouping=False)
177+
# pattern: MappableString = Mappable(None) # TODO no Property yet
178+
179+
width: MappableFloat = Mappable(1, grouping=False)
180+
baseline: MappableFloat = Mappable(0, grouping=False) # TODO *is* this mappable?
181+
182+
def _plot(self, split_gen, scales, orient):
183+
184+
ori_idx = ["x", "y"].index(orient)
185+
val_idx = ["y", "x"].index(orient)
186+
187+
patches = defaultdict(list)
188+
for _, data, ax in split_gen():
189+
bars, _ = self._make_patches(data, scales, orient)
190+
patches[ax].extend(bars)
191+
192+
collections = {}
193+
for ax, ax_patches in patches.items():
194+
195+
col = mpl.collections.PatchCollection(ax_patches, match_original=True)
196+
col.sticky_edges[val_idx][:] = (0, np.inf)
197+
ax.add_collection(col, autolim=False)
198+
collections[ax] = col
199+
200+
# Workaround for matplotlib autoscaling bug
201+
# https://github.com/matplotlib/matplotlib/issues/11898
202+
# https://github.com/matplotlib/matplotlib/issues/23129
203+
xy = np.vstack([path.vertices for path in col.get_paths()])
204+
ax.dataLim.update_from_data_xy(
205+
xy, ax.ignore_existing_data_limits, updatex=True, updatey=True
206+
)
207+
208+
if "edgewidth" not in scales and isinstance(self.edgewidth, Mappable):
209+
210+
for ax in collections:
211+
ax.autoscale_view()
212+
213+
def get_dimensions(collection):
214+
edges, widths = [], []
215+
for verts in (path.vertices for path in collection.get_paths()):
216+
edges.append(min(verts[:, ori_idx]))
217+
widths.append(np.ptp(verts[:, ori_idx]))
218+
return np.array(edges), np.array(widths)
219+
220+
min_width = np.inf
221+
for ax, col in collections.items():
222+
edges, widths = get_dimensions(col)
223+
points = 72 / ax.figure.dpi * abs(
224+
ax.transData.transform([edges + widths] * 2)
225+
- ax.transData.transform([edges] * 2)
226+
)
227+
min_width = min(min_width, min(points[:, ori_idx]))
228+
229+
linewidth = min(.1 * min_width, mpl.rcParams["patch.linewidth"])
230+
for _, col in collections.items():
231+
col.set_linewidth(linewidth)

seaborn/_marks/base.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ def __init__(
2727
val: Any = None,
2828
depend: str | None = None,
2929
rc: str | None = None,
30+
auto: bool = False,
3031
grouping: bool = True,
3132
):
3233
"""
@@ -40,6 +41,8 @@ def __init__(
4041
Use the value of this feature as the default.
4142
rc : str
4243
Use the value of this rcParam as the default.
44+
auto : bool
45+
The default value will depend on other parameters at compile time.
4346
grouping : bool
4447
If True, use the mapped variable to define groups.
4548
@@ -52,6 +55,7 @@ def __init__(
5255
self._val = val
5356
self._rc = rc
5457
self._depend = depend
58+
self._auto = auto
5559
self._grouping = grouping
5660

5761
def __repr__(self):
@@ -62,6 +66,8 @@ def __repr__(self):
6266
s = f"<depend:{self._depend}>"
6367
elif self._rc is not None:
6468
s = f"<rc:{self._rc}>"
69+
elif self._auto:
70+
s = "<auto>"
6571
else:
6672
s = "<undefined>"
6773
return s

seaborn/_marks/scatter.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ def _plot(self, split_gen, scales, orient):
9191
# (That should be solved upstream by defaulting to "" for unset x/y?)
9292
# (Be mindful of xmin/xmax, etc!)
9393

94-
for keys, data, ax in split_gen():
94+
for _, data, ax in split_gen():
9595

9696
offsets = np.column_stack([data["x"], data["y"]])
9797
data = self._resolve_properties(data, scales)

seaborn/objects.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
from seaborn._marks.base import Mark # noqa: F401
77
from seaborn._marks.area import Area, Ribbon # noqa: F401
8-
from seaborn._marks.bars import Bar # noqa: F401
8+
from seaborn._marks.bars import Bar, Bars # noqa: F401
99
from seaborn._marks.lines import Line, Lines, Path, Paths # noqa: F401
1010
from seaborn._marks.scatter import Dot, Scatter # noqa: F401
1111

0 commit comments

Comments
 (0)