Skip to content

Commit 3d83a8f

Browse files
committed
Tweak nominal scale axes akin to categorical axes in classic seaborn
1 parent aa56714 commit 3d83a8f

File tree

3 files changed

+39
-4
lines changed

3 files changed

+39
-4
lines changed

doc/whatsnew/v0.12.1.rst

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,11 @@ v0.12.1 (Unreleased)
66

77
- |Feature| Added the :class:`objects.Perc` stat (:pr:`3063`).
88

9-
- |Feature| The :class:`Band` and :class:`Range` marks will now cover the full extent of the data if `min` / `max` variables are not explicitly assigned or added in a transform (:pr:`3056`).
9+
- |Feature| The :class:`objects.Band` and :class:`objects.Range` marks will now cover the full extent of the data if `min` / `max` variables are not explicitly assigned or added in a transform (:pr:`3056`).
1010

11-
- |Enhancement| Marks that sort along the orient axis (e.g. :class:`Line`) now use a stable algorithm (:pr:`3064`).
11+
- |Enhancement| Marks that sort along the orient axis (e.g. :class:`objects.Line`) now use a stable algorithm (:pr:`3064`).
12+
13+
- |Enhancement| Axes with a :class:`objects.Nominal` scale now appear like categorical axes in class seaborn, with fixed margins, no grid, and an inverted y axis (:pr:`3069`).
1214

1315
- |Fix| Make :class:`objects.PolyFit` robust to missing data (:pr:`3010`).
1416

seaborn/_core/plot.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
from seaborn._stats.base import Stat
2626
from seaborn._core.data import PlotData
2727
from seaborn._core.moves import Move
28-
from seaborn._core.scales import Scale
28+
from seaborn._core.scales import Scale, Nominal
2929
from seaborn._core.subplots import Subplots
3030
from seaborn._core.groupby import GroupBy
3131
from seaborn._core.properties import PROPERTIES, Property
@@ -1236,7 +1236,6 @@ def _setup_scales(
12361236
# This only affects us when sharing *paired* axes. This is a novel/niche
12371237
# behavior, so we will raise rather than hack together a workaround.
12381238
if axis is not None and Version(mpl.__version__) < Version("3.4.0"):
1239-
from seaborn._core.scales import Nominal
12401239
paired_axis = axis in p._pair_spec.get("structure", {})
12411240
cat_scale = isinstance(scale, Nominal)
12421241
ok_dim = {"x": "col", "y": "row"}[axis]
@@ -1629,6 +1628,7 @@ def _finalize_figure(self, p: Plot) -> None:
16291628
ax = sub["ax"]
16301629
for axis in "xy":
16311630
axis_key = sub[axis]
1631+
axis_obj = getattr(ax, f"{axis}axis")
16321632

16331633
# Axis limits
16341634
if axis_key in p._limits:
@@ -1642,6 +1642,17 @@ def _finalize_figure(self, p: Plot) -> None:
16421642
hi = cast(float, hi) + 0.5
16431643
ax.set(**{f"{axis}lim": (lo, hi)})
16441644

1645+
# Nominal scale special-casing
1646+
if isinstance(self._scales.get(axis_key), Nominal):
1647+
axis_obj.grid(False, which="both")
1648+
if axis_key not in p._limits:
1649+
nticks = len(axis_obj.get_major_ticks())
1650+
lo, hi = -.5, nticks - .5
1651+
if axis == "y":
1652+
lo, hi = hi, lo
1653+
set_lim = getattr(ax, f"set_{axis}lim")
1654+
set_lim(lo, hi, auto=None)
1655+
16451656
engine_default = None if p._target is not None else "tight"
16461657
layout_engine = p._layout_spec.get("engine", engine_default)
16471658
set_layout_engine(self._figure, layout_engine)

tests/_core/test_plot.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -645,6 +645,28 @@ def test_undefined_variable_raises(self):
645645
with pytest.raises(RuntimeError, match=err):
646646
p.plot()
647647

648+
def test_nominal_x_axis_tweaks(self):
649+
650+
p = Plot(x=["a", "b", "c"], y=[1, 2, 3])
651+
ax1 = p.plot()._figure.axes[0]
652+
assert ax1.get_xlim() == (-.5, 2.5)
653+
assert not any(x.get_visible() for x in ax1.xaxis.get_gridlines())
654+
655+
lim = (-1, 2.1)
656+
ax2 = p.limit(x=lim).plot()._figure.axes[0]
657+
assert ax2.get_xlim() == lim
658+
659+
def test_nominal_y_axis_tweaks(self):
660+
661+
p = Plot(x=[1, 2, 3], y=["a", "b", "c"])
662+
ax1 = p.plot()._figure.axes[0]
663+
assert ax1.get_ylim() == (2.5, -.5)
664+
assert not any(y.get_visible() for y in ax1.yaxis.get_gridlines())
665+
666+
lim = (-1, 2.1)
667+
ax2 = p.limit(y=lim).plot()._figure.axes[0]
668+
assert ax2.get_ylim() == lim
669+
648670

649671
class TestPlotting:
650672

0 commit comments

Comments
 (0)