-
-
Notifications
You must be signed in to change notification settings - Fork 1.2k
Add tensordot to dataarray class also add its test to test_dataarray #731
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1374,5 +1374,64 @@ def real(self): | |
| def imag(self): | ||
| return self._replace(self.variable.imag) | ||
|
|
||
| def dot(self, other): | ||
| """Perform dot product of two DataArrays along their shared dims. | ||
|
|
||
| Equivalent to taking taking tensordot over all shared dims. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| other : DataArray | ||
| The other array with which the dot product is performed. | ||
|
|
||
| Returns | ||
| ------- | ||
| result : DataArray | ||
| Array resulting from the dot product over all shared dimensions. | ||
|
|
||
| See also | ||
| -------- | ||
| np.tensordot(a, b, axes) | ||
|
|
||
| Examples | ||
| -------- | ||
|
|
||
| >>> da_vals = np.arange(6 * 5 * 4).reshape((6, 5, 4)) | ||
| >>> da = DataArray(da_vals, dims=['x', 'y', 'z']) | ||
| >>> dm_vals = np.arange(4) | ||
| >>> dm = DataArray(dm_vals, dims=['z']) | ||
|
|
||
| >>> dm.dims | ||
| ('z') | ||
| >>> da.dims | ||
| ('x', 'y', 'z') | ||
|
|
||
| >>> dot_result = da.dot(dm) | ||
| >>> dot_result.dims | ||
| ('x', 'y') | ||
| """ | ||
| if isinstance(other, Dataset): | ||
| raise NotImplementedError('dot products are not yet supported ' | ||
| 'with Dataset objects.') | ||
| if not isinstance(other, DataArray): | ||
| raise TypeError('dot only operates on DataArrays.') | ||
|
|
||
| # sum over the common dims | ||
| dims = set(self.dims) & set(other.dims) | ||
| if len(dims) == 0: | ||
| raise ValueError('DataArrays have no shared dimensions over which ' | ||
| 'to perform dot.') | ||
|
|
||
| self, other = align(self, other, join='inner', copy=False) | ||
|
|
||
| axes = (self.get_axis_num(dims), other.get_axis_num(dims)) | ||
| new_data = ops.tensordot(self.data, other.data, axes=axes) | ||
|
|
||
| new_coords = self.coords.merge(other.coords).drop(dims) | ||
| new_dims = ([d for d in self.dims if d not in dims] + | ||
| [d for d in other.dims if d not in dims]) | ||
|
|
||
| return type(self)(new_data, new_coords, new_dims) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe not. Depends on if we want to figure out a way to extend it to DataSet at some point.
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sure, but you raise an exception if
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm happy either way, @shoyer suggested this way, the way your suggesting is the way I originally did it. I would side with you on that it should be generic only when it needs to be generic.
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Well, this is what you get for listening to @shoyer. Either way is fine by me. I know what you were going for, it makes since. This way is just a little harder to read and includes one extra function call.
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Or if you inherit from DataArray, then the
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In general, I prefer
But in the scheme of things it isn't very important :). |
||
|
|
||
| # priority most be higher than Variable to properly work with binary ufuncs | ||
| ops.inject_all_ops_and_reduce_methods(DataArray, priority=60) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -300,3 +300,8 @@ def test_stack(self): | |
| dims=['w', 'z']) | ||
| assert stacked.data.chunks == expected.data.chunks | ||
| self.assertLazyAndIdentical(expected, stacked) | ||
|
|
||
| def test_dot(self): | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This has the right idea, but it's a pretty complex example. I would prefer something more minimal like I suggested earlier, e.g.,
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The reason to make this test something simpler is reduce the redundancy with the other |
||
| eager = self.eager_array.dot(self.eager_array[0]) | ||
| lazy = self.lazy_array.dot(self.lazy_array[0]) | ||
| self.assertLazyAndAllClose(eager, lazy) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1792,3 +1792,41 @@ def test_full_like(self): | |
| actual = _full_like(DataArray([1, 2, 3]), fill_value=np.nan) | ||
| self.assertEqual(actual.dtype, np.float) | ||
| np.testing.assert_equal(actual.values, np.nan) | ||
|
|
||
| def test_dot(self): | ||
| x = np.linspace(-3, 3, 6) | ||
| y = np.linspace(-3, 3, 5) | ||
| z = range(4) | ||
| da_vals = np.arange(6 * 5 * 4).reshape((6, 5, 4)) | ||
| da = DataArray(da_vals, coords=[x, y, z], dims=['x', 'y', 'z']) | ||
|
|
||
| dm_vals = range(4) | ||
| dm = DataArray(dm_vals, coords=[z], dims=['z']) | ||
|
|
||
| # nd dot 1d | ||
| actual = da.dot(dm) | ||
| expected_vals = np.tensordot(da_vals, dm_vals, [2, 0]) | ||
| expected = DataArray(expected_vals, coords=[x, y], dims=['x', 'y']) | ||
| self.assertDataArrayEqual(expected, actual) | ||
|
|
||
| # all shared dims | ||
| actual = da.dot(da) | ||
| expected_vals = np.tensordot(da_vals, da_vals, axes=([0, 1, 2], [0, 1, 2])) | ||
| expected = DataArray(expected_vals) | ||
| self.assertDataArrayEqual(expected, actual) | ||
|
|
||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. only one space at the end of the file is necessary :) |
||
| # multiple shared dims | ||
| dm_vals = np.arange(20 * 5 * 4).reshape((20, 5, 4)) | ||
| j = np.linspace(-3, 3, 20) | ||
| dm = DataArray(dm_vals, coords=[j, y, z], dims=['j', 'y', 'z']) | ||
| actual = da.dot(dm) | ||
| expected_vals = np.tensordot(da_vals, dm_vals, axes=([1, 2], [1, 2])) | ||
| expected = DataArray(expected_vals, coords=[x, j], dims=['x', 'j']) | ||
| self.assertDataArrayEqual(expected, actual) | ||
|
|
||
| with self.assertRaises(NotImplementedError): | ||
| da.dot(dm.to_dataset(name='dm')) | ||
| with self.assertRaises(TypeError): | ||
| da.dot(dm.values) | ||
| with self.assertRaisesRegexp(ValueError, 'no shared dimensions'): | ||
| da.dot(DataArray(1)) | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it would be fine to remove lines 1413:1415, and change 1417 to something like:
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, could go either way here. I like raising
NotImplementedErrorbecause it makes it more obvious to users that this behavior might change in the future. Besides, if we get lucky a user will notice the error and then implement the missing functionality themselves ;).