1
1
from __future__ import annotations
2
+ from collections import defaultdict
2
3
from dataclasses import dataclass
3
4
4
5
import numpy as np
23
24
from seaborn ._core .scales import Scale
24
25
25
26
26
- @dataclass
27
- class Bar (Mark ):
28
- """
29
- An interval mark drawn between baseline and data values with a width.
30
- """
31
- color : MappableColor = Mappable ("C0" , )
32
- alpha : MappableFloat = Mappable (.7 , )
33
- fill : MappableBool = Mappable (True , )
34
- edgecolor : MappableColor = Mappable (depend = "color" , )
35
- edgealpha : MappableFloat = Mappable (1 , )
36
- edgewidth : MappableFloat = Mappable (rc = "patch.linewidth" )
37
- edgestyle : MappableStyle = Mappable ("-" , )
38
- # pattern: MappableString = Mappable(None, ) # TODO no Property yet
27
+ class BarBase (Mark ):
39
28
40
- width : MappableFloat = Mappable (.8 , grouping = False )
41
- baseline : MappableFloat = Mappable (0 , grouping = False ) # TODO *is* this mappable?
29
+ def _make_patches (self , data , scales , orient ):
30
+
31
+ kws = self ._resolve_properties (data , scales )
32
+ if orient == "x" :
33
+ kws ["x" ] = (data ["x" ] - data ["width" ] / 2 ).to_numpy ()
34
+ kws ["y" ] = data ["baseline" ].to_numpy ()
35
+ kws ["w" ] = data ["width" ].to_numpy ()
36
+ kws ["h" ] = (data ["y" ] - data ["baseline" ]).to_numpy ()
37
+ else :
38
+ kws ["x" ] = data ["baseline" ].to_numpy ()
39
+ kws ["y" ] = (data ["y" ] - data ["width" ] / 2 ).to_numpy ()
40
+ kws ["w" ] = (data ["x" ] - data ["baseline" ]).to_numpy ()
41
+ kws ["h" ] = data ["width" ].to_numpy ()
42
+
43
+ kws .pop ("width" , None )
44
+ kws .pop ("baseline" , None )
45
+
46
+ val_dim = {"x" : "h" , "y" : "w" }[orient ]
47
+ bars , vals = [], []
48
+
49
+ for i in range (len (data )):
50
+
51
+ row = {k : v [i ] for k , v in kws .items ()}
52
+
53
+ # Skip bars with no value. It's possible we'll want to make this
54
+ # an option (i.e so you have an artist for animating or annotating),
55
+ # but let's keep things simple for now.
56
+ if not np .nan_to_num (row [val_dim ]):
57
+ continue
58
+
59
+ bar = mpl .patches .Rectangle (
60
+ xy = (row ["x" ], row ["y" ]),
61
+ width = row ["w" ],
62
+ height = row ["h" ],
63
+ facecolor = row ["facecolor" ],
64
+ edgecolor = row ["edgecolor" ],
65
+ linestyle = row ["edgestyle" ],
66
+ linewidth = row ["edgewidth" ],
67
+ ** self .artist_kws ,
68
+ )
69
+ bars .append (bar )
70
+ vals .append (row [val_dim ])
71
+
72
+ return bars , vals
42
73
43
74
def _resolve_properties (self , data , scales ):
44
75
@@ -56,58 +87,57 @@ def _resolve_properties(self, data, scales):
56
87
57
88
return resolved
58
89
59
- def _plot (self , split_gen , scales , orient ):
90
+ def _legend_artist (
91
+ self , variables : list [str ], value : Any , scales : dict [str , Scale ],
92
+ ) -> Artist :
93
+ # TODO return some sensible default?
94
+ key = {v : value for v in variables }
95
+ key = self ._resolve_properties (key , scales )
96
+ artist = mpl .patches .Patch (
97
+ facecolor = key ["facecolor" ],
98
+ edgecolor = key ["edgecolor" ],
99
+ linewidth = key ["edgewidth" ],
100
+ linestyle = key ["edgestyle" ],
101
+ )
102
+ return artist
60
103
61
- def coords_to_geometry (x , y , w , b ):
62
- # TODO possible too slow with lots of bars (e.g. dense hist)
63
- # Why not just use BarCollection?
64
- if orient == "x" :
65
- w , h = w , y - b
66
- xy = x - w / 2 , b
67
- else :
68
- w , h = x - b , w
69
- xy = b , y - h / 2
70
- return xy , w , h
71
104
72
- val_idx = ["y" , "x" ].index (orient )
105
+ @dataclass
106
+ class Bar (BarBase ):
107
+ """
108
+ An rectangular mark drawn between baseline and data values.
109
+ """
110
+ color : MappableColor = Mappable ("C0" , grouping = False )
111
+ alpha : MappableFloat = Mappable (.7 , grouping = False )
112
+ fill : MappableBool = Mappable (True , grouping = False )
113
+ edgecolor : MappableColor = Mappable (depend = "color" , grouping = False )
114
+ edgealpha : MappableFloat = Mappable (1 , grouping = False )
115
+ edgewidth : MappableFloat = Mappable (rc = "patch.linewidth" , grouping = False )
116
+ edgestyle : MappableStyle = Mappable ("-" , grouping = False )
117
+ # pattern: MappableString = Mappable(None) # TODO no Property yet
73
118
74
- for _ , data , ax in split_gen ():
119
+ width : MappableFloat = Mappable (.8 , grouping = False )
120
+ baseline : MappableFloat = Mappable (0 , grouping = False ) # TODO *is* this mappable?
75
121
76
- xys = data [["x" , "y" ]].to_numpy ()
77
- data = self ._resolve_properties (data , scales )
122
+ def _plot (self , split_gen , scales , orient ):
78
123
79
- bars , vals = [], []
80
- for i , (x , y ) in enumerate (xys ):
124
+ val_idx = ["y" , "x" ].index (orient )
81
125
82
- baseline = data [ "baseline" ][ i ]
83
- width = data [ "width" ][ i ]
84
- xy , w , h = coords_to_geometry ( x , y , width , baseline )
126
+ for _ , data , ax in split_gen ():
127
+
128
+ bars , vals = self . _make_patches ( data , scales , orient )
85
129
86
- # Skip bars with no value. It's possible we'll want to make this
87
- # an option (i.e so you have an artist for animating or annotating),
88
- # but let's keep things simple for now.
89
- if not np .nan_to_num (h ):
90
- continue
130
+ for bar in bars :
91
131
92
- # TODO Because we are clipping the artist (see below), the edges end up
132
+ # Because we are clipping the artist (see below), the edges end up
93
133
# looking half as wide as they actually are. I don't love this clumsy
94
134
# workaround, which is going to cause surprises if you work with the
95
135
# artists directly. We may need to revisit after feedback.
96
- linewidth = data [ "edgewidth" ][ i ] * 2
97
- linestyle = data [ "edgestyle" ][ i ]
136
+ bar . set_linewidth ( bar . get_linewidth () * 2 )
137
+ linestyle = bar . get_linestyle ()
98
138
if linestyle [1 ]:
99
139
linestyle = (linestyle [0 ], tuple (x / 2 for x in linestyle [1 ]))
100
-
101
- bar = mpl .patches .Rectangle (
102
- xy = xy ,
103
- width = w ,
104
- height = h ,
105
- facecolor = data ["facecolor" ][i ],
106
- edgecolor = data ["edgecolor" ][i ],
107
- linestyle = linestyle ,
108
- linewidth = linewidth ,
109
- ** self .artist_kws ,
110
- )
140
+ bar .set_linestyle (linestyle )
111
141
112
142
# This is a bit of a hack to handle the fact that the edge lines are
113
143
# centered on the actual extents of the bar, and overlap when bars are
@@ -121,8 +151,6 @@ def coords_to_geometry(x, y, w, b):
121
151
bar .set_clip_box (ax .bbox )
122
152
bar .sticky_edges [val_idx ][:] = (0 , np .inf )
123
153
ax .add_patch (bar )
124
- bars .append (bar )
125
- vals .append (h )
126
154
127
155
# Add a container which is useful for, e.g. Axes.bar_label
128
156
if Version (mpl .__version__ ) >= Version ("3.4.0" ):
@@ -133,16 +161,71 @@ def coords_to_geometry(x, y, w, b):
133
161
container = mpl .container .BarContainer (bars , ** container_kws )
134
162
ax .add_container (container )
135
163
136
- def _legend_artist (
137
- self , variables : list [str ], value : Any , scales : dict [str , Scale ],
138
- ) -> Artist :
139
- # TODO return some sensible default?
140
- key = {v : value for v in variables }
141
- key = self ._resolve_properties (key , scales )
142
- artist = mpl .patches .Patch (
143
- facecolor = key ["facecolor" ],
144
- edgecolor = key ["edgecolor" ],
145
- linewidth = key ["edgewidth" ],
146
- linestyle = key ["edgestyle" ],
147
- )
148
- return artist
164
+
165
+ @dataclass
166
+ class Bars (BarBase ):
167
+ """
168
+ A faster Bar mark with defaults that are more suitable for histograms.
169
+ """
170
+ color : MappableColor = Mappable ("C0" , grouping = False )
171
+ alpha : MappableFloat = Mappable (.7 , grouping = False )
172
+ fill : MappableBool = Mappable (True , grouping = False )
173
+ edgecolor : MappableColor = Mappable (rc = "patch.edgecolor" , grouping = False )
174
+ edgealpha : MappableFloat = Mappable (1 , grouping = False )
175
+ edgewidth : MappableFloat = Mappable (auto = True , grouping = False )
176
+ edgestyle : MappableStyle = Mappable ("-" , grouping = False )
177
+ # pattern: MappableString = Mappable(None) # TODO no Property yet
178
+
179
+ width : MappableFloat = Mappable (1 , grouping = False )
180
+ baseline : MappableFloat = Mappable (0 , grouping = False ) # TODO *is* this mappable?
181
+
182
+ def _plot (self , split_gen , scales , orient ):
183
+
184
+ ori_idx = ["x" , "y" ].index (orient )
185
+ val_idx = ["y" , "x" ].index (orient )
186
+
187
+ patches = defaultdict (list )
188
+ for _ , data , ax in split_gen ():
189
+ bars , _ = self ._make_patches (data , scales , orient )
190
+ patches [ax ].extend (bars )
191
+
192
+ collections = {}
193
+ for ax , ax_patches in patches .items ():
194
+
195
+ col = mpl .collections .PatchCollection (ax_patches , match_original = True )
196
+ col .sticky_edges [val_idx ][:] = (0 , np .inf )
197
+ ax .add_collection (col , autolim = False )
198
+ collections [ax ] = col
199
+
200
+ # Workaround for matplotlib autoscaling bug
201
+ # https://github.com/matplotlib/matplotlib/issues/11898
202
+ # https://github.com/matplotlib/matplotlib/issues/23129
203
+ xy = np .vstack ([path .vertices for path in col .get_paths ()])
204
+ ax .dataLim .update_from_data_xy (
205
+ xy , ax .ignore_existing_data_limits , updatex = True , updatey = True
206
+ )
207
+
208
+ if "edgewidth" not in scales and isinstance (self .edgewidth , Mappable ):
209
+
210
+ for ax in collections :
211
+ ax .autoscale_view ()
212
+
213
+ def get_dimensions (collection ):
214
+ edges , widths = [], []
215
+ for verts in (path .vertices for path in collection .get_paths ()):
216
+ edges .append (min (verts [:, ori_idx ]))
217
+ widths .append (np .ptp (verts [:, ori_idx ]))
218
+ return np .array (edges ), np .array (widths )
219
+
220
+ min_width = np .inf
221
+ for ax , col in collections .items ():
222
+ edges , widths = get_dimensions (col )
223
+ points = 72 / ax .figure .dpi * abs (
224
+ ax .transData .transform ([edges + widths ] * 2 )
225
+ - ax .transData .transform ([edges ] * 2 )
226
+ )
227
+ min_width = min (min_width , min (points [:, ori_idx ]))
228
+
229
+ linewidth = min (.1 * min_width , mpl .rcParams ["patch.linewidth" ])
230
+ for _ , col in collections .items ():
231
+ col .set_linewidth (linewidth )
0 commit comments