28
28
_version_predates ,
29
29
)
30
30
from seaborn ._compat import MarkerStyle
31
- from seaborn ._statistics import EstimateAggregator , LetterValues
31
+ from seaborn ._statistics import (
32
+ EstimateAggregator ,
33
+ LetterValues ,
34
+ WeightedAggregator ,
35
+ )
32
36
from seaborn .palettes import light_palette
33
37
from seaborn .axisgrid import FacetGrid , _facet_docs
34
38
@@ -1385,11 +1389,16 @@ class _CategoricalAggPlotter(_CategoricalPlotter):
1385
1389
.. versionadded:: v0.12.0
1386
1390
n_boot : int
1387
1391
Number of bootstrap samples used to compute confidence intervals.
1392
+ seed : int, `numpy.random.Generator`, or `numpy.random.RandomState`
1393
+ Seed or random number generator for reproducible bootstrapping.
1388
1394
units : name of variable in `data` or vector data
1389
1395
Identifier of sampling units; used by the errorbar function to
1390
1396
perform a multilevel bootstrap and account for repeated measures
1391
- seed : int, `numpy.random.Generator`, or `numpy.random.RandomState`
1392
- Seed or random number generator for reproducible bootstrapping.\
1397
+ weights : name of variable in `data` or vector data
1398
+ Data values or column used to compute weighted statistics.
1399
+ Note that the use of weights may limit other statistical options.
1400
+
1401
+ .. versionadded:: v0.13.1\
1393
1402
""" ),
1394
1403
ci = dedent ("""\
1395
1404
ci : float
@@ -2308,10 +2317,10 @@ def swarmplot(
2308
2317
2309
2318
def barplot (
2310
2319
data = None , * , x = None , y = None , hue = None , order = None , hue_order = None ,
2311
- estimator = "mean" , errorbar = ("ci" , 95 ), n_boot = 1000 , units = None , seed = None ,
2312
- orient = None , color = None , palette = None , saturation = .75 , fill = True , hue_norm = None ,
2313
- width = .8 , dodge = "auto" , gap = 0 , log_scale = None , native_scale = False , formatter = None ,
2314
- legend = "auto" , capsize = 0 , err_kws = None ,
2320
+ estimator = "mean" , errorbar = ("ci" , 95 ), n_boot = 1000 , seed = None , units = None ,
2321
+ weights = None , orient = None , color = None , palette = None , saturation = .75 ,
2322
+ fill = True , hue_norm = None , width = .8 , dodge = "auto" , gap = 0 , log_scale = None ,
2323
+ native_scale = False , formatter = None , legend = "auto" , capsize = 0 , err_kws = None ,
2315
2324
ci = deprecated , errcolor = deprecated , errwidth = deprecated , ax = None , ** kwargs ,
2316
2325
):
2317
2326
@@ -2324,7 +2333,7 @@ def barplot(
2324
2333
2325
2334
p = _CategoricalAggPlotter (
2326
2335
data = data ,
2327
- variables = dict (x = x , y = y , hue = hue , units = units ),
2336
+ variables = dict (x = x , y = y , hue = hue , units = units , weight = weights ),
2328
2337
order = order ,
2329
2338
orient = orient ,
2330
2339
color = color ,
@@ -2354,7 +2363,8 @@ def barplot(
2354
2363
p .map_hue (palette = palette , order = hue_order , norm = hue_norm , saturation = saturation )
2355
2364
color = _default_color (ax .bar , hue , color , kwargs , saturation = saturation )
2356
2365
2357
- aggregator = EstimateAggregator (estimator , errorbar , n_boot = n_boot , seed = seed )
2366
+ agg_cls = WeightedAggregator if "weight" in p .plot_data else EstimateAggregator
2367
+ aggregator = agg_cls (estimator , errorbar , n_boot = n_boot , seed = seed )
2358
2368
err_kws = {} if err_kws is None else _normalize_kwargs (err_kws , mpl .lines .Line2D )
2359
2369
2360
2370
# Deprecations to remove in v0.15.0.
@@ -2449,20 +2459,19 @@ def barplot(
2449
2459
2450
2460
def pointplot (
2451
2461
data = None , * , x = None , y = None , hue = None , order = None , hue_order = None ,
2452
- estimator = "mean" , errorbar = ("ci" , 95 ), n_boot = 1000 , units = None , seed = None ,
2453
- color = None , palette = None , hue_norm = None , markers = default , linestyles = default ,
2454
- dodge = False , log_scale = None , native_scale = False , orient = None , capsize = 0 ,
2455
- formatter = None , legend = "auto" , err_kws = None ,
2462
+ estimator = "mean" , errorbar = ("ci" , 95 ), n_boot = 1000 , seed = None , units = None ,
2463
+ weights = None , color = None , palette = None , hue_norm = None , markers = default ,
2464
+ linestyles = default , dodge = False , log_scale = None , native_scale = False ,
2465
+ orient = None , capsize = 0 , formatter = None , legend = "auto" , err_kws = None ,
2456
2466
ci = deprecated , errwidth = deprecated , join = deprecated , scale = deprecated ,
2457
- ax = None ,
2458
- ** kwargs ,
2467
+ ax = None , ** kwargs ,
2459
2468
):
2460
2469
2461
2470
errorbar = utils ._deprecate_ci (errorbar , ci )
2462
2471
2463
2472
p = _CategoricalAggPlotter (
2464
2473
data = data ,
2465
- variables = dict (x = x , y = y , hue = hue , units = units ),
2474
+ variables = dict (x = x , y = y , hue = hue , units = units , weight = weights ),
2466
2475
order = order ,
2467
2476
orient = orient ,
2468
2477
# Handle special backwards compatibility where pointplot originally
@@ -2489,7 +2498,8 @@ def pointplot(
2489
2498
p .map_hue (palette = palette , order = hue_order , norm = hue_norm )
2490
2499
color = _default_color (ax .plot , hue , color , kwargs )
2491
2500
2492
- aggregator = EstimateAggregator (estimator , errorbar , n_boot = n_boot , seed = seed )
2501
+ agg_cls = WeightedAggregator if "weight" in p .plot_data else EstimateAggregator
2502
+ aggregator = agg_cls (estimator , errorbar , n_boot = n_boot , seed = seed )
2493
2503
err_kws = {} if err_kws is None else _normalize_kwargs (err_kws , mpl .lines .Line2D )
2494
2504
2495
2505
# Deprecations to remove in v0.15.0.
@@ -2729,12 +2739,12 @@ def countplot(
2729
2739
2730
2740
def catplot (
2731
2741
data = None , * , x = None , y = None , hue = None , row = None , col = None , kind = "strip" ,
2732
- estimator = "mean" , errorbar = ("ci" , 95 ), n_boot = 1000 , units = None , seed = None ,
2733
- order = None , hue_order = None , row_order = None , col_order = None , col_wrap = None ,
2734
- height = 5 , aspect = 1 , log_scale = None , native_scale = False , formatter = None ,
2735
- orient = None , color = None , palette = None , hue_norm = None , legend = "auto" ,
2736
- legend_out = True , sharex = True , sharey = True , margin_titles = False , facet_kws = None ,
2737
- ci = deprecated , ** kwargs
2742
+ estimator = "mean" , errorbar = ("ci" , 95 ), n_boot = 1000 , seed = None , units = None ,
2743
+ weights = None , order = None , hue_order = None , row_order = None , col_order = None ,
2744
+ col_wrap = None , height = 5 , aspect = 1 , log_scale = None , native_scale = False ,
2745
+ formatter = None , orient = None , color = None , palette = None , hue_norm = None ,
2746
+ legend = "auto" , legend_out = True , sharex = True , sharey = True ,
2747
+ margin_titles = False , facet_kws = None , ci = deprecated , ** kwargs
2738
2748
):
2739
2749
2740
2750
# Check for attempt to plot onto specific axes and warn
@@ -2764,7 +2774,9 @@ def catplot(
2764
2774
2765
2775
p = Plotter (
2766
2776
data = data ,
2767
- variables = dict (x = x , y = y , hue = hue , row = row , col = col , units = units ),
2777
+ variables = dict (
2778
+ x = x , y = y , hue = hue , row = row , col = col , units = units , weight = weights
2779
+ ),
2768
2780
order = order ,
2769
2781
orient = orient ,
2770
2782
# Handle special backwards compatibility where pointplot originally
@@ -2840,6 +2852,14 @@ def catplot(
2840
2852
if dodge == "auto" :
2841
2853
dodge = p ._dodge_needed ()
2842
2854
2855
+ if "weight" in p .plot_data :
2856
+ if kind not in ["bar" , "point" ]:
2857
+ msg = f"The `weights` parameter has no effect with kind={ kind !r} ."
2858
+ warnings .warn (msg , stacklevel = 2 )
2859
+ agg_cls = WeightedAggregator
2860
+ else :
2861
+ agg_cls = EstimateAggregator
2862
+
2843
2863
if kind == "strip" :
2844
2864
2845
2865
jitter = kwargs .pop ("jitter" , True )
@@ -2989,9 +3009,7 @@ def catplot(
2989
3009
2990
3010
elif kind == "point" :
2991
3011
2992
- aggregator = EstimateAggregator (
2993
- estimator , errorbar , n_boot = n_boot , seed = seed
2994
- )
3012
+ aggregator = agg_cls (estimator , errorbar , n_boot = n_boot , seed = seed )
2995
3013
2996
3014
markers = kwargs .pop ("markers" , default )
2997
3015
linestyles = kwargs .pop ("linestyles" , default )
@@ -3025,9 +3043,8 @@ def catplot(
3025
3043
3026
3044
elif kind == "bar" :
3027
3045
3028
- aggregator = EstimateAggregator (
3029
- estimator , errorbar , n_boot = n_boot , seed = seed
3030
- )
3046
+ aggregator = agg_cls (estimator , errorbar , n_boot = n_boot , seed = seed )
3047
+
3031
3048
err_kws , capsize = p ._err_kws_backcompat (
3032
3049
_normalize_kwargs (kwargs .pop ("err_kws" , {}), mpl .lines .Line2D ),
3033
3050
errcolor = kwargs .pop ("errcolor" , deprecated ),
0 commit comments