Skip to content

Commit f513bc3

Browse files
authored
refactor cyclic_dataarray (#33)
* refactor cyclic_dataarray * changelog
1 parent a7a675d commit f513bc3

File tree

4 files changed

+42
-66
lines changed

4 files changed

+42
-66
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
gridspec.
2020
* Replaced `ax.get_subplotspec().get_geometry()` with `ax.get_subplotspec().get_geometry()`
2121
as the former was deprecated in matplotlib (#8).
22+
* Refactor `mpu.cyclic_dataarray` using `obj.pad` ([#33](https://github.com/mathause/mplotutils/pull/33)).
2223
* Enabled CI on github actions (#9).
2324
* Formatted with black and isort, checked with flake8.
2425

licenses/PLOT_ALL_IN_NCFILE_LICENSE

Lines changed: 0 additions & 11 deletions
This file was deleted.

mplotutils/cartopy_utils.py

Lines changed: 22 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,6 @@
1-
# code by M.Hauser
2-
31
import warnings
42

53
import cartopy.crs as ccrs
6-
import cartopy.util as cutil
74
import matplotlib.pyplot as plt
85
import numpy as np
96
import shapely.geometry as sgeom
@@ -127,14 +124,29 @@ def _is_monotonic(coord, axis=0):
127124
# =============================================================================
128125

129126

130-
def cyclic_dataarray(da, coord="lon"):
131-
"""
132-
Add a cyclic coordinate point to a DataArray along a specified named dimension.
127+
def cyclic_dataarray(obj, coord="lon"):
128+
"""Add a cyclic coordinate point to a DataArray or Dataset along a dimension.
129+
130+
Parameters
131+
----------
132+
obj : xr.Dataset | xr.DataArray
133+
Object to add the cyclic data point to.
134+
coord : str, default: "lon"
135+
Name of the
133136
137+
Returns
138+
-------
139+
obj_cyclic : xr.Dataset | xr.DataArray
140+
The same as `obj` with a cyclic data point added.
141+
142+
Examples
143+
--------
134144
>>> import xarray as xr
135-
>>> data = xr.DataArray([[1, 2, 3], [4, 5, 6]],
136-
... coords={'x': [1, 2], 'y': range(3)},
137-
... dims=['x', 'y'])
145+
>>> data = xr.DataArray(
146+
... [[1, 2, 3], [4, 5, 6]],
147+
... coords={'x': [1, 2], 'y': range(3)},
148+
... dims=['x', 'y']
149+
... )
138150
>>> data_cyclic = cyclic_dataarray(data, 'y')
139151
>>> data_cyclic
140152
<xarray.DataArray (x: 2, y: 4)>
@@ -144,38 +156,9 @@ def cyclic_dataarray(da, coord="lon"):
144156
* x (x) int64 1 2
145157
* y (y) int64 0 1 2 3
146158
147-
Notes
148-
-----
149-
After: https://github.com/darothen/plot-all-in-ncfile/blob/master/plot_util.py
150-
151159
"""
152-
import xarray as xr
153160

154-
assert isinstance(da, xr.DataArray)
155-
156-
lon_idx = da.dims.index(coord)
157-
cyclic_data, cyclic_coord = cutil.add_cyclic_point(
158-
da.values, coord=da.coords[coord], axis=lon_idx
159-
)
160-
161-
# Copy and add the cyclic coordinate and data
162-
new_coords = dict(da.coords)
163-
new_coords[coord] = cyclic_coord
164-
new_values = cyclic_data
165-
166-
new_da = xr.DataArray(new_values, dims=da.dims, coords=new_coords)
167-
168-
# Copy the attributes for the re-constructed data and coords
169-
for att, val in da.attrs.items():
170-
new_da.attrs[att] = val
171-
for c in da.coords:
172-
for att in da.coords[c].attrs:
173-
new_da.coords[c].attrs[att] = da.coords[c].attrs[att]
174-
175-
return new_da
176-
177-
178-
# =============================================================================
161+
return obj.pad({coord: (0, 1)}, mode="wrap")
179162

180163

181164
def ylabel_map(s, labelpad=None, size=None, weight=None, y=0.5, ax=None, **kwargs):
Lines changed: 19 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,31 @@
1-
import numpy as np
1+
import pytest
22
import xarray as xr
33

4-
from mplotutils.cartopy_utils import cyclic_dataarray
4+
from mplotutils import cyclic_dataarray
55

66

7-
def test_cyclic_dataarray():
7+
@pytest.mark.parametrize("as_dataset", (True, False))
8+
def test_cyclic_dataarray(as_dataset):
89

9-
data = xr.DataArray(
10-
[[1, 2, 3], [4, 5, 6]], coords={"x": [1, 2], "y": range(3)}, dims=["x", "y"]
10+
data = [[1, 2, 3], [4, 5, 6]]
11+
da = xr.DataArray(
12+
data, dims=("y", "x"), coords={"y": [1, 2], "x": [0, 1, 2]}, name="data"
1113
)
1214

13-
res = cyclic_dataarray(data, "y")
14-
15-
expected_data = np.asarray([[1, 2, 3, 1], [4, 5, 6, 4]])
15+
expected = [[1, 2, 3, 1], [4, 5, 6, 4]]
16+
da_expected = xr.DataArray(
17+
expected, dims=("y", "x"), coords={"y": [1, 2], "x": [0, 1, 2, 0]}, name="data"
18+
)
1619

17-
np.testing.assert_allclose(res, expected_data)
20+
data = da.to_dataset() if as_dataset else da
21+
expected = da_expected.to_dataset() if as_dataset else da_expected
1822

19-
np.testing.assert_allclose(res.y, [0, 1, 2, 3])
20-
np.testing.assert_allclose(res.x, [1, 2])
23+
result = cyclic_dataarray(data, "x")
24+
xr.testing.assert_identical(result, expected)
2125

2226
# per default use 'lon'
23-
data = xr.DataArray(
24-
[[1, 2, 3], [4, 5, 6]], coords={"x": [1, 2], "lon": range(3)}, dims=["x", "lon"]
25-
)
27+
data = data.rename(x="lon")
28+
expected = expected.rename(x="lon")
2629

27-
res = cyclic_dataarray(data)
28-
np.testing.assert_allclose(res, expected_data)
30+
result = cyclic_dataarray(data)
31+
xr.testing.assert_identical(result, expected)

0 commit comments

Comments
 (0)