Skip to content

Commit 220adbc

Browse files
fujiisoupshoyer
authored andcommitted
sparse option to reindex and unstack (#3542)
* Added fill_value for unstack * remove sparse option and fix unintended changes * a bug fix * Added sparse option to unstack and reindex * black * More tests * black * Remove sparse option from reindex * try __array_function__ where * flake8
1 parent dc559ea commit 220adbc

File tree

7 files changed

+113
-4
lines changed

7 files changed

+113
-4
lines changed

doc/whats-new.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,10 @@ Breaking changes
3737

3838
New Features
3939
~~~~~~~~~~~~
40+
- Added the ``sparse`` option to :py:meth:`~xarray.DataArray.unstack`,
41+
:py:meth:`~xarray.Dataset.unstack`, :py:meth:`~xarray.DataArray.reindex`,
42+
:py:meth:`~xarray.Dataset.reindex` (:issue:`3518`).
43+
By `Keisuke Fujii <https://github.com/fujiisoup>`_.
4044

4145
- Added the ``max_gap`` kwarg to :py:meth:`DataArray.interpolate_na` and
4246
:py:meth:`Dataset.interpolate_na`. This controls the maximum size of the data

xarray/core/alignment.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -466,6 +466,7 @@ def reindex_variables(
466466
tolerance: Any = None,
467467
copy: bool = True,
468468
fill_value: Optional[Any] = dtypes.NA,
469+
sparse: bool = False,
469470
) -> Tuple[Dict[Hashable, Variable], Dict[Hashable, pd.Index]]:
470471
"""Conform a dictionary of aligned variables onto a new set of variables,
471472
filling in missing values with NaN.
@@ -503,6 +504,8 @@ def reindex_variables(
503504
the input. In either case, new xarray objects are always returned.
504505
fill_value : scalar, optional
505506
Value to use for newly missing values
507+
sparse: bool, optional
508+
Use an sparse-array
506509
507510
Returns
508511
-------
@@ -571,6 +574,8 @@ def reindex_variables(
571574

572575
for name, var in variables.items():
573576
if name not in indexers:
577+
if sparse:
578+
var = var._as_sparse(fill_value=fill_value)
574579
key = tuple(
575580
slice(None) if d in unchanged_dims else int_indexers.get(d, slice(None))
576581
for d in var.dims

xarray/core/dataarray.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1729,6 +1729,7 @@ def unstack(
17291729
self,
17301730
dim: Union[Hashable, Sequence[Hashable], None] = None,
17311731
fill_value: Any = dtypes.NA,
1732+
sparse: bool = False,
17321733
) -> "DataArray":
17331734
"""
17341735
Unstack existing dimensions corresponding to MultiIndexes into
@@ -1742,6 +1743,7 @@ def unstack(
17421743
Dimension(s) over which to unstack. By default unstacks all
17431744
MultiIndexes.
17441745
fill_value: value to be filled. By default, np.nan
1746+
sparse: use sparse-array if True
17451747
17461748
Returns
17471749
-------
@@ -1773,7 +1775,7 @@ def unstack(
17731775
--------
17741776
DataArray.stack
17751777
"""
1776-
ds = self._to_temp_dataset().unstack(dim, fill_value)
1778+
ds = self._to_temp_dataset().unstack(dim, fill_value, sparse)
17771779
return self._from_temp_dataset(ds)
17781780

17791781
def to_unstacked_dataset(self, dim, level=0):

xarray/core/dataset.py

Lines changed: 32 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2286,6 +2286,7 @@ def reindex(
22862286
the input. In either case, a new xarray object is always returned.
22872287
fill_value : scalar, optional
22882288
Value to use for newly missing values
2289+
sparse: use sparse-array. By default, False
22892290
**indexers_kwarg : {dim: indexer, ...}, optional
22902291
Keyword arguments in the same form as ``indexers``.
22912292
One of indexers or indexers_kwargs must be provided.
@@ -2428,6 +2429,29 @@ def reindex(
24282429
the original and desired indexes. If you do want to fill in the `NaN` values present in the
24292430
original dataset, use the :py:meth:`~Dataset.fillna()` method.
24302431
2432+
"""
2433+
return self._reindex(
2434+
indexers,
2435+
method,
2436+
tolerance,
2437+
copy,
2438+
fill_value,
2439+
sparse=False,
2440+
**indexers_kwargs,
2441+
)
2442+
2443+
def _reindex(
2444+
self,
2445+
indexers: Mapping[Hashable, Any] = None,
2446+
method: str = None,
2447+
tolerance: Number = None,
2448+
copy: bool = True,
2449+
fill_value: Any = dtypes.NA,
2450+
sparse: bool = False,
2451+
**indexers_kwargs: Any,
2452+
) -> "Dataset":
2453+
"""
2454+
same to _reindex but support sparse option
24312455
"""
24322456
indexers = utils.either_dict_or_kwargs(indexers, indexers_kwargs, "reindex")
24332457

@@ -2444,6 +2468,7 @@ def reindex(
24442468
tolerance,
24452469
copy=copy,
24462470
fill_value=fill_value,
2471+
sparse=sparse,
24472472
)
24482473
coord_names = set(self._coord_names)
24492474
coord_names.update(indexers)
@@ -3327,7 +3352,7 @@ def ensure_stackable(val):
33273352

33283353
return data_array
33293354

3330-
def _unstack_once(self, dim: Hashable, fill_value) -> "Dataset":
3355+
def _unstack_once(self, dim: Hashable, fill_value, sparse) -> "Dataset":
33313356
index = self.get_index(dim)
33323357
index = index.remove_unused_levels()
33333358
full_idx = pd.MultiIndex.from_product(index.levels, names=index.names)
@@ -3336,7 +3361,9 @@ def _unstack_once(self, dim: Hashable, fill_value) -> "Dataset":
33363361
if index.equals(full_idx):
33373362
obj = self
33383363
else:
3339-
obj = self.reindex({dim: full_idx}, copy=False, fill_value=fill_value)
3364+
obj = self._reindex(
3365+
{dim: full_idx}, copy=False, fill_value=fill_value, sparse=sparse
3366+
)
33403367

33413368
new_dim_names = index.names
33423369
new_dim_sizes = [lev.size for lev in index.levels]
@@ -3366,6 +3393,7 @@ def unstack(
33663393
self,
33673394
dim: Union[Hashable, Iterable[Hashable]] = None,
33683395
fill_value: Any = dtypes.NA,
3396+
sparse: bool = False,
33693397
) -> "Dataset":
33703398
"""
33713399
Unstack existing dimensions corresponding to MultiIndexes into
@@ -3379,6 +3407,7 @@ def unstack(
33793407
Dimension(s) over which to unstack. By default unstacks all
33803408
MultiIndexes.
33813409
fill_value: value to be filled. By default, np.nan
3410+
sparse: use sparse-array if True
33823411
33833412
Returns
33843413
-------
@@ -3416,7 +3445,7 @@ def unstack(
34163445

34173446
result = self.copy(deep=False)
34183447
for dim in dims:
3419-
result = result._unstack_once(dim, fill_value)
3448+
result = result._unstack_once(dim, fill_value, sparse)
34203449
return result
34213450

34223451
def update(self, other: "CoercibleMapping", inplace: bool = None) -> "Dataset":

xarray/core/variable.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -993,6 +993,36 @@ def chunk(self, chunks=None, name=None, lock=False):
993993

994994
return type(self)(self.dims, data, self._attrs, self._encoding, fastpath=True)
995995

996+
def _as_sparse(self, sparse_format=_default, fill_value=dtypes.NA):
997+
"""
998+
use sparse-array as backend.
999+
"""
1000+
import sparse
1001+
1002+
# TODO what to do if dask-backended?
1003+
if fill_value is dtypes.NA:
1004+
dtype, fill_value = dtypes.maybe_promote(self.dtype)
1005+
else:
1006+
dtype = dtypes.result_type(self.dtype, fill_value)
1007+
1008+
if sparse_format is _default:
1009+
sparse_format = "coo"
1010+
try:
1011+
as_sparse = getattr(sparse, "as_{}".format(sparse_format.lower()))
1012+
except AttributeError:
1013+
raise ValueError("{} is not a valid sparse format".format(sparse_format))
1014+
1015+
data = as_sparse(self.data.astype(dtype), fill_value=fill_value)
1016+
return self._replace(data=data)
1017+
1018+
def _to_dense(self):
1019+
"""
1020+
Change backend from sparse to np.array
1021+
"""
1022+
if hasattr(self._data, "todense"):
1023+
return self._replace(data=self._data.todense())
1024+
return self.copy(deep=False)
1025+
9961026
def isel(
9971027
self: VariableType,
9981028
indexers: Mapping[Hashable, Any] = None,
@@ -2021,6 +2051,14 @@ def chunk(self, chunks=None, name=None, lock=False):
20212051
# Dummy - do not chunk. This method is invoked e.g. by Dataset.chunk()
20222052
return self.copy(deep=False)
20232053

2054+
def _as_sparse(self, sparse_format=_default, fill_value=_default):
2055+
# Dummy
2056+
return self.copy(deep=False)
2057+
2058+
def _to_dense(self):
2059+
# Dummy
2060+
return self.copy(deep=False)
2061+
20242062
def _finalize_indexing_result(self, dims, data):
20252063
if getattr(data, "ndim", 0) != 1:
20262064
# returns Variable rather than IndexVariable if multi-dimensional

xarray/tests/test_dataset.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2811,6 +2811,25 @@ def test_unstack_fill_value(self):
28112811
expected = ds["var"].unstack("index").fillna(-1).astype(np.int)
28122812
assert actual.equals(expected)
28132813

2814+
@requires_sparse
2815+
def test_unstack_sparse(self):
2816+
ds = xr.Dataset(
2817+
{"var": (("x",), np.arange(6))},
2818+
coords={"x": [0, 1, 2] * 2, "y": (("x",), ["a"] * 3 + ["b"] * 3)},
2819+
)
2820+
# make ds incomplete
2821+
ds = ds.isel(x=[0, 2, 3, 4]).set_index(index=["x", "y"])
2822+
# test fill_value
2823+
actual = ds.unstack("index", sparse=True)
2824+
expected = ds.unstack("index")
2825+
assert actual["var"].variable._to_dense().equals(expected["var"].variable)
2826+
assert actual["var"].data.density < 1.0
2827+
2828+
actual = ds["var"].unstack("index", sparse=True)
2829+
expected = ds["var"].unstack("index")
2830+
assert actual.variable._to_dense().equals(expected.variable)
2831+
assert actual.data.density < 1.0
2832+
28142833
def test_stack_unstack_fast(self):
28152834
ds = Dataset(
28162835
{

xarray/tests/test_variable.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
assert_identical,
3434
raises_regex,
3535
requires_dask,
36+
requires_sparse,
3637
source_ndarray,
3738
)
3839

@@ -1862,6 +1863,17 @@ def test_getitem_with_mask_nd_indexer(self):
18621863
)
18631864

18641865

1866+
@requires_sparse
1867+
class TestVariableWithSparse:
1868+
# TODO inherit VariableSubclassobjects to cover more tests
1869+
1870+
def test_as_sparse(self):
1871+
data = np.arange(12).reshape(3, 4)
1872+
var = Variable(("x", "y"), data)._as_sparse(fill_value=-1)
1873+
actual = var._to_dense()
1874+
assert_identical(var, actual)
1875+
1876+
18651877
class TestIndexVariable(VariableSubclassobjects):
18661878
cls = staticmethod(IndexVariable)
18671879

0 commit comments

Comments
 (0)