8
8
import inspect
9
9
import itertools
10
10
import textwrap
11
+ from contextlib import contextmanager
11
12
from collections import abc
12
13
from collections .abc import Callable , Generator , Hashable
13
14
from typing import Any , Optional , cast
14
15
16
+ from cycler import cycler
15
17
import pandas as pd
16
18
from pandas import DataFrame , Series , Index
17
19
import matplotlib as mpl
30
32
from seaborn ._core .typing import DataSource , VariableSpec , OrderSpec
31
33
from seaborn ._core .rules import categorical_order
32
34
from seaborn ._compat import set_scale_obj
35
+ from seaborn .rcmod import axes_style , plotting_context
36
+ from seaborn .palettes import color_palette
33
37
from seaborn .external .version import Version
34
38
35
39
from typing import TYPE_CHECKING
@@ -148,6 +152,7 @@ class Plot:
148
152
_scales : dict [str , Scale ]
149
153
_limits : dict [str , tuple [Any , Any ]]
150
154
_labels : dict [str , str | Callable [[str ], str ] | None ]
155
+ _theme : dict [str , Any ]
151
156
152
157
_facet_spec : FacetSpec
153
158
_pair_spec : PairSpec
@@ -176,12 +181,13 @@ def __init__(
176
181
self ._scales = {}
177
182
self ._limits = {}
178
183
self ._labels = {}
184
+ self ._theme = {}
179
185
180
186
self ._facet_spec = {}
181
187
self ._pair_spec = {}
182
188
183
- self ._subplot_spec = {}
184
189
self ._figure_spec = {}
190
+ self ._subplot_spec = {}
185
191
186
192
self ._target = None
187
193
@@ -256,6 +262,26 @@ def _clone(self) -> Plot:
256
262
257
263
return new
258
264
265
+ def _theme_with_defaults (self ) -> dict [str , Any ]:
266
+
267
+ style_groups = [
268
+ "axes" , "figure" , "font" , "grid" , "hatch" , "legend" , "lines" ,
269
+ "mathtext" , "markers" , "patch" , "savefig" , "scatter" ,
270
+ "xaxis" , "xtick" , "yaxis" , "ytick" ,
271
+ ]
272
+ base = {
273
+ k : v for k , v in mpl .rcParamsDefault .items ()
274
+ if any (k .startswith (p ) for p in style_groups )
275
+ }
276
+ theme = {
277
+ ** base ,
278
+ ** axes_style ("darkgrid" ),
279
+ ** plotting_context ("notebook" ),
280
+ "axes.prop_cycle" : cycler ("color" , color_palette ("deep" )),
281
+ }
282
+ theme .update (self ._theme )
283
+ return theme
284
+
259
285
@property
260
286
def _variables (self ) -> list [str ]:
261
287
@@ -629,44 +655,73 @@ def configure(
629
655
630
656
# TODO def legend (ugh)
631
657
632
- def theme (self ) -> Plot :
658
+ def theme (self , * args : dict [ str , Any ] ) -> Plot :
633
659
"""
634
660
Control the default appearance of elements in the plot.
635
661
636
- TODO
662
+ The API for customizing plot appearance is not yet finalized.
663
+ Currently, the only valid argument is a dict of matplotlib rc parameters.
664
+ (This dict must be passed as a positional argument.)
665
+
666
+ It is likely that this method will be enhanced in future releases.
667
+
637
668
"""
638
- # TODO Plot-specific themes using the seaborn theming system
639
- raise NotImplementedError ()
640
669
new = self ._clone ()
641
- return new
642
670
643
- # TODO decorate? (or similar, for various texts) alt names: label?
671
+ # We can skip this whole block on Python 3.8+ with positional-only syntax
672
+ nargs = len (args )
673
+ if nargs != 1 :
674
+ err = f"theme() takes 1 positional argument, but { nargs } were given"
675
+ raise TypeError (err )
676
+
677
+ rc = args [0 ]
678
+ new ._theme .update (rc )
679
+
680
+ return new
644
681
645
- def save (self , fname , ** kwargs ) -> Plot :
682
+ def save (self , loc , ** kwargs ) -> Plot :
646
683
"""
647
- Render the plot and write it to a buffer or file on disk.
684
+ Compile the plot and write it to a buffer or file on disk.
648
685
649
686
Parameters
650
687
----------
651
- fname : str, path, or buffer
688
+ loc : str, path, or buffer
652
689
Location on disk to save the figure, or a buffer to write into.
653
690
kwargs
654
691
Other keyword arguments are passed through to
655
692
:meth:`matplotlib.figure.Figure.savefig`.
656
693
657
694
"""
658
695
# TODO expose important keyword arguments in our signature?
659
- self .plot ().save (fname , ** kwargs )
696
+ with theme_context (self ._theme_with_defaults ()):
697
+ self ._plot ().save (loc , ** kwargs )
660
698
return self
661
699
662
- def plot (self , pyplot = False ) -> Plotter :
700
+ def show (self , ** kwargs ) -> None :
663
701
"""
664
- Compile the plot spec and return a Plotter object .
702
+ Compile and display the plot by hooking into pyplot .
665
703
"""
704
+ # TODO make pyplot configurable at the class level, and when not using,
705
+ # import IPython.display and call on self to populate cell output?
706
+
707
+ # Keep an eye on whether matplotlib implements "attaching" an existing
708
+ # figure to pyplot: https://github.com/matplotlib/matplotlib/pull/14024
709
+
710
+ self .plot (pyplot = True ).show (** kwargs )
711
+
712
+ def plot (self , pyplot : bool = False ) -> Plotter :
713
+ """
714
+ Compile the plot spec and return the Plotter object.
715
+ """
716
+ with theme_context (self ._theme_with_defaults ()):
717
+ return self ._plot (pyplot )
718
+
719
+ def _plot (self , pyplot : bool = False ) -> Plotter :
720
+
666
721
# TODO if we have _target object, pyplot should be determined by whether it
667
722
# is hooked into the pyplot state machine (how do we check?)
668
723
669
- plotter = Plotter (pyplot = pyplot )
724
+ plotter = Plotter (pyplot = pyplot , theme = self . _theme_with_defaults () )
670
725
671
726
# Process the variable assignments and initialize the figure
672
727
common , layers = plotter ._extract_data (self )
@@ -697,18 +752,6 @@ def plot(self, pyplot=False) -> Plotter:
697
752
698
753
return plotter
699
754
700
- def show (self , ** kwargs ) -> None :
701
- """
702
- Render and display the plot.
703
- """
704
- # TODO make pyplot configurable at the class level, and when not using,
705
- # import IPython.display and call on self to populate cell output?
706
-
707
- # Keep an eye on whether matplotlib implements "attaching" an existing
708
- # figure to pyplot: https://github.com/matplotlib/matplotlib/pull/14024
709
-
710
- self .plot (pyplot = True ).show (** kwargs )
711
-
712
755
713
756
# ---- The plot compilation engine ---------------------------------------------- #
714
757
@@ -725,12 +768,13 @@ class Plotter:
725
768
_layers : list [Layer ]
726
769
_figure : Figure
727
770
728
- def __init__ (self , pyplot = False ):
771
+ def __init__ (self , pyplot : bool , theme : dict [ str , Any ] ):
729
772
730
- self .pyplot = pyplot
731
- self ._legend_contents : list [
732
- tuple [str , str | int ], list [Artist ], list [str ],
733
- ] = []
773
+ self ._pyplot = pyplot
774
+ self ._theme = theme
775
+ self ._legend_contents : list [tuple [
776
+ tuple [str | None , str | int ], list [Artist ], list [str ],
777
+ ]] = []
734
778
self ._scales : dict [str , Scale ] = {}
735
779
736
780
def save (self , loc , ** kwargs ) -> Plotter : # TODO type args
@@ -747,7 +791,8 @@ def show(self, **kwargs) -> None:
747
791
# TODO if we did not create the Plotter with pyplot, is it possible to do this?
748
792
# If not we should clearly raise.
749
793
import matplotlib .pyplot as plt
750
- plt .show (** kwargs )
794
+ with theme_context (self ._theme ):
795
+ plt .show (** kwargs )
751
796
752
797
# TODO API for accessing the underlying matplotlib objects
753
798
# TODO what else is useful in the public API for this class?
@@ -781,11 +826,12 @@ def _repr_png_(self) -> tuple[bytes, dict[str, float]]:
781
826
782
827
dpi = 96
783
828
buffer = io .BytesIO ()
784
- self ._figure .savefig (buffer , dpi = dpi * 2 , format = "png" , bbox_inches = "tight" )
829
+
830
+ with theme_context (self ._theme ):
831
+ self ._figure .savefig (buffer , dpi = dpi * 2 , format = "png" , bbox_inches = "tight" )
785
832
data = buffer .getvalue ()
786
833
787
834
scaling = .85 / 2
788
- # w, h = self._figure.get_size_inches()
789
835
w , h = Image .open (buffer ).size
790
836
metadata = {"width" : w * scaling , "height" : h * scaling }
791
837
return data , metadata
@@ -824,9 +870,6 @@ def _setup_figure(self, p: Plot, common: PlotData, layers: list[Layer]) -> None:
824
870
825
871
# --- Parsing the faceting/pairing parameterization to specify figure grid
826
872
827
- # TODO use context manager with theme that has been set
828
- # TODO (maybe wrap THIS function with context manager; would be cleaner)
829
-
830
873
subplot_spec = p ._subplot_spec .copy ()
831
874
facet_spec = p ._facet_spec .copy ()
832
875
pair_spec = p ._pair_spec .copy ()
@@ -840,7 +883,7 @@ def _setup_figure(self, p: Plot, common: PlotData, layers: list[Layer]) -> None:
840
883
841
884
# --- Figure initialization
842
885
self ._figure = subplots .init_figure (
843
- pair_spec , self .pyplot , p ._figure_spec , p ._target ,
886
+ pair_spec , self ._pyplot , p ._figure_spec , p ._target ,
844
887
)
845
888
846
889
# --- Figure annotation
@@ -1498,3 +1541,14 @@ def _finalize_figure(self, p: Plot) -> None:
1498
1541
# TODO this should be configurable
1499
1542
if not self ._figure .get_constrained_layout ():
1500
1543
self ._figure .set_tight_layout (True )
1544
+
1545
+
1546
+ @contextmanager
1547
+ def theme_context (params : dict [str , Any ]) -> Generator :
1548
+ """Temporarily modify specifc matplotlib rcParams."""
1549
+ orig = {k : mpl .rcParams [k ] for k in params }
1550
+ try :
1551
+ mpl .rcParams .update (params )
1552
+ yield
1553
+ finally :
1554
+ mpl .rcParams .update (orig )
0 commit comments