Skip to content

Commit a69350d

Browse files
authored
move xarray functions to xrcompat (#35)
* move xarray functions to xrcompat * fix tests * Update mplotutils/xrcompat.py * Update mplotutils/xrcompat.py * Update mplotutils/xrcompat.py
1 parent c6ed25a commit a69350d

File tree

4 files changed

+91
-95
lines changed

4 files changed

+91
-95
lines changed

mplotutils/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from .cartopy_utils import *
55
from .colorbar_utils import *
66
from .mpl_utils import *
7+
from .xrcompat import *
78

89
try:
910
from importlib.metadata import version as _get_version

mplotutils/cartopy_utils.py

Lines changed: 0 additions & 94 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
import warnings
2-
31
import cartopy.crs as ccrs
42
import matplotlib.pyplot as plt
53
import numpy as np
@@ -32,98 +30,6 @@ def sample_data_map(nlons, nlats):
3230
return lon, lat, data
3331

3432

35-
# =============================================================================
36-
37-
38-
# from xarray
39-
def infer_interval_breaks(x, y, clip=False):
40-
"""find edges of gridcells, given their centers"""
41-
42-
# TODO: require cartopy >= 0.21 when removing this function
43-
warnings.warn(
44-
"It's no longer necessary to compute the edges of the array. This is now done"
45-
"in matplotlib. This function will be removed from mplotutils in a future "
46-
"version.",
47-
FutureWarning,
48-
)
49-
50-
if len(x.shape) == 1:
51-
x = _infer_interval_breaks(x)
52-
y = _infer_interval_breaks(y)
53-
else:
54-
# we have to infer the intervals on both axes
55-
x = _infer_interval_breaks(x, axis=1)
56-
x = _infer_interval_breaks(x, axis=0)
57-
y = _infer_interval_breaks(y, axis=1)
58-
y = _infer_interval_breaks(y, axis=0)
59-
60-
if clip:
61-
y = np.clip(y, -90, 90)
62-
63-
return x, y
64-
65-
66-
# from xarray
67-
def _infer_interval_breaks(coord, axis=0):
68-
"""
69-
>>> _infer_interval_breaks(np.arange(5))
70-
array([-0.5, 0.5, 1.5, 2.5, 3.5, 4.5])
71-
>>> _infer_interval_breaks([[0, 1], [3, 4]], axis=1)
72-
array([[-0.5, 0.5, 1.5],
73-
[ 2.5, 3.5, 4.5]])
74-
"""
75-
76-
if not _is_monotonic(coord, axis=axis):
77-
raise ValueError(
78-
"The input coordinate is not sorted in increasing "
79-
"order along axis %d. This can lead to unexpected "
80-
"results. Consider calling the `sortby` method on "
81-
"the input DataArray. To plot data with categorical "
82-
"axes, consider using the `heatmap` function from "
83-
"the `seaborn` statistical plotting library." % axis
84-
)
85-
86-
coord = np.asarray(coord)
87-
deltas = 0.5 * np.diff(coord, axis=axis)
88-
if deltas.size == 0:
89-
deltas = np.array(0.0)
90-
first = np.take(coord, [0], axis=axis) - np.take(deltas, [0], axis=axis)
91-
last = np.take(coord, [-1], axis=axis) + np.take(deltas, [-1], axis=axis)
92-
trim_last = tuple(
93-
slice(None, -1) if n == axis else slice(None) for n in range(coord.ndim)
94-
)
95-
return np.concatenate([first, coord[trim_last] + deltas, last], axis=axis)
96-
97-
98-
# from xarray
99-
def _is_monotonic(coord, axis=0):
100-
"""
101-
>>> _is_monotonic(np.array([0, 1, 2]))
102-
True
103-
>>> _is_monotonic(np.array([2, 1, 0]))
104-
True
105-
>>> _is_monotonic(np.array([0, 2, 1]))
106-
False
107-
"""
108-
coord = np.asarray(coord)
109-
110-
if coord.shape[axis] < 3:
111-
return True
112-
else:
113-
n = coord.shape[axis]
114-
delta_pos = coord.take(np.arange(1, n), axis=axis) >= coord.take(
115-
np.arange(0, n - 1), axis=axis
116-
)
117-
delta_neg = coord.take(np.arange(1, n), axis=axis) <= coord.take(
118-
np.arange(0, n - 1), axis=axis
119-
)
120-
121-
return np.all(delta_pos) or np.all(delta_neg)
122-
123-
124-
# =============================================================================
125-
126-
12733
def cyclic_dataarray(obj, coord="lon"):
12834
"""Add a cyclic coordinate point to a DataArray or Dataset along a dimension.
12935

mplotutils/tests/test_infer_interval_breaks.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import pytest
33
from numpy.testing import assert_array_equal # noqa: F401
44

5-
from mplotutils.cartopy_utils import _infer_interval_breaks, infer_interval_breaks
5+
from mplotutils.xrcompat import _infer_interval_breaks, infer_interval_breaks
66

77

88
def test_infer_interval_breaks_warns():

mplotutils/xrcompat.py

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
# code vendored from xarray under the conditions of their license
2+
# see licenses/XARRAY_LICENSE
3+
4+
import warnings
5+
6+
import numpy as np
7+
8+
9+
def infer_interval_breaks(x, y, clip=False):
10+
"""find edges of gridcells, given their centers"""
11+
12+
# TODO: require cartopy >= 0.21 when removing this function
13+
warnings.warn(
14+
"It's no longer necessary to compute the edges of the array. This is now done"
15+
"in matplotlib. This function will be removed from mplotutils in a future "
16+
"version.",
17+
FutureWarning,
18+
)
19+
20+
if len(x.shape) == 1:
21+
x = _infer_interval_breaks(x)
22+
y = _infer_interval_breaks(y)
23+
else:
24+
# we have to infer the intervals on both axes
25+
x = _infer_interval_breaks(x, axis=1)
26+
x = _infer_interval_breaks(x, axis=0)
27+
y = _infer_interval_breaks(y, axis=1)
28+
y = _infer_interval_breaks(y, axis=0)
29+
30+
if clip:
31+
y = np.clip(y, -90, 90)
32+
33+
return x, y
34+
35+
36+
def _infer_interval_breaks(coord, axis=0):
37+
"""
38+
>>> _infer_interval_breaks(np.arange(5))
39+
array([-0.5, 0.5, 1.5, 2.5, 3.5, 4.5])
40+
>>> _infer_interval_breaks([[0, 1], [3, 4]], axis=1)
41+
array([[-0.5, 0.5, 1.5],
42+
[ 2.5, 3.5, 4.5]])
43+
"""
44+
45+
if not _is_monotonic(coord, axis=axis):
46+
raise ValueError(
47+
"The input coordinate is not sorted in increasing "
48+
"order along axis %d. This can lead to unexpected "
49+
"results. Consider calling the `sortby` method on "
50+
"the input DataArray. To plot data with categorical "
51+
"axes, consider using the `heatmap` function from "
52+
"the `seaborn` statistical plotting library." % axis
53+
)
54+
55+
coord = np.asarray(coord)
56+
deltas = 0.5 * np.diff(coord, axis=axis)
57+
if deltas.size == 0:
58+
deltas = np.array(0.0)
59+
first = np.take(coord, [0], axis=axis) - np.take(deltas, [0], axis=axis)
60+
last = np.take(coord, [-1], axis=axis) + np.take(deltas, [-1], axis=axis)
61+
trim_last = tuple(
62+
slice(None, -1) if n == axis else slice(None) for n in range(coord.ndim)
63+
)
64+
return np.concatenate([first, coord[trim_last] + deltas, last], axis=axis)
65+
66+
67+
def _is_monotonic(coord, axis=0):
68+
"""
69+
>>> _is_monotonic(np.array([0, 1, 2]))
70+
True
71+
>>> _is_monotonic(np.array([2, 1, 0]))
72+
True
73+
>>> _is_monotonic(np.array([0, 2, 1]))
74+
False
75+
"""
76+
coord = np.asarray(coord)
77+
78+
if coord.shape[axis] < 3:
79+
return True
80+
else:
81+
n = coord.shape[axis]
82+
delta_pos = coord.take(np.arange(1, n), axis=axis) >= coord.take(
83+
np.arange(0, n - 1), axis=axis
84+
)
85+
delta_neg = coord.take(np.arange(1, n), axis=axis) <= coord.take(
86+
np.arange(0, n - 1), axis=axis
87+
)
88+
89+
return np.all(delta_pos) or np.all(delta_neg)

0 commit comments

Comments
 (0)