Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 4 additions & 5 deletions xarray/core/dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@
assert_coordinate_consistent,
remap_label_indexers,
)
from .dataset import Dataset, merge_indexes, split_indexes
from .dataset import Dataset, split_indexes
from .formatting import format_item
from .indexes import Indexes, default_indexes
from .merge import PANDAS_TYPES
Expand Down Expand Up @@ -1601,10 +1601,9 @@ def set_index(
--------
DataArray.reset_index
"""
_check_inplace(inplace)
indexes = either_dict_or_kwargs(indexes, indexes_kwargs, "set_index")
coords, _ = merge_indexes(indexes, self._coords, set(), append=append)
return self._replace(coords=coords)
ds = self._to_temp_dataset().set_index(indexes, append=append,
inplace=inplace, **indexes_kwargs)
return self._from_temp_dataset(ds)

def reset_index(
self,
Expand Down
13 changes: 11 additions & 2 deletions xarray/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,7 @@ def merge_indexes(
"""
vars_to_replace: Dict[Hashable, Variable] = {}
vars_to_remove: List[Hashable] = []
dims_to_replace: Dict[Hashable, Variable] = {}
error_msg = "{} is not the name of an existing variable."

for dim, var_names in indexes.items():
Expand Down Expand Up @@ -244,7 +245,7 @@ def merge_indexes(
if not len(names) and len(var_names) == 1:
idx = pd.Index(variables[var_names[0]].values)

else:
else: # MultiIndex
for n in var_names:
try:
var = variables[n]
Expand All @@ -256,15 +257,23 @@ def merge_indexes(
levels.append(cat.categories)

idx = pd.MultiIndex(levels, codes, names=names)
for n in names:
dims_to_replace[n] = dim

vars_to_replace[dim] = IndexVariable(dim, idx)
vars_to_remove.extend(var_names)

new_variables = {k: v for k, v in variables.items() if k not in vars_to_remove}
new_variables.update(vars_to_replace)

# update dimensions if necessary GH: 3512
for k, v in new_variables.items():
if any(d in dims_to_replace for d in v.dims):
new_dims = [dims_to_replace.get(d, d) for d in v.dims]
new_variables[k] = type(v)(new_dims, v.data, attrs=v.attrs,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FYI we could use v._replace here to save passing in the data that's not changing

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, @max-sixty

Does Variable have _replace method? I don't see this in variable.py.
Yes, I wonder how can I smartly update just a dimension name.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, one second, let me make this!

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

encoding=v.encoding, fastpath=True)
new_coord_names = coord_names | set(vars_to_replace)
new_coord_names -= set(vars_to_remove)

return new_variables, new_coord_names


Expand Down
10 changes: 10 additions & 0 deletions xarray/tests/test_dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -1182,6 +1182,16 @@ def test_selection_multiindex_remove_unused(self):
expected = expected.set_index(xy=["x", "y"]).unstack()
assert_identical(expected, actual)

def test_selection_multiindex_from_level(self):
# GH: 3512
da = DataArray([0, 1], dims=['x'], coords={'x': [0, 1], 'y': 'a'})
db = DataArray([2, 3], dims=['x'], coords={'x': [0, 1], 'y': 'b'})
data = xr.concat([da, db], dim='x').set_index(xy=['x', 'y'])
assert data.dims == ('xy', )
actual = data.sel(y='a')
expected = data.isel(xy=[0, 1]).unstack('xy').squeeze('y').drop('y')
assert_equal(actual, expected)

def test_virtual_default_coords(self):
array = DataArray(np.zeros((5,)), dims="x")
expected = DataArray(range(5), dims="x", name="x")
Expand Down