Skip to content

Commit 5433d9d

Browse files
authored
Fix issue with expand_dims on 0d arrays of tuples (#867)
* Fix issue with expand_dims on 0d arrays of tuples This was originally reported on the mailing list: https://groups.google.com/forum/#!topic/xarray/fz7HHgpgwk0 Unfortunately, the fix requires a backwards compatibility break, because it changes how 0d object arrays are handled: In [4]: xr.Variable([], object()).values Out[4]: array(<object object at 0x10072e2e0>, dtype=object) Previously, we just returned the original object. * Add another test * Fix issue with squeeze() on object arrays of lists * another test * another test
1 parent 6aedf98 commit 5433d9d

File tree

6 files changed

+137
-45
lines changed

6 files changed

+137
-45
lines changed

doc/whats-new.rst

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,16 @@ Breaking changes
2626
~~~~~~~~~~~~~~~~
2727

2828
- Dropped support for Python 2.6 (:issue:`855`).
29-
- Indexing on multi-index now drop levels, which is consitent with pandas.
29+
- Indexing on multi-index now drop levels, which is consistent with pandas.
3030
It also changes the name of the dimension / coordinate when the multi-index is
31-
reduced to a single index.
32-
- Contour plots no longer add a colorbar per default (:issue:`866`).
31+
reduced to a single index (:issue:`802`).
32+
- Contour plots no longer add a colorbar per default (:issue:`866`). Filled
33+
contour plots are unchanged.
34+
- ``DataArray.values`` and ``.data`` now always returns an NumPy array-like
35+
object, even for 0-dimensional arrays with object dtype (:issue:`867`).
36+
Previously, ``.values`` returned native Python objects in such cases. To
37+
convert the values of scalar arrays to Python objects, use the ``.item()``
38+
method.
3339

3440
Enhancements
3541
~~~~~~~~~~~~
@@ -102,6 +108,9 @@ Bug fixes
102108
- ``Variable.copy(deep=True)`` no longer converts MultiIndex into a base Index
103109
(:issue:`769`). By `Benoit Bovy <https://github.com/benbovy>`_.
104110

111+
- Fixes for groupby on dimensions with a multi-index (:issue:`867`). By
112+
`Stephan Hoyer <https://github.com/shoyer>`_.
113+
105114
- Fix printing datasets with unicode attributes on Python 2 (:issue:`892`). By
106115
`Stephan Hoyer <https://github.com/shoyer>`_.
107116

xarray/core/indexing.py

Lines changed: 27 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ def _asarray_tuplesafe(values):
125125
Adapted from pandas.core.common._asarray_tuplesafe
126126
"""
127127
if isinstance(values, tuple):
128-
result = utils.tuple_to_0darray(values)
128+
result = utils.to_0d_object_array(values)
129129
else:
130130
result = np.asarray(values)
131131
if result.ndim == 2:
@@ -396,9 +396,18 @@ def _convert_key(self, key):
396396
key = orthogonal_indexer(key, self.shape)
397397
return key
398398

399+
def _ensure_ndarray(self, value):
400+
# We always want the result of indexing to be a NumPy array. If it's
401+
# not, then it really should be a 0d array. Doing the coercion here
402+
# instead of inside variable.as_compatible_data makes it less error
403+
# prone.
404+
if not isinstance(value, np.ndarray):
405+
value = utils.to_0d_array(value)
406+
return value
407+
399408
def __getitem__(self, key):
400409
key = self._convert_key(key)
401-
return self.array[key]
410+
return self._ensure_ndarray(self.array[key])
402411

403412
def __setitem__(self, key, value):
404413
key = self._convert_key(key)
@@ -469,16 +478,22 @@ def __getitem__(self, key):
469478

470479
if isinstance(result, pd.Index):
471480
result = PandasIndexAdapter(result, dtype=self.dtype)
472-
elif result is pd.NaT:
473-
# work around the impossibility of casting NaT with asarray
474-
# note: it probably would be better in general to return
475-
# pd.Timestamp rather np.than datetime64 but this is easier
476-
# (for now)
477-
result = np.datetime64('NaT', 'ns')
478-
elif isinstance(result, timedelta):
479-
result = np.timedelta64(getattr(result, 'value', result), 'ns')
480-
elif self.dtype != object:
481-
result = np.asarray(result, dtype=self.dtype)
481+
else:
482+
# result is a scalar
483+
if result is pd.NaT:
484+
# work around the impossibility of casting NaT with asarray
485+
# note: it probably would be better in general to return
486+
# pd.Timestamp rather np.than datetime64 but this is easier
487+
# (for now)
488+
result = np.datetime64('NaT', 'ns')
489+
elif isinstance(result, timedelta):
490+
result = np.timedelta64(getattr(result, 'value', result), 'ns')
491+
elif self.dtype != object:
492+
result = np.asarray(result, dtype=self.dtype)
493+
494+
# as for numpy.ndarray indexing, we always want the result to be
495+
# a NumPy array.
496+
result = utils.to_0d_array(result)
482497

483498
return result
484499

xarray/core/utils.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -178,13 +178,23 @@ def is_valid_numpy_dtype(dtype):
178178
return True
179179

180180

181-
def tuple_to_0darray(value):
181+
def to_0d_object_array(value):
182+
"""Given a value, wrap it in a 0-D numpy.ndarray with dtype=object."""
182183
result = np.empty((1,), dtype=object)
183184
result[:] = [value]
184185
result.shape = ()
185186
return result
186187

187188

189+
def to_0d_array(value):
190+
"""Given a value, wrap it in a 0-D numpy.ndarray."""
191+
if np.isscalar(value) or (isinstance(value, np.ndarray)
192+
and value.ndim == 0):
193+
return np.array(value)
194+
else:
195+
return to_0d_object_array(value)
196+
197+
188198
def dict_equiv(first, second, compat=equivalent):
189199
"""Test equivalence of two dict-like objects. If any of the values are
190200
numpy arrays, compare them correctly.

xarray/core/variable.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ def as_compatible_data(data, fastpath=False):
119119
return _maybe_wrap_data(data)
120120

121121
if isinstance(data, tuple):
122-
data = utils.tuple_to_0darray(data)
122+
data = utils.to_0d_object_array(data)
123123

124124
if isinstance(data, pd.Timestamp):
125125
# TODO: convert, handle datetime objects, too
@@ -159,19 +159,21 @@ def as_compatible_data(data, fastpath=False):
159159

160160
def _as_array_or_item(data):
161161
"""Return the given values as a numpy array, or as an individual item if
162-
it's a 0-dimensional object array or datetime64.
162+
it's a 0d datetime64 or timedelta64 array.
163163
164164
Importantly, this function does not copy data if it is already an ndarray -
165165
otherwise, it will not be possible to update Variable values in place.
166+
167+
This function mostly exists because 0-dimensional ndarrays with
168+
dtype=datetime64 are broken :(
169+
https://github.com/numpy/numpy/issues/4337
170+
https://github.com/numpy/numpy/issues/7619
171+
172+
TODO: remove this (replace with np.asarray) once these issues are fixed
166173
"""
167174
data = np.asarray(data)
168175
if data.ndim == 0:
169-
if data.dtype.kind == 'O':
170-
# unpack 0d object arrays to be consistent with numpy
171-
data = data.item()
172-
elif data.dtype.kind == 'M':
173-
# convert to a np.datetime64 object, because 0-dimensional ndarrays
174-
# with dtype=datetime64 are broken :(
176+
if data.dtype.kind == 'M':
175177
data = np.datetime64(data, 'ns')
176178
elif data.dtype.kind == 'm':
177179
data = np.timedelta64(data, 'ns')

xarray/test/test_groupby.py

Lines changed: 34 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,46 @@
1+
import numpy as np
2+
import xarray as xr
13
from xarray.core.groupby import _consolidate_slices
24

35
import pytest
46

57

68
def test_consolidate_slices():
79

8-
assert _consolidate_slices([slice(3), slice(3, 5)]) == [slice(5)]
9-
assert _consolidate_slices([slice(2, 3), slice(3, 6)]) == [slice(2, 6)]
10-
assert (_consolidate_slices([slice(2, 3, 1), slice(3, 6, 1)])
10+
assert _consolidate_slices([slice(3), slice(3, 5)]) == [slice(5)]
11+
assert _consolidate_slices([slice(2, 3), slice(3, 6)]) == [slice(2, 6)]
12+
assert (_consolidate_slices([slice(2, 3, 1), slice(3, 6, 1)])
1113
== [slice(2, 6, 1)])
1214

13-
slices = [slice(2, 3), slice(5, 6)]
14-
assert _consolidate_slices(slices) == slices
15-
16-
with pytest.raises(ValueError):
17-
_consolidate_slices([slice(3), 4])
15+
slices = [slice(2, 3), slice(5, 6)]
16+
assert _consolidate_slices(slices) == slices
17+
18+
with pytest.raises(ValueError):
19+
_consolidate_slices([slice(3), 4])
20+
21+
22+
def test_multi_index_groupby_apply():
23+
# regression test for GH873
24+
ds = xr.Dataset({'foo': (('x', 'y'), np.random.randn(3, 4))},
25+
{'x': ['a', 'b', 'c'], 'y': [1, 2, 3, 4]})
26+
doubled = 2 * ds
27+
group_doubled = (ds.stack(space=['x', 'y'])
28+
.groupby('space')
29+
.apply(lambda x: 2 * x)
30+
.unstack('space'))
31+
assert doubled.equals(group_doubled)
32+
33+
34+
def test_multi_index_groupby_sum():
35+
# regression test for GH873
36+
ds = xr.Dataset({'foo': (('x', 'y', 'z'), np.ones((3, 4, 2)))},
37+
{'x': ['a', 'b', 'c'], 'y': [1, 2, 3, 4]})
38+
expected = ds.sum('z')
39+
actual = (ds.stack(space=['x', 'y'])
40+
.groupby('space')
41+
.sum('z')
42+
.unstack('space'))
43+
assert expected.equals(actual)
1844

1945

2046
# TODO: move other groupby tests from test_dataset and test_dataarray over here

xarray/test/test_variable.py

Lines changed: 43 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -47,8 +47,8 @@ def test_getitem_dict(self):
4747
expected = v[0]
4848
self.assertVariableIdentical(expected, actual)
4949

50-
def assertIndexedLikeNDArray(self, variable, expected_value0,
51-
expected_dtype=None):
50+
def _assertIndexedLikeNDArray(self, variable, expected_value0,
51+
expected_dtype=None):
5252
"""Given a 1-dimensional variable, verify that the variable is indexed
5353
like a numpy.ndarray.
5454
"""
@@ -66,52 +66,52 @@ def assertIndexedLikeNDArray(self, variable, expected_value0,
6666
# check output type instead of array dtype
6767
self.assertEqual(type(variable.values[0]), type(expected_value0))
6868
self.assertEqual(type(variable[0].values), type(expected_value0))
69-
else:
69+
elif expected_dtype is not False:
7070
self.assertEqual(variable.values[0].dtype, expected_dtype)
7171
self.assertEqual(variable[0].values.dtype, expected_dtype)
7272

7373
def test_index_0d_int(self):
7474
for value, dtype in [(0, np.int_),
7575
(np.int32(0), np.int32)]:
7676
x = self.cls(['x'], [value])
77-
self.assertIndexedLikeNDArray(x, value, dtype)
77+
self._assertIndexedLikeNDArray(x, value, dtype)
7878

7979
def test_index_0d_float(self):
8080
for value, dtype in [(0.5, np.float_),
8181
(np.float32(0.5), np.float32)]:
8282
x = self.cls(['x'], [value])
83-
self.assertIndexedLikeNDArray(x, value, dtype)
83+
self._assertIndexedLikeNDArray(x, value, dtype)
8484

8585
def test_index_0d_string(self):
8686
for value, dtype in [('foo', np.dtype('U3' if PY3 else 'S3')),
8787
(u'foo', np.dtype('U3'))]:
8888
x = self.cls(['x'], [value])
89-
self.assertIndexedLikeNDArray(x, value, dtype)
89+
self._assertIndexedLikeNDArray(x, value, dtype)
9090

9191
def test_index_0d_datetime(self):
9292
d = datetime(2000, 1, 1)
9393
x = self.cls(['x'], [d])
94-
self.assertIndexedLikeNDArray(x, np.datetime64(d))
94+
self._assertIndexedLikeNDArray(x, np.datetime64(d))
9595

9696
x = self.cls(['x'], [np.datetime64(d)])
97-
self.assertIndexedLikeNDArray(x, np.datetime64(d), 'datetime64[ns]')
97+
self._assertIndexedLikeNDArray(x, np.datetime64(d), 'datetime64[ns]')
9898

9999
x = self.cls(['x'], pd.DatetimeIndex([d]))
100-
self.assertIndexedLikeNDArray(x, np.datetime64(d), 'datetime64[ns]')
100+
self._assertIndexedLikeNDArray(x, np.datetime64(d), 'datetime64[ns]')
101101

102102
def test_index_0d_timedelta64(self):
103103
td = timedelta(hours=1)
104104

105105
x = self.cls(['x'], [np.timedelta64(td)])
106-
self.assertIndexedLikeNDArray(x, np.timedelta64(td), 'timedelta64[ns]')
106+
self._assertIndexedLikeNDArray(x, np.timedelta64(td), 'timedelta64[ns]')
107107

108108
x = self.cls(['x'], pd.to_timedelta([td]))
109-
self.assertIndexedLikeNDArray(x, np.timedelta64(td), 'timedelta64[ns]')
109+
self._assertIndexedLikeNDArray(x, np.timedelta64(td), 'timedelta64[ns]')
110110

111111
def test_index_0d_not_a_time(self):
112112
d = np.datetime64('NaT', 'ns')
113113
x = self.cls(['x'], [d])
114-
self.assertIndexedLikeNDArray(x, d, None)
114+
self._assertIndexedLikeNDArray(x, d)
115115

116116
def test_index_0d_object(self):
117117

@@ -130,7 +130,15 @@ def __repr__(self):
130130

131131
item = HashableItemWrapper((1, 2, 3))
132132
x = self.cls('x', [item])
133-
self.assertIndexedLikeNDArray(x, item)
133+
self._assertIndexedLikeNDArray(x, item, expected_dtype=False)
134+
135+
def test_0d_object_array_with_list(self):
136+
listarray = np.empty((1,), dtype=object)
137+
listarray[0] = [1, 2, 3]
138+
x = self.cls('x', listarray)
139+
assert x.data == listarray
140+
assert x[0].data == listarray.squeeze()
141+
assert x.squeeze().data == listarray.squeeze()
134142

135143
def test_index_and_concat_datetime(self):
136144
# regression test for #125
@@ -729,6 +737,19 @@ def test_transpose(self):
729737
w3 = Variable(['b', 'c', 'd', 'a'], np.einsum('abcd->bcda', x))
730738
self.assertVariableIdentical(w, w3.transpose('a', 'b', 'c', 'd'))
731739

740+
def test_transpose_0d(self):
741+
for value in [
742+
3.5,
743+
('a', 1),
744+
np.datetime64('2000-01-01'),
745+
np.timedelta64(1, 'h'),
746+
None,
747+
object(),
748+
]:
749+
variable = Variable([], value)
750+
actual = variable.transpose()
751+
assert actual.identical(variable)
752+
732753
def test_squeeze(self):
733754
v = Variable(['x', 'y'], [[1]])
734755
self.assertVariableIdentical(Variable([], 1), v.squeeze())
@@ -773,6 +794,15 @@ def test_expand_dims(self):
773794
with self.assertRaisesRegexp(ValueError, 'must be a superset'):
774795
v.expand_dims(['z'])
775796

797+
def test_expand_dims_object_dtype(self):
798+
v = Variable([], ('a', 1))
799+
actual = v.expand_dims(('x',), (3,))
800+
exp_values = np.empty((3,), dtype=object)
801+
for i in range(3):
802+
exp_values[i] = ('a', 1)
803+
expected = Variable(['x'], exp_values)
804+
assert actual.identical(expected)
805+
776806
def test_stack(self):
777807
v = Variable(['x', 'y'], [[0, 1], [2, 3]], {'foo': 'bar'})
778808
actual = v.stack(z=('x', 'y'))

0 commit comments

Comments
 (0)