Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 48 additions & 9 deletions seaborn/_core/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import textwrap
from collections import abc
from collections.abc import Callable, Generator, Hashable
from typing import Any
from typing import Any, cast

import pandas as pd
from pandas import DataFrame, Series, Index
Expand Down Expand Up @@ -142,11 +142,11 @@ class Plot:
the plot without rendering it to access the lower-level representation.

"""
# TODO use TypedDict throughout?

_data: PlotData
_layers: list[Layer]

_scales: dict[str, Scale]
_limits: dict[str, tuple[Any, Any]]

_subplot_spec: dict[str, Any] # TODO values type
_facet_spec: FacetSpec
Expand All @@ -169,7 +169,9 @@ def __init__(

self._data = PlotData(data, variables)
self._layers = []

self._scales = {}
self._limits = {}

self._subplot_spec = {}
self._facet_spec = {}
Expand Down Expand Up @@ -543,6 +545,23 @@ def scale(self, **scales: Scale) -> Plot:
new._scales.update(**scales)
return new

def limit(self, **limits: tuple[Any, Any]) -> Plot:
"""
Control the range of visible data.

Keywords correspond to variables defined in the plot, and values are a
(min, max) tuple (where either can be `None` to leave unset).

Limits apply only to the axis scale; data outside the limits are still
used in any stat transforms and added to the plot.

Behavior for non-coordinate variables is currently undefined.

"""
new = self._clone()
new._limits.update(limits)
return new

def configure(
self,
figsize: tuple[float, float] | None = None,
Expand Down Expand Up @@ -634,11 +653,8 @@ def plot(self, pyplot=False) -> Plotter:
for layer in layers:
plotter._plot_layer(self, layer)

plotter._make_legend()

# TODO this should be configurable
if not plotter._figure.get_constrained_layout():
plotter._figure.set_tight_layout(True)
plotter._make_legend(self)
plotter._finalize_figure(self)

return plotter

Expand Down Expand Up @@ -1379,7 +1395,7 @@ def _update_legend_contents(

self._legend_contents.extend(contents)

def _make_legend(self) -> None:
def _make_legend(self, p: Plot) -> None:
"""Create the legend artist(s) and add onto the figure."""
# Combine artists representing same information across layers
# Input list has an entry for each distinct variable in each layer
Expand Down Expand Up @@ -1424,3 +1440,26 @@ def _make_legend(self) -> None:
else:
base_legend = legend
self._figure.legends.append(legend)

def _finalize_figure(self, p: Plot) -> None:

for sub in self._subplots:
ax = sub["ax"]
for axis in "xy":
axis_key = sub[axis]

# Axis limits
if axis_key in p._limits:
convert_units = getattr(ax, f"{axis}axis").convert_units
a, b = p._limits[axis_key]
lo = a if a is None else convert_units(a)
hi = b if b is None else convert_units(b)
if isinstance(a, str):
lo = cast(float, lo) - 0.5
if isinstance(b, str):
hi = cast(float, hi) + 0.5
ax.set(**{f"{axis}lim": (lo, hi)})

# TODO this should be configurable
if not self._figure.get_constrained_layout():
self._figure.set_tight_layout(True)
24 changes: 24 additions & 0 deletions tests/_core/test_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -995,6 +995,23 @@ def test_axis_labels_are_first_name(self, long_df):
assert ax.get_xlabel() == "a"
assert ax.get_ylabel() == "b"

def test_limits(self, long_df):

limit = (-2, 24)
p = Plot(long_df, x="x", y="y").limit(x=limit).plot()
ax1 = p._figure.axes[0]
assert ax1.get_xlim() == limit

limit = (np.datetime64("2005-01-01"), np.datetime64("2008-01-01"))
p = Plot(long_df, x="d", y="y").limit(x=limit).plot()
ax = p._figure.axes[0]
assert ax.get_xlim() == tuple(mpl.dates.date2num(limit))

limit = ("b", "c")
p = Plot(x=["a", "b", "c", "d"], y=[1, 2, 3, 4]).limit(x=limit).plot()
ax = p._figure.axes[0]
assert ax.get_xlim() == (0.5, 2.5)


class TestFacetInterface:

Expand Down Expand Up @@ -1382,6 +1399,13 @@ def test_two_variables_single_order_error(self, long_df):
with pytest.raises(RuntimeError, match=err):
p.facet(col="a", row="b", order=["a", "b", "c"])

def test_limits(self, long_df):

limit = (-2, 24)
p = Plot(long_df, y="y").pair(x=["x", "z"]).limit(x1=limit).plot()
ax1 = p._figure.axes[1]
assert ax1.get_xlim() == limit


class TestLabelVisibility:

Expand Down