Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ New Features
By `Todd Jennings <https://github.com/toddrjen>`_
- Allow plotting of boolean arrays. (:pull:`3766`)
By `Marek Jacob <https://github.com/MeraX>`_
- Enable using MultiIndex levels as cordinates in 1D and 2D plots (:issue:`3927`).
By `Mathias Hauser <https://github.com/mathause>`_.
- A ``days_in_month`` accessor for :py:class:`xarray.CFTimeIndex`, analogous to
the ``days_in_month`` accessor for a :py:class:`pandas.DatetimeIndex`, which
returns the days in the month each datetime in the index. Now days in month
Expand Down
24 changes: 16 additions & 8 deletions xarray/plot/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,19 +29,27 @@


def _infer_line_data(darray, x, y, hue):
error_msg = "must be either None or one of ({:s})".format(
", ".join([repr(dd) for dd in darray.dims])
)
error_msg = "must be a dimension, coordinate, MultiIndex level name or None"
ndims = len(darray.dims)

if x is not None and x not in darray.dims and x not in darray.coords:
raise ValueError("x " + error_msg)
if (
x is not None
and x not in darray.dims
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we make valid_xy = set(darray.dims) + set(darray.coords) + set(darray._level_coords) and then print those values in the error message. It would make the if conditions on x and y pretty clean too.

Could do the same below too.

Is it possible to refactor out the x and y checking?

and x not in darray.coords
and x not in darray._level_coords
):
raise ValueError(f"x {error_msg}")

if y is not None and y not in darray.dims and y not in darray.coords:
raise ValueError("y " + error_msg)
if (
y is not None
and y not in darray.dims
and y not in darray.coords
and y not in darray._level_coords
):
raise ValueError(f"y {error_msg}")

if x is not None and y is not None:
raise ValueError("You cannot specify both x and y kwargs" "for line plots.")
raise ValueError("You cannot specify both x and y kwargs for line plots.")

if ndims == 1:
huename = None
Expand Down
39 changes: 32 additions & 7 deletions xarray/plot/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,24 +360,49 @@ def _infer_xy_labels(darray, x, y, imshow=False, rgb=None):

darray must be a 2 dimensional data array, or 3d for imshow only.
"""
assert x is None or x != y
if (x is not None) and (x == y):
raise ValueError("'x' and 'y' cannot be equal.")

if imshow and darray.ndim == 3:
return _infer_xy_labels_3d(darray, x, y, rgb)

error_msg = "must be a dimension, coordinate or MultiIndex level name"
if x is None and y is None:
if darray.ndim != 2:
raise ValueError("DataArray must be 2d")
y, x = darray.dims
elif x is None:
if y not in darray.dims and y not in darray.coords:
raise ValueError("y must be a dimension name if x is not supplied")
if (
y not in darray.dims
and y not in darray.coords
and y not in darray._level_coords
):
raise ValueError(f"'y' {error_msg}")
x = darray.dims[0] if y == darray.dims[1] else darray.dims[1]
elif y is None:
if x not in darray.dims and x not in darray.coords:
raise ValueError("x must be a dimension name if y is not supplied")
if (
x not in darray.dims
and x not in darray.coords
and x not in darray._level_coords
):
raise ValueError(f"'x' {error_msg}")
y = darray.dims[0] if x == darray.dims[1] else darray.dims[1]
elif any(k not in darray.coords and k not in darray.dims for k in (x, y)):
raise ValueError("x and y must be coordinate variables")
else:
if any(
k not in darray.coords
and k not in darray.dims
and k not in darray._level_coords
for k in (x, y)
):
raise ValueError(f"'x' and 'y' {error_msg}s")
elif (
all(k in darray._level_coords for k in (x, y))
and darray._level_coords[x] == darray._level_coords[y]
):
raise ValueError("'x' and 'y' cannot be levels of the same MultiIndex")
elif darray._level_coords.get(x, x) == darray._level_coords.get(y, y):
raise ValueError("'x' and 'y' cannot be a MultiIndex and one of its levels")

return x, y


Expand Down
69 changes: 59 additions & 10 deletions xarray/tests/test_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,15 +155,32 @@ def test_1d_x_y_kw(self):
for aa, (x, y) in enumerate(xy):
da.plot(x=x, y=y, ax=ax.flat[aa])

with raises_regex(ValueError, "cannot"):
with raises_regex(ValueError, "cannot specify both"):
da.plot(x="z", y="z")

with raises_regex(ValueError, "None"):
error_msg = "must be a dimension, coordinate, MultiIndex level name or None"
with raises_regex(ValueError, f"x {error_msg}"):
da.plot(x="f", y="z")

with raises_regex(ValueError, "None"):
with raises_regex(ValueError, f"y {error_msg}"):
da.plot(x="z", y="f")

def test_multiindex_level_as_coord(self):
da = xr.DataArray(
np.arange(5),
dims="x",
coords=dict(a=("x", np.arange(5)), b=("x", np.arange(5, 10))),
)
da = da.set_index(x=["a", "b"])

for x in ["a", "b"]:
h = da.plot(x=x)[0]
assert_array_equal(h.get_xdata(), da[x].values)

for y in ["a", "b"]:
h = da.plot(y=y)[0]
assert_array_equal(h.get_ydata(), da[y].values)

# Test for bug in GH issue #2725
def test_infer_line_data(self):
current = DataArray(
Expand Down Expand Up @@ -1031,6 +1048,16 @@ def test_nonnumeric_index_raises_typeerror(self):
with raises_regex(TypeError, r"[Pp]lot"):
self.plotfunc(a)

def test_multiindex_raises_typeerror(self):
a = DataArray(
easy_array((3, 2)),
dims=("x", "y"),
coords=dict(x=("x", [0, 1, 2]), a=("y", [0, 1]), b=("y", [2, 3])),
)
a = a.set_index(y=("a", "b"))
with raises_regex(TypeError, r"[Pp]lot"):
self.plotfunc(a)

def test_can_pass_in_axis(self):
self.pass_in_axis(self.plotmethod)

Expand Down Expand Up @@ -1139,15 +1166,16 @@ def test_positional_coord_string(self):
assert "y_long_name [y_units]" == ax.get_ylabel()

def test_bad_x_string_exception(self):
with raises_regex(ValueError, "x and y must be coordinate variables"):

with raises_regex(ValueError, "'x' and 'y' cannot be equal."):
self.plotmethod(x="y", y="y")

error_msg = "must be a dimension, coordinate or MultiIndex level name"
with raises_regex(ValueError, f"'x' and 'y' {error_msg}"):
self.plotmethod("not_a_real_dim", "y")
with raises_regex(
ValueError, "x must be a dimension name if y is not supplied"
):
with raises_regex(ValueError, f"'x' {error_msg}"):
self.plotmethod(x="not_a_real_dim")
with raises_regex(
ValueError, "y must be a dimension name if x is not supplied"
):
with raises_regex(ValueError, f"'y' {error_msg}"):
self.plotmethod(y="not_a_real_dim")
self.darray.coords["z"] = 100

Expand Down Expand Up @@ -1182,6 +1210,27 @@ def test_non_linked_coords_transpose(self):
# simply ensure that these high coords were passed over
assert np.min(ax.get_xlim()) > 100.0

def test_multiindex_level_as_coord(self):
da = DataArray(
easy_array((3, 2)),
dims=("x", "y"),
coords=dict(x=("x", [0, 1, 2]), a=("y", [0, 1]), b=("y", [2, 3])),
)
da = da.set_index(y=["a", "b"])

for x, y in (("a", "x"), ("b", "x"), ("x", "a"), ("x", "b")):
self.plotfunc(da, x=x, y=y)

ax = plt.gca()
assert x == ax.get_xlabel()
assert y == ax.get_ylabel()

with raises_regex(ValueError, "levels of the same MultiIndex"):
self.plotfunc(da, x="a", y="b")

with raises_regex(ValueError, "MultiIndex and one of its levels"):
self.plotfunc(da, x="a", y="y")

def test_default_title(self):
a = DataArray(easy_array((4, 3, 2)), dims=["a", "b", "c"])
a.coords["c"] = [0, 1]
Expand Down