Skip to content

Commit 762db89

Browse files
authored
Add rudimentary themeing support (#2929)
* WIP Plot.theme * Add default values for theme to match set_theme() * Depend on matplotib style defaults and update rcParams more selectively * Fix lines test * Improve test coverage
1 parent ff96e1f commit 762db89

File tree

3 files changed

+135
-53
lines changed

3 files changed

+135
-53
lines changed

seaborn/_core/plot.py

Lines changed: 92 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,12 @@
88
import inspect
99
import itertools
1010
import textwrap
11+
from contextlib import contextmanager
1112
from collections import abc
1213
from collections.abc import Callable, Generator, Hashable
1314
from typing import Any, Optional, cast
1415

16+
from cycler import cycler
1517
import pandas as pd
1618
from pandas import DataFrame, Series, Index
1719
import matplotlib as mpl
@@ -30,6 +32,8 @@
3032
from seaborn._core.typing import DataSource, VariableSpec, OrderSpec
3133
from seaborn._core.rules import categorical_order
3234
from seaborn._compat import set_scale_obj
35+
from seaborn.rcmod import axes_style, plotting_context
36+
from seaborn.palettes import color_palette
3337
from seaborn.external.version import Version
3438

3539
from typing import TYPE_CHECKING
@@ -148,6 +152,7 @@ class Plot:
148152
_scales: dict[str, Scale]
149153
_limits: dict[str, tuple[Any, Any]]
150154
_labels: dict[str, str | Callable[[str], str] | None]
155+
_theme: dict[str, Any]
151156

152157
_facet_spec: FacetSpec
153158
_pair_spec: PairSpec
@@ -176,12 +181,13 @@ def __init__(
176181
self._scales = {}
177182
self._limits = {}
178183
self._labels = {}
184+
self._theme = {}
179185

180186
self._facet_spec = {}
181187
self._pair_spec = {}
182188

183-
self._subplot_spec = {}
184189
self._figure_spec = {}
190+
self._subplot_spec = {}
185191

186192
self._target = None
187193

@@ -256,6 +262,26 @@ def _clone(self) -> Plot:
256262

257263
return new
258264

265+
def _theme_with_defaults(self) -> dict[str, Any]:
266+
267+
style_groups = [
268+
"axes", "figure", "font", "grid", "hatch", "legend", "lines",
269+
"mathtext", "markers", "patch", "savefig", "scatter",
270+
"xaxis", "xtick", "yaxis", "ytick",
271+
]
272+
base = {
273+
k: v for k, v in mpl.rcParamsDefault.items()
274+
if any(k.startswith(p) for p in style_groups)
275+
}
276+
theme = {
277+
**base,
278+
**axes_style("darkgrid"),
279+
**plotting_context("notebook"),
280+
"axes.prop_cycle": cycler("color", color_palette("deep")),
281+
}
282+
theme.update(self._theme)
283+
return theme
284+
259285
@property
260286
def _variables(self) -> list[str]:
261287

@@ -629,44 +655,73 @@ def configure(
629655

630656
# TODO def legend (ugh)
631657

632-
def theme(self) -> Plot:
658+
def theme(self, *args: dict[str, Any]) -> Plot:
633659
"""
634660
Control the default appearance of elements in the plot.
635661
636-
TODO
662+
The API for customizing plot appearance is not yet finalized.
663+
Currently, the only valid argument is a dict of matplotlib rc parameters.
664+
(This dict must be passed as a positional argument.)
665+
666+
It is likely that this method will be enhanced in future releases.
667+
637668
"""
638-
# TODO Plot-specific themes using the seaborn theming system
639-
raise NotImplementedError()
640669
new = self._clone()
641-
return new
642670

643-
# TODO decorate? (or similar, for various texts) alt names: label?
671+
# We can skip this whole block on Python 3.8+ with positional-only syntax
672+
nargs = len(args)
673+
if nargs != 1:
674+
err = f"theme() takes 1 positional argument, but {nargs} were given"
675+
raise TypeError(err)
676+
677+
rc = args[0]
678+
new._theme.update(rc)
679+
680+
return new
644681

645-
def save(self, fname, **kwargs) -> Plot:
682+
def save(self, loc, **kwargs) -> Plot:
646683
"""
647-
Render the plot and write it to a buffer or file on disk.
684+
Compile the plot and write it to a buffer or file on disk.
648685
649686
Parameters
650687
----------
651-
fname : str, path, or buffer
688+
loc : str, path, or buffer
652689
Location on disk to save the figure, or a buffer to write into.
653690
kwargs
654691
Other keyword arguments are passed through to
655692
:meth:`matplotlib.figure.Figure.savefig`.
656693
657694
"""
658695
# TODO expose important keyword arguments in our signature?
659-
self.plot().save(fname, **kwargs)
696+
with theme_context(self._theme_with_defaults()):
697+
self._plot().save(loc, **kwargs)
660698
return self
661699

662-
def plot(self, pyplot=False) -> Plotter:
700+
def show(self, **kwargs) -> None:
663701
"""
664-
Compile the plot spec and return a Plotter object.
702+
Compile and display the plot by hooking into pyplot.
665703
"""
704+
# TODO make pyplot configurable at the class level, and when not using,
705+
# import IPython.display and call on self to populate cell output?
706+
707+
# Keep an eye on whether matplotlib implements "attaching" an existing
708+
# figure to pyplot: https://github.com/matplotlib/matplotlib/pull/14024
709+
710+
self.plot(pyplot=True).show(**kwargs)
711+
712+
def plot(self, pyplot: bool = False) -> Plotter:
713+
"""
714+
Compile the plot spec and return the Plotter object.
715+
"""
716+
with theme_context(self._theme_with_defaults()):
717+
return self._plot(pyplot)
718+
719+
def _plot(self, pyplot: bool = False) -> Plotter:
720+
666721
# TODO if we have _target object, pyplot should be determined by whether it
667722
# is hooked into the pyplot state machine (how do we check?)
668723

669-
plotter = Plotter(pyplot=pyplot)
724+
plotter = Plotter(pyplot=pyplot, theme=self._theme_with_defaults())
670725

671726
# Process the variable assignments and initialize the figure
672727
common, layers = plotter._extract_data(self)
@@ -697,18 +752,6 @@ def plot(self, pyplot=False) -> Plotter:
697752

698753
return plotter
699754

700-
def show(self, **kwargs) -> None:
701-
"""
702-
Render and display the plot.
703-
"""
704-
# TODO make pyplot configurable at the class level, and when not using,
705-
# import IPython.display and call on self to populate cell output?
706-
707-
# Keep an eye on whether matplotlib implements "attaching" an existing
708-
# figure to pyplot: https://github.com/matplotlib/matplotlib/pull/14024
709-
710-
self.plot(pyplot=True).show(**kwargs)
711-
712755

713756
# ---- The plot compilation engine ---------------------------------------------- #
714757

@@ -725,12 +768,13 @@ class Plotter:
725768
_layers: list[Layer]
726769
_figure: Figure
727770

728-
def __init__(self, pyplot=False):
771+
def __init__(self, pyplot: bool, theme: dict[str, Any]):
729772

730-
self.pyplot = pyplot
731-
self._legend_contents: list[
732-
tuple[str, str | int], list[Artist], list[str],
733-
] = []
773+
self._pyplot = pyplot
774+
self._theme = theme
775+
self._legend_contents: list[tuple[
776+
tuple[str | None, str | int], list[Artist], list[str],
777+
]] = []
734778
self._scales: dict[str, Scale] = {}
735779

736780
def save(self, loc, **kwargs) -> Plotter: # TODO type args
@@ -747,7 +791,8 @@ def show(self, **kwargs) -> None:
747791
# TODO if we did not create the Plotter with pyplot, is it possible to do this?
748792
# If not we should clearly raise.
749793
import matplotlib.pyplot as plt
750-
plt.show(**kwargs)
794+
with theme_context(self._theme):
795+
plt.show(**kwargs)
751796

752797
# TODO API for accessing the underlying matplotlib objects
753798
# TODO what else is useful in the public API for this class?
@@ -781,11 +826,12 @@ def _repr_png_(self) -> tuple[bytes, dict[str, float]]:
781826

782827
dpi = 96
783828
buffer = io.BytesIO()
784-
self._figure.savefig(buffer, dpi=dpi * 2, format="png", bbox_inches="tight")
829+
830+
with theme_context(self._theme):
831+
self._figure.savefig(buffer, dpi=dpi * 2, format="png", bbox_inches="tight")
785832
data = buffer.getvalue()
786833

787834
scaling = .85 / 2
788-
# w, h = self._figure.get_size_inches()
789835
w, h = Image.open(buffer).size
790836
metadata = {"width": w * scaling, "height": h * scaling}
791837
return data, metadata
@@ -824,9 +870,6 @@ def _setup_figure(self, p: Plot, common: PlotData, layers: list[Layer]) -> None:
824870

825871
# --- Parsing the faceting/pairing parameterization to specify figure grid
826872

827-
# TODO use context manager with theme that has been set
828-
# TODO (maybe wrap THIS function with context manager; would be cleaner)
829-
830873
subplot_spec = p._subplot_spec.copy()
831874
facet_spec = p._facet_spec.copy()
832875
pair_spec = p._pair_spec.copy()
@@ -840,7 +883,7 @@ def _setup_figure(self, p: Plot, common: PlotData, layers: list[Layer]) -> None:
840883

841884
# --- Figure initialization
842885
self._figure = subplots.init_figure(
843-
pair_spec, self.pyplot, p._figure_spec, p._target,
886+
pair_spec, self._pyplot, p._figure_spec, p._target,
844887
)
845888

846889
# --- Figure annotation
@@ -1498,3 +1541,14 @@ def _finalize_figure(self, p: Plot) -> None:
14981541
# TODO this should be configurable
14991542
if not self._figure.get_constrained_layout():
15001543
self._figure.set_tight_layout(True)
1544+
1545+
1546+
@contextmanager
1547+
def theme_context(params: dict[str, Any]) -> Generator:
1548+
"""Temporarily modify specifc matplotlib rcParams."""
1549+
orig = {k: mpl.rcParams[k] for k in params}
1550+
try:
1551+
mpl.rcParams.update(params)
1552+
yield
1553+
finally:
1554+
mpl.rcParams.update(orig)

tests/_core/test_plot.py

Lines changed: 34 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
1+
import io
2+
import xml
13
import functools
24
import itertools
35
import warnings
4-
import imghdr
56

67
import numpy as np
78
import pandas as pd
89
import matplotlib as mpl
910
import matplotlib.pyplot as plt
11+
from PIL import Image
1012

1113
import pytest
1214
from pandas.testing import assert_frame_equal, assert_series_equal
@@ -859,6 +861,23 @@ def test_paired_and_faceted(self, long_df):
859861
assert_vector_equal(data["x"], long_df.loc[rows, x_i])
860862
assert_vector_equal(data["y"], long_df.loc[rows, y])
861863

864+
def test_theme_default(self):
865+
866+
p = Plot().plot()
867+
assert mpl.colors.same_color(p._figure.axes[0].get_facecolor(), "#EAEAF2")
868+
869+
def test_theme_params(self):
870+
871+
color = "r"
872+
p = Plot().theme({"axes.facecolor": color}).plot()
873+
assert mpl.colors.same_color(p._figure.axes[0].get_facecolor(), color)
874+
875+
def test_theme_error(self):
876+
877+
p = Plot()
878+
with pytest.raises(TypeError, match=r"theme\(\) takes 1 positional"):
879+
p.theme("arg1", "arg2")
880+
862881
def test_move(self, long_df):
863882

864883
orig_df = long_df.copy(deep=True)
@@ -949,21 +968,31 @@ def test_show(self):
949968
if not gui_backend:
950969
assert msg
951970

952-
def test_png_representation(self):
971+
def test_png_repr(self):
953972

954973
p = Plot()
955974
data, metadata = p._repr_png_()
975+
img = Image.open(io.BytesIO(data))
956976

957977
assert not hasattr(p, "_figure")
958978
assert isinstance(data, bytes)
959-
assert imghdr.what("", data) == "png"
979+
assert img.format == "PNG"
960980
assert sorted(metadata) == ["height", "width"]
961981
# TODO test retina scaling
962982

963-
@pytest.mark.xfail(reason="Plot.save not yet implemented")
964983
def test_save(self):
965984

966-
Plot().save()
985+
buf = io.BytesIO()
986+
987+
p = Plot().save(buf)
988+
assert isinstance(p, Plot)
989+
img = Image.open(buf)
990+
assert img.format == "PNG"
991+
992+
buf = io.StringIO()
993+
Plot().save(buf, format="svg")
994+
tag = xml.etree.ElementTree.fromstring(buf.getvalue()).tag
995+
assert tag == "{http://www.w3.org/2000/svg}svg"
967996

968997
def test_on_axes(self):
969998

tests/_marks/test_lines.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -119,18 +119,17 @@ def test_capstyle(self):
119119
x = y = [1, 2]
120120
rc = {"lines.solid_capstyle": "projecting", "lines.dash_capstyle": "round"}
121121

122-
with mpl.rc_context(rc):
123-
p = Plot(x, y).add(Path()).plot()
124-
line, = p._figure.axes[0].get_lines()
125-
assert line.get_dash_capstyle() == "projecting"
122+
p = Plot(x, y).add(Path()).theme(rc).plot()
123+
line, = p._figure.axes[0].get_lines()
124+
assert line.get_dash_capstyle() == "projecting"
126125

127-
p = Plot(x, y).add(Path(linestyle="--")).plot()
128-
line, = p._figure.axes[0].get_lines()
129-
assert line.get_dash_capstyle() == "round"
126+
p = Plot(x, y).add(Path(linestyle="--")).theme(rc).plot()
127+
line, = p._figure.axes[0].get_lines()
128+
assert line.get_dash_capstyle() == "round"
130129

131-
p = Plot(x, y).add(Path({"solid_capstyle": "butt"})).plot()
132-
line, = p._figure.axes[0].get_lines()
133-
assert line.get_solid_capstyle() == "butt"
130+
p = Plot(x, y).add(Path({"solid_capstyle": "butt"})).theme(rc).plot()
131+
line, = p._figure.axes[0].get_lines()
132+
assert line.get_solid_capstyle() == "butt"
134133

135134

136135
class TestLine:

0 commit comments

Comments
 (0)