Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
130 changes: 92 additions & 38 deletions seaborn/_core/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,12 @@
import inspect
import itertools
import textwrap
from contextlib import contextmanager
from collections import abc
from collections.abc import Callable, Generator, Hashable
from typing import Any, Optional, cast

from cycler import cycler
import pandas as pd
from pandas import DataFrame, Series, Index
import matplotlib as mpl
Expand All @@ -30,6 +32,8 @@
from seaborn._core.typing import DataSource, VariableSpec, OrderSpec
from seaborn._core.rules import categorical_order
from seaborn._compat import set_scale_obj
from seaborn.rcmod import axes_style, plotting_context
from seaborn.palettes import color_palette
from seaborn.external.version import Version

from typing import TYPE_CHECKING
Expand Down Expand Up @@ -148,6 +152,7 @@ class Plot:
_scales: dict[str, Scale]
_limits: dict[str, tuple[Any, Any]]
_labels: dict[str, str | Callable[[str], str] | None]
_theme: dict[str, Any]

_facet_spec: FacetSpec
_pair_spec: PairSpec
Expand Down Expand Up @@ -176,12 +181,13 @@ def __init__(
self._scales = {}
self._limits = {}
self._labels = {}
self._theme = {}

self._facet_spec = {}
self._pair_spec = {}

self._subplot_spec = {}
self._figure_spec = {}
self._subplot_spec = {}

self._target = None

Expand Down Expand Up @@ -256,6 +262,26 @@ def _clone(self) -> Plot:

return new

def _theme_with_defaults(self) -> dict[str, Any]:

style_groups = [
"axes", "figure", "font", "grid", "hatch", "legend", "lines",
"mathtext", "markers", "patch", "savefig", "scatter",
"xaxis", "xtick", "yaxis", "ytick",
]
base = {
k: v for k, v in mpl.rcParamsDefault.items()
if any(k.startswith(p) for p in style_groups)
}
theme = {
**base,
**axes_style("darkgrid"),
**plotting_context("notebook"),
"axes.prop_cycle": cycler("color", color_palette("deep")),
}
theme.update(self._theme)
return theme

@property
def _variables(self) -> list[str]:

Expand Down Expand Up @@ -629,44 +655,73 @@ def configure(

# TODO def legend (ugh)

def theme(self) -> Plot:
def theme(self, *args: dict[str, Any]) -> Plot:
"""
Control the default appearance of elements in the plot.

TODO
The API for customizing plot appearance is not yet finalized.
Currently, the only valid argument is a dict of matplotlib rc parameters.
(This dict must be passed as a positional argument.)

It is likely that this method will be enhanced in future releases.

"""
# TODO Plot-specific themes using the seaborn theming system
raise NotImplementedError()
new = self._clone()
return new

# TODO decorate? (or similar, for various texts) alt names: label?
# We can skip this whole block on Python 3.8+ with positional-only syntax
nargs = len(args)
if nargs != 1:
err = f"theme() takes 1 positional argument, but {nargs} were given"
raise TypeError(err)

rc = args[0]
new._theme.update(rc)

return new

def save(self, fname, **kwargs) -> Plot:
def save(self, loc, **kwargs) -> Plot:
"""
Render the plot and write it to a buffer or file on disk.
Compile the plot and write it to a buffer or file on disk.

Parameters
----------
fname : str, path, or buffer
loc : str, path, or buffer
Location on disk to save the figure, or a buffer to write into.
kwargs
Other keyword arguments are passed through to
:meth:`matplotlib.figure.Figure.savefig`.

"""
# TODO expose important keyword arguments in our signature?
self.plot().save(fname, **kwargs)
with theme_context(self._theme_with_defaults()):
self._plot().save(loc, **kwargs)
return self

def plot(self, pyplot=False) -> Plotter:
def show(self, **kwargs) -> None:
"""
Compile the plot spec and return a Plotter object.
Compile and display the plot by hooking into pyplot.
"""
# TODO make pyplot configurable at the class level, and when not using,
# import IPython.display and call on self to populate cell output?

# Keep an eye on whether matplotlib implements "attaching" an existing
# figure to pyplot: https://github.com/matplotlib/matplotlib/pull/14024

self.plot(pyplot=True).show(**kwargs)

def plot(self, pyplot: bool = False) -> Plotter:
"""
Compile the plot spec and return the Plotter object.
"""
with theme_context(self._theme_with_defaults()):
return self._plot(pyplot)

def _plot(self, pyplot: bool = False) -> Plotter:

# TODO if we have _target object, pyplot should be determined by whether it
# is hooked into the pyplot state machine (how do we check?)

plotter = Plotter(pyplot=pyplot)
plotter = Plotter(pyplot=pyplot, theme=self._theme_with_defaults())

# Process the variable assignments and initialize the figure
common, layers = plotter._extract_data(self)
Expand Down Expand Up @@ -697,18 +752,6 @@ def plot(self, pyplot=False) -> Plotter:

return plotter

def show(self, **kwargs) -> None:
"""
Render and display the plot.
"""
# TODO make pyplot configurable at the class level, and when not using,
# import IPython.display and call on self to populate cell output?

# Keep an eye on whether matplotlib implements "attaching" an existing
# figure to pyplot: https://github.com/matplotlib/matplotlib/pull/14024

self.plot(pyplot=True).show(**kwargs)


# ---- The plot compilation engine ---------------------------------------------- #

Expand All @@ -725,12 +768,13 @@ class Plotter:
_layers: list[Layer]
_figure: Figure

def __init__(self, pyplot=False):
def __init__(self, pyplot: bool, theme: dict[str, Any]):

self.pyplot = pyplot
self._legend_contents: list[
tuple[str, str | int], list[Artist], list[str],
] = []
self._pyplot = pyplot
self._theme = theme
self._legend_contents: list[tuple[
tuple[str | None, str | int], list[Artist], list[str],
]] = []
self._scales: dict[str, Scale] = {}

def save(self, loc, **kwargs) -> Plotter: # TODO type args
Expand All @@ -747,7 +791,8 @@ def show(self, **kwargs) -> None:
# TODO if we did not create the Plotter with pyplot, is it possible to do this?
# If not we should clearly raise.
import matplotlib.pyplot as plt
plt.show(**kwargs)
with theme_context(self._theme):
plt.show(**kwargs)

# TODO API for accessing the underlying matplotlib objects
# TODO what else is useful in the public API for this class?
Expand Down Expand Up @@ -781,11 +826,12 @@ def _repr_png_(self) -> tuple[bytes, dict[str, float]]:

dpi = 96
buffer = io.BytesIO()
self._figure.savefig(buffer, dpi=dpi * 2, format="png", bbox_inches="tight")

with theme_context(self._theme):
self._figure.savefig(buffer, dpi=dpi * 2, format="png", bbox_inches="tight")
data = buffer.getvalue()

scaling = .85 / 2
# w, h = self._figure.get_size_inches()
w, h = Image.open(buffer).size
metadata = {"width": w * scaling, "height": h * scaling}
return data, metadata
Expand Down Expand Up @@ -824,9 +870,6 @@ def _setup_figure(self, p: Plot, common: PlotData, layers: list[Layer]) -> None:

# --- Parsing the faceting/pairing parameterization to specify figure grid

# TODO use context manager with theme that has been set
# TODO (maybe wrap THIS function with context manager; would be cleaner)

subplot_spec = p._subplot_spec.copy()
facet_spec = p._facet_spec.copy()
pair_spec = p._pair_spec.copy()
Expand All @@ -840,7 +883,7 @@ def _setup_figure(self, p: Plot, common: PlotData, layers: list[Layer]) -> None:

# --- Figure initialization
self._figure = subplots.init_figure(
pair_spec, self.pyplot, p._figure_spec, p._target,
pair_spec, self._pyplot, p._figure_spec, p._target,
)

# --- Figure annotation
Expand Down Expand Up @@ -1498,3 +1541,14 @@ def _finalize_figure(self, p: Plot) -> None:
# TODO this should be configurable
if not self._figure.get_constrained_layout():
self._figure.set_tight_layout(True)


@contextmanager
def theme_context(params: dict[str, Any]) -> Generator:
"""Temporarily modify specifc matplotlib rcParams."""
orig = {k: mpl.rcParams[k] for k in params}
try:
mpl.rcParams.update(params)
yield
finally:
mpl.rcParams.update(orig)
39 changes: 34 additions & 5 deletions tests/_core/test_plot.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
import io
import xml
import functools
import itertools
import warnings
import imghdr

import numpy as np
import pandas as pd
import matplotlib as mpl
import matplotlib.pyplot as plt
from PIL import Image

import pytest
from pandas.testing import assert_frame_equal, assert_series_equal
Expand Down Expand Up @@ -859,6 +861,23 @@ def test_paired_and_faceted(self, long_df):
assert_vector_equal(data["x"], long_df.loc[rows, x_i])
assert_vector_equal(data["y"], long_df.loc[rows, y])

def test_theme_default(self):

p = Plot().plot()
assert mpl.colors.same_color(p._figure.axes[0].get_facecolor(), "#EAEAF2")

def test_theme_params(self):

color = "r"
p = Plot().theme({"axes.facecolor": color}).plot()
assert mpl.colors.same_color(p._figure.axes[0].get_facecolor(), color)

def test_theme_error(self):

p = Plot()
with pytest.raises(TypeError, match=r"theme\(\) takes 1 positional"):
p.theme("arg1", "arg2")

def test_move(self, long_df):

orig_df = long_df.copy(deep=True)
Expand Down Expand Up @@ -949,21 +968,31 @@ def test_show(self):
if not gui_backend:
assert msg

def test_png_representation(self):
def test_png_repr(self):

p = Plot()
data, metadata = p._repr_png_()
img = Image.open(io.BytesIO(data))

assert not hasattr(p, "_figure")
assert isinstance(data, bytes)
assert imghdr.what("", data) == "png"
assert img.format == "PNG"
assert sorted(metadata) == ["height", "width"]
# TODO test retina scaling

@pytest.mark.xfail(reason="Plot.save not yet implemented")
def test_save(self):

Plot().save()
buf = io.BytesIO()

p = Plot().save(buf)
assert isinstance(p, Plot)
img = Image.open(buf)
assert img.format == "PNG"

buf = io.StringIO()
Plot().save(buf, format="svg")
tag = xml.etree.ElementTree.fromstring(buf.getvalue()).tag
assert tag == "{http://www.w3.org/2000/svg}svg"

def test_on_axes(self):

Expand Down
19 changes: 9 additions & 10 deletions tests/_marks/test_lines.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,18 +119,17 @@ def test_capstyle(self):
x = y = [1, 2]
rc = {"lines.solid_capstyle": "projecting", "lines.dash_capstyle": "round"}

with mpl.rc_context(rc):
p = Plot(x, y).add(Path()).plot()
line, = p._figure.axes[0].get_lines()
assert line.get_dash_capstyle() == "projecting"
p = Plot(x, y).add(Path()).theme(rc).plot()
line, = p._figure.axes[0].get_lines()
assert line.get_dash_capstyle() == "projecting"

p = Plot(x, y).add(Path(linestyle="--")).plot()
line, = p._figure.axes[0].get_lines()
assert line.get_dash_capstyle() == "round"
p = Plot(x, y).add(Path(linestyle="--")).theme(rc).plot()
line, = p._figure.axes[0].get_lines()
assert line.get_dash_capstyle() == "round"

p = Plot(x, y).add(Path({"solid_capstyle": "butt"})).plot()
line, = p._figure.axes[0].get_lines()
assert line.get_solid_capstyle() == "butt"
p = Plot(x, y).add(Path({"solid_capstyle": "butt"})).theme(rc).plot()
line, = p._figure.axes[0].get_lines()
assert line.get_solid_capstyle() == "butt"


class TestLine:
Expand Down