Skip to content

Commit 7fb3b19

Browse files
griveratshoyer
authored andcommitted
Accept int value in head, thin and tail (#3298)
* Accept int value in head, thin and tail * Fix typing * Remove thin def val and add suggestions * Fix typing and change raise message
1 parent e90e8bc commit 7fb3b19

File tree

4 files changed

+171
-31
lines changed

4 files changed

+171
-31
lines changed

xarray/core/dataarray.py

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1035,52 +1035,54 @@ def sel(
10351035
return self._from_temp_dataset(ds)
10361036

10371037
def head(
1038-
self, indexers: Mapping[Hashable, Any] = None, **indexers_kwargs: Any
1038+
self,
1039+
indexers: Union[Mapping[Hashable, int], int] = None,
1040+
**indexers_kwargs: Any
10391041
) -> "DataArray":
10401042
"""Return a new DataArray whose data is given by the the first `n`
1041-
values along the specified dimension(s).
1043+
values along the specified dimension(s). Default `n` = 5
10421044
10431045
See Also
10441046
--------
10451047
Dataset.head
10461048
DataArray.tail
10471049
DataArray.thin
10481050
"""
1049-
1050-
indexers = either_dict_or_kwargs(indexers, indexers_kwargs, "head")
1051-
ds = self._to_temp_dataset().head(indexers=indexers)
1051+
ds = self._to_temp_dataset().head(indexers, **indexers_kwargs)
10521052
return self._from_temp_dataset(ds)
10531053

10541054
def tail(
1055-
self, indexers: Mapping[Hashable, Any] = None, **indexers_kwargs: Any
1055+
self,
1056+
indexers: Union[Mapping[Hashable, int], int] = None,
1057+
**indexers_kwargs: Any
10561058
) -> "DataArray":
10571059
"""Return a new DataArray whose data is given by the the last `n`
1058-
values along the specified dimension(s).
1060+
values along the specified dimension(s). Default `n` = 5
10591061
10601062
See Also
10611063
--------
10621064
Dataset.tail
10631065
DataArray.head
10641066
DataArray.thin
10651067
"""
1066-
indexers = either_dict_or_kwargs(indexers, indexers_kwargs, "tail")
1067-
ds = self._to_temp_dataset().tail(indexers=indexers)
1068+
ds = self._to_temp_dataset().tail(indexers, **indexers_kwargs)
10681069
return self._from_temp_dataset(ds)
10691070

10701071
def thin(
1071-
self, indexers: Mapping[Hashable, Any] = None, **indexers_kwargs: Any
1072+
self,
1073+
indexers: Union[Mapping[Hashable, int], int] = None,
1074+
**indexers_kwargs: Any
10721075
) -> "DataArray":
10731076
"""Return a new DataArray whose data is given by each `n` value
1074-
along the specified dimension(s).
1077+
along the specified dimension(s). Default `n` = 5
10751078
10761079
See Also
10771080
--------
10781081
Dataset.thin
10791082
DataArray.head
10801083
DataArray.tail
10811084
"""
1082-
indexers = either_dict_or_kwargs(indexers, indexers_kwargs, "thin")
1083-
ds = self._to_temp_dataset().thin(indexers=indexers)
1085+
ds = self._to_temp_dataset().thin(indexers, **indexers_kwargs)
10841086
return self._from_temp_dataset(ds)
10851087

10861088
def broadcast_like(

xarray/core/dataset.py

Lines changed: 81 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -2009,15 +2009,18 @@ def sel(
20092009
return result._overwrite_indexes(new_indexes)
20102010

20112011
def head(
2012-
self, indexers: Mapping[Hashable, Any] = None, **indexers_kwargs: Any
2012+
self,
2013+
indexers: Union[Mapping[Hashable, int], int] = None,
2014+
**indexers_kwargs: Any
20132015
) -> "Dataset":
20142016
"""Returns a new dataset with the first `n` values of each array
20152017
for the specified dimension(s).
20162018
20172019
Parameters
20182020
----------
2019-
indexers : dict, optional
2020-
A dict with keys matching dimensions and integer values `n`.
2021+
indexers : dict or int, default: 5
2022+
A dict with keys matching dimensions and integer values `n`
2023+
or a single integer `n` applied over all dimensions.
20212024
One of indexers or indexers_kwargs must be provided.
20222025
**indexers_kwargs : {dim: n, ...}, optional
20232026
The keyword arguments form of ``indexers``.
@@ -2030,20 +2033,41 @@ def head(
20302033
Dataset.thin
20312034
DataArray.head
20322035
"""
2036+
if not indexers_kwargs:
2037+
if indexers is None:
2038+
indexers = 5
2039+
if not isinstance(indexers, int) and not is_dict_like(indexers):
2040+
raise TypeError("indexers must be either dict-like or a single integer")
2041+
if isinstance(indexers, int):
2042+
indexers = {dim: indexers for dim in self.dims}
20332043
indexers = either_dict_or_kwargs(indexers, indexers_kwargs, "head")
2034-
indexers = {k: slice(val) for k, val in indexers.items()}
2035-
return self.isel(indexers)
2044+
for k, v in indexers.items():
2045+
if not isinstance(v, int):
2046+
raise TypeError(
2047+
"expected integer type indexer for "
2048+
"dimension %r, found %r" % (k, type(v))
2049+
)
2050+
elif v < 0:
2051+
raise ValueError(
2052+
"expected positive integer as indexer "
2053+
"for dimension %r, found %s" % (k, v)
2054+
)
2055+
indexers_slices = {k: slice(val) for k, val in indexers.items()}
2056+
return self.isel(indexers_slices)
20362057

20372058
def tail(
2038-
self, indexers: Mapping[Hashable, Any] = None, **indexers_kwargs: Any
2059+
self,
2060+
indexers: Union[Mapping[Hashable, int], int] = None,
2061+
**indexers_kwargs: Any
20392062
) -> "Dataset":
20402063
"""Returns a new dataset with the last `n` values of each array
20412064
for the specified dimension(s).
20422065
20432066
Parameters
20442067
----------
2045-
indexers : dict, optional
2046-
A dict with keys matching dimensions and integer values `n`.
2068+
indexers : dict or int, default: 5
2069+
A dict with keys matching dimensions and integer values `n`
2070+
or a single integer `n` applied over all dimensions.
20472071
One of indexers or indexers_kwargs must be provided.
20482072
**indexers_kwargs : {dim: n, ...}, optional
20492073
The keyword arguments form of ``indexers``.
@@ -2056,24 +2080,44 @@ def tail(
20562080
Dataset.thin
20572081
DataArray.tail
20582082
"""
2059-
2083+
if not indexers_kwargs:
2084+
if indexers is None:
2085+
indexers = 5
2086+
if not isinstance(indexers, int) and not is_dict_like(indexers):
2087+
raise TypeError("indexers must be either dict-like or a single integer")
2088+
if isinstance(indexers, int):
2089+
indexers = {dim: indexers for dim in self.dims}
20602090
indexers = either_dict_or_kwargs(indexers, indexers_kwargs, "tail")
2061-
indexers = {
2091+
for k, v in indexers.items():
2092+
if not isinstance(v, int):
2093+
raise TypeError(
2094+
"expected integer type indexer for "
2095+
"dimension %r, found %r" % (k, type(v))
2096+
)
2097+
elif v < 0:
2098+
raise ValueError(
2099+
"expected positive integer as indexer "
2100+
"for dimension %r, found %s" % (k, v)
2101+
)
2102+
indexers_slices = {
20622103
k: slice(-val, None) if val != 0 else slice(val)
20632104
for k, val in indexers.items()
20642105
}
2065-
return self.isel(indexers)
2106+
return self.isel(indexers_slices)
20662107

20672108
def thin(
2068-
self, indexers: Mapping[Hashable, Any] = None, **indexers_kwargs: Any
2109+
self,
2110+
indexers: Union[Mapping[Hashable, int], int] = None,
2111+
**indexers_kwargs: Any
20692112
) -> "Dataset":
20702113
"""Returns a new dataset with each array indexed along every `n`th
20712114
value for the specified dimension(s)
20722115
20732116
Parameters
20742117
----------
2075-
indexers : dict, optional
2076-
A dict with keys matching dimensions and integer values `n`.
2118+
indexers : dict or int, default: 5
2119+
A dict with keys matching dimensions and integer values `n`
2120+
or a single integer `n` applied over all dimensions.
20772121
One of indexers or indexers_kwargs must be provided.
20782122
**indexers_kwargs : {dim: n, ...}, optional
20792123
The keyword arguments form of ``indexers``.
@@ -2086,11 +2130,30 @@ def thin(
20862130
Dataset.tail
20872131
DataArray.thin
20882132
"""
2133+
if (
2134+
not indexers_kwargs
2135+
and not isinstance(indexers, int)
2136+
and not is_dict_like(indexers)
2137+
):
2138+
raise TypeError("indexers must be either dict-like or a single integer")
2139+
if isinstance(indexers, int):
2140+
indexers = {dim: indexers for dim in self.dims}
20892141
indexers = either_dict_or_kwargs(indexers, indexers_kwargs, "thin")
2090-
if 0 in indexers.values():
2091-
raise ValueError("step cannot be zero")
2092-
indexers = {k: slice(None, None, val) for k, val in indexers.items()}
2093-
return self.isel(indexers)
2142+
for k, v in indexers.items():
2143+
if not isinstance(v, int):
2144+
raise TypeError(
2145+
"expected integer type indexer for "
2146+
"dimension %r, found %r" % (k, type(v))
2147+
)
2148+
elif v < 0:
2149+
raise ValueError(
2150+
"expected positive integer as indexer "
2151+
"for dimension %r, found %s" % (k, v)
2152+
)
2153+
elif v == 0:
2154+
raise ValueError("step cannot be zero")
2155+
indexers_slices = {k: slice(None, None, val) for k, val in indexers.items()}
2156+
return self.isel(indexers_slices)
20942157

20952158
def broadcast_like(
20962159
self, other: Union["Dataset", "DataArray"], exclude: Iterable[Hashable] = None

xarray/tests/test_dataarray.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1005,13 +1005,48 @@ def test_isel_drop(self):
10051005
def test_head(self):
10061006
assert_equal(self.dv.isel(x=slice(5)), self.dv.head(x=5))
10071007
assert_equal(self.dv.isel(x=slice(0)), self.dv.head(x=0))
1008+
assert_equal(
1009+
self.dv.isel({dim: slice(6) for dim in self.dv.dims}), self.dv.head(6)
1010+
)
1011+
assert_equal(
1012+
self.dv.isel({dim: slice(5) for dim in self.dv.dims}), self.dv.head()
1013+
)
1014+
with raises_regex(TypeError, "either dict-like or a single int"):
1015+
self.dv.head([3])
1016+
with raises_regex(TypeError, "expected integer type"):
1017+
self.dv.head(x=3.1)
1018+
with raises_regex(ValueError, "expected positive int"):
1019+
self.dv.head(-3)
10081020

10091021
def test_tail(self):
10101022
assert_equal(self.dv.isel(x=slice(-5, None)), self.dv.tail(x=5))
10111023
assert_equal(self.dv.isel(x=slice(0)), self.dv.tail(x=0))
1024+
assert_equal(
1025+
self.dv.isel({dim: slice(-6, None) for dim in self.dv.dims}),
1026+
self.dv.tail(6),
1027+
)
1028+
assert_equal(
1029+
self.dv.isel({dim: slice(-5, None) for dim in self.dv.dims}), self.dv.tail()
1030+
)
1031+
with raises_regex(TypeError, "either dict-like or a single int"):
1032+
self.dv.tail([3])
1033+
with raises_regex(TypeError, "expected integer type"):
1034+
self.dv.tail(x=3.1)
1035+
with raises_regex(ValueError, "expected positive int"):
1036+
self.dv.tail(-3)
10121037

10131038
def test_thin(self):
10141039
assert_equal(self.dv.isel(x=slice(None, None, 5)), self.dv.thin(x=5))
1040+
assert_equal(
1041+
self.dv.isel({dim: slice(None, None, 6) for dim in self.dv.dims}),
1042+
self.dv.thin(6),
1043+
)
1044+
with raises_regex(TypeError, "either dict-like or a single int"):
1045+
self.dv.thin([3])
1046+
with raises_regex(TypeError, "expected integer type"):
1047+
self.dv.thin(x=3.1)
1048+
with raises_regex(ValueError, "expected positive int"):
1049+
self.dv.thin(-3)
10151050
with raises_regex(ValueError, "cannot be zero"):
10161051
self.dv.thin(time=0)
10171052

xarray/tests/test_dataset.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1422,6 +1422,21 @@ def test_head(self):
14221422
actual = data.head(time=0)
14231423
assert_equal(expected, actual)
14241424

1425+
expected = data.isel({dim: slice(6) for dim in data.dims})
1426+
actual = data.head(6)
1427+
assert_equal(expected, actual)
1428+
1429+
expected = data.isel({dim: slice(5) for dim in data.dims})
1430+
actual = data.head()
1431+
assert_equal(expected, actual)
1432+
1433+
with raises_regex(TypeError, "either dict-like or a single int"):
1434+
data.head([3])
1435+
with raises_regex(TypeError, "expected integer type"):
1436+
data.head(dim2=3.1)
1437+
with raises_regex(ValueError, "expected positive int"):
1438+
data.head(time=-3)
1439+
14251440
def test_tail(self):
14261441
data = create_test_data()
14271442

@@ -1433,15 +1448,40 @@ def test_tail(self):
14331448
actual = data.tail(dim1=0)
14341449
assert_equal(expected, actual)
14351450

1451+
expected = data.isel({dim: slice(-6, None) for dim in data.dims})
1452+
actual = data.tail(6)
1453+
assert_equal(expected, actual)
1454+
1455+
expected = data.isel({dim: slice(-5, None) for dim in data.dims})
1456+
actual = data.tail()
1457+
assert_equal(expected, actual)
1458+
1459+
with raises_regex(TypeError, "either dict-like or a single int"):
1460+
data.tail([3])
1461+
with raises_regex(TypeError, "expected integer type"):
1462+
data.tail(dim2=3.1)
1463+
with raises_regex(ValueError, "expected positive int"):
1464+
data.tail(time=-3)
1465+
14361466
def test_thin(self):
14371467
data = create_test_data()
14381468

14391469
expected = data.isel(time=slice(None, None, 5), dim2=slice(None, None, 6))
14401470
actual = data.thin(time=5, dim2=6)
14411471
assert_equal(expected, actual)
14421472

1473+
expected = data.isel({dim: slice(None, None, 6) for dim in data.dims})
1474+
actual = data.thin(6)
1475+
assert_equal(expected, actual)
1476+
1477+
with raises_regex(TypeError, "either dict-like or a single int"):
1478+
data.thin([3])
1479+
with raises_regex(TypeError, "expected integer type"):
1480+
data.thin(dim2=3.1)
14431481
with raises_regex(ValueError, "cannot be zero"):
14441482
data.thin(time=0)
1483+
with raises_regex(ValueError, "expected positive int"):
1484+
data.thin(time=-3)
14451485

14461486
@pytest.mark.filterwarnings("ignore::DeprecationWarning")
14471487
def test_sel_fancy(self):

0 commit comments

Comments
 (0)