Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions doc/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,7 @@ Computation
:py:attr:`~Dataset.round`
:py:attr:`~Dataset.real`
:py:attr:`~Dataset.T`
:py:attr:`~DataArray.dot`

**Grouped operations**:
:py:attr:`~core.groupby.DatasetGroupBy.assign`
Expand Down
2 changes: 2 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ v0.7.2 (unreleased)

Enhancements
~~~~~~~~~~~~
-xarray version of np.dot :py:meth:`~DataArray.dot`. Performs dot product of
two DataArrays along their shared dims

- Rolling window operations on DataArray objects are now supported via a new
:py:meth:`xarray.DataArray.rolling` method.
Expand Down
59 changes: 59 additions & 0 deletions xarray/core/dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -1374,5 +1374,64 @@ def real(self):
def imag(self):
return self._replace(self.variable.imag)

def dot(self, other):
"""Perform dot product of two DataArrays along their shared dims.

Equivalent to taking taking tensordot over all shared dims.

Parameters
----------
other : DataArray
The other array with which the dot product is performed.

Returns
-------
result : DataArray
Array resulting from the dot product over all shared dimensions.

See also
--------
np.tensordot(a, b, axes)

Examples
--------

>>> da_vals = np.arange(6 * 5 * 4).reshape((6, 5, 4))
>>> da = DataArray(da_vals, dims=['x', 'y', 'z'])
>>> dm_vals = np.arange(4)
>>> dm = DataArray(dm_vals, dims=['z'])

>>> dm.dims
('z')
>>> da.dims
('x', 'y', 'z')

>>> dot_result = da.dot(dm)
>>> dot_result.dims
('x', 'y')
"""
if isinstance(other, Dataset):
raise NotImplementedError('dot products are not yet supported '
'with Dataset objects.')
if not isinstance(other, DataArray):
raise TypeError('dot only operates on DataArrays.')
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it would be fine to remove lines 1413:1415, and change 1417 to something like:

if not isinstance(other, DataArray):
    raise TypeError('dot only operates on DataArrays, got {}'.format(type(other))

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, could go either way here. I like raising NotImplementedError because it makes it more obvious to users that this behavior might change in the future. Besides, if we get lucky a user will notice the error and then implement the missing functionality themselves ;).


# sum over the common dims
dims = set(self.dims) & set(other.dims)
if len(dims) == 0:
raise ValueError('DataArrays have no shared dimensions over which '
'to perform dot.')

self, other = align(self, other, join='inner', copy=False)

axes = (self.get_axis_num(dims), other.get_axis_num(dims))
new_data = ops.tensordot(self.data, other.data, axes=axes)

new_coords = self.coords.merge(other.coords).drop(dims)
new_dims = ([d for d in self.dims if d not in dims] +
[d for d in other.dims if d not in dims])

return type(self)(new_data, new_coords, new_dims)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

type(self) will always be a DataArray, yes?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe not. Depends on if we want to figure out a way to extend it to DataSet at some point.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, but you raise an exception if other is not DataArray. So until you apply this to the DataSet, this just an abstraction.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm happy either way, @shoyer suggested this way, the way your suggesting is the way I originally did it. I would side with you on that it should be generic only when it needs to be generic.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, this is what you get for listening to @shoyer. Either way is fine by me. I know what you were going for, it makes since. This way is just a little harder to read and includes one extra function call.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or if you inherit from DataArray, then the type call will return the inherited class, which is good

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In general, I prefer type(self), and most other xarray code uses it:

  1. as @MaximilianR mentions it works better with subclassing
  2. it avoids repeating the class name twice, which makes refactoring easier (this is similar to how they removed class name as a required argument to super in Python 3)

But in the scheme of things it isn't very important :).


# priority most be higher than Variable to properly work with binary ufuncs
ops.inject_all_ops_and_reduce_methods(DataArray, priority=60)
2 changes: 2 additions & 0 deletions xarray/core/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,8 @@ def _fail_on_dask_array_input(values, msg=None, func_name=None):
array_all = _dask_or_eager_func('all')
array_any = _dask_or_eager_func('any')

tensordot = _dask_or_eager_func('tensordot', n_array_args=2)


def _interleaved_indices_required(indices):
"""With dask, we care about data locality and would rather avoid splitting
Expand Down
5 changes: 5 additions & 0 deletions xarray/test/test_dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,3 +300,8 @@ def test_stack(self):
dims=['w', 'z'])
assert stacked.data.chunks == expected.data.chunks
self.assertLazyAndIdentical(expected, stacked)

def test_dot(self):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This has the right idea, but it's a pretty complex example. I would prefer something more minimal like I suggested earlier, e.g.,

def test_dot(self):
    eager = self.eager_array.dot(self.eager_array[0])
    lazy = self.lazy_array.dot(self.lazy_array[0]))
    self.assertLazyAndAllClose(eager, lazy) 

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The reason to make this test something simpler is reduce the redundancy with the other test_dot test case. Complex tests can be good to make sure things work, but it's often easier to debug simpler/shorter ones.

eager = self.eager_array.dot(self.eager_array[0])
lazy = self.lazy_array.dot(self.lazy_array[0])
self.assertLazyAndAllClose(eager, lazy)
38 changes: 38 additions & 0 deletions xarray/test/test_dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -1792,3 +1792,41 @@ def test_full_like(self):
actual = _full_like(DataArray([1, 2, 3]), fill_value=np.nan)
self.assertEqual(actual.dtype, np.float)
np.testing.assert_equal(actual.values, np.nan)

def test_dot(self):
x = np.linspace(-3, 3, 6)
y = np.linspace(-3, 3, 5)
z = range(4)
da_vals = np.arange(6 * 5 * 4).reshape((6, 5, 4))
da = DataArray(da_vals, coords=[x, y, z], dims=['x', 'y', 'z'])

dm_vals = range(4)
dm = DataArray(dm_vals, coords=[z], dims=['z'])

# nd dot 1d
actual = da.dot(dm)
expected_vals = np.tensordot(da_vals, dm_vals, [2, 0])
expected = DataArray(expected_vals, coords=[x, y], dims=['x', 'y'])
self.assertDataArrayEqual(expected, actual)

# all shared dims
actual = da.dot(da)
expected_vals = np.tensordot(da_vals, da_vals, axes=([0, 1, 2], [0, 1, 2]))
expected = DataArray(expected_vals)
self.assertDataArrayEqual(expected, actual)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

only one space at the end of the file is necessary :)

# multiple shared dims
dm_vals = np.arange(20 * 5 * 4).reshape((20, 5, 4))
j = np.linspace(-3, 3, 20)
dm = DataArray(dm_vals, coords=[j, y, z], dims=['j', 'y', 'z'])
actual = da.dot(dm)
expected_vals = np.tensordot(da_vals, dm_vals, axes=([1, 2], [1, 2]))
expected = DataArray(expected_vals, coords=[x, j], dims=['x', 'j'])
self.assertDataArrayEqual(expected, actual)

with self.assertRaises(NotImplementedError):
da.dot(dm.to_dataset(name='dm'))
with self.assertRaises(TypeError):
da.dot(dm.values)
with self.assertRaisesRegexp(ValueError, 'no shared dimensions'):
da.dot(DataArray(1))