Skip to content

Commit 4248b9e

Browse files
authored
fix and test map ticklabels (#47)
* fix and test map ticklabels * fixes * comment faulty test * add comment on faulty test * update code * remove strict option
1 parent edb8e16 commit 4248b9e

File tree

2 files changed

+156
-21
lines changed

2 files changed

+156
-21
lines changed

mplotutils/cartopy_utils.py

Lines changed: 45 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
1+
import warnings
2+
13
import cartopy.crs as ccrs
24
import matplotlib.pyplot as plt
35
import numpy as np
4-
import shapely.geometry as sgeom
6+
import shapely.geometry
57
from cartopy.mpl.gridliner import LATITUDE_FORMATTER, LONGITUDE_FORMATTER
68

79
from .colormaps import _get_label_attr
@@ -123,7 +125,7 @@ def ylabel_map(s, labelpad=None, size=None, weight=None, y=0.5, ax=None, **kwarg
123125
rotation_mode=rotation_mode,
124126
size=size,
125127
weight=weight,
126-
**kwargs
128+
**kwargs,
127129
)
128130

129131
return h
@@ -188,7 +190,7 @@ def xlabel_map(s, labelpad=None, size=None, weight=None, x=0.5, ax=None, **kwarg
188190
rotation_mode=rotation_mode,
189191
size=size,
190192
weight=weight,
191-
**kwargs
193+
**kwargs,
192194
)
193195

194196
return h
@@ -206,7 +208,7 @@ def yticklabels(
206208
ha="right",
207209
va="center",
208210
bbox_props=dict(ec="none", fc="none"),
209-
**kwargs
211+
**kwargs,
210212
):
211213

212214
"""
@@ -238,18 +240,18 @@ def yticklabels(
238240
239241
"""
240242

241-
plt.draw()
242-
243243
# get ax if necessary
244244
if ax is None:
245245
ax = plt.gca()
246246

247+
ax.figure.canvas.draw()
248+
247249
labelpad, size, weight = _get_label_attr(labelpad, size, weight)
248250

249251
boundary_pc = _get_boundary_platecarree(ax)
250252

251253
# ensure labels are on rhs and not in the middle
252-
if len(boundary_pc) == 1:
254+
if len(boundary_pc.geoms) == 1:
253255
lonmin, lonmax = -180, 180
254256
else:
255257
lonmin, lonmax = 0, 360
@@ -265,7 +267,7 @@ def yticklabels(
265267
"WARN: no points found for ylabel\n"
266268
"y_lim is: {:0.2f} to {:0.2f}".format(y_lim[0], y_lim[1])
267269
)
268-
print(msg)
270+
warnings.warn(msg)
269271

270272
# get a transform instance that mpl understands
271273
transform = ccrs.PlateCarree()._as_mpl_transform(ax)
@@ -281,7 +283,7 @@ def yticklabels(
281283
x = _determine_intersection(boundary_pc, [lonmin, y], [lonmax, y])
282284

283285
if x.size > 0:
284-
x = x[0, 0]
286+
x = x[:, 0].min()
285287
lp = labelpad[0] + labelpad[1] * np.abs(y) / 90
286288

287289
ax.annotate(
@@ -295,7 +297,7 @@ def yticklabels(
295297
xytext=(-lp, 0),
296298
textcoords="offset points",
297299
bbox=bbox_props,
298-
**kwargs
300+
**kwargs,
299301
)
300302

301303

@@ -308,7 +310,7 @@ def xticklabels(
308310
ha="center",
309311
va="top",
310312
bbox_props=dict(ec="none", fc="none"),
311-
**kwargs
313+
**kwargs,
312314
):
313315

314316
"""
@@ -340,12 +342,17 @@ def xticklabels(
340342
341343
"""
342344

343-
plt.draw()
344-
345345
# get ax if necessary
346346
if ax is None:
347347
ax = plt.gca()
348348

349+
ax.figure.canvas.draw()
350+
351+
# proj = ccrs.PlateCarree()
352+
# points = shapely.geometry.MultiPoint([shapely.geometry.Point(x, 0) for x in x_ticks])
353+
# points = proj.project_geometry(points, proj)
354+
# x_ticks = [x.x for x in points.geoms]
355+
349356
labelpad, size, weight = _get_label_attr(labelpad, size, weight)
350357

351358
boundary_pc = _get_boundary_platecarree(ax)
@@ -361,7 +368,7 @@ def xticklabels(
361368
"WARN: no points found for xlabel\n"
362369
"x_lim is: {:0.2f} to {:0.2f}".format(x_lim[0], x_lim[1])
363370
)
364-
print(msg)
371+
warnings.warn(msg)
365372

366373
# get a transform instance that mpl understands
367374
transform = ccrs.PlateCarree()._as_mpl_transform(ax)
@@ -373,7 +380,7 @@ def xticklabels(
373380

374381
y = _determine_intersection(boundary_pc, [x, -90], [x, 90])
375382
if y.size > 0:
376-
y = y[0, 1]
383+
y = y[:, 1].min()
377384

378385
ax.annotate(
379386
msg,
@@ -386,25 +393,42 @@ def xticklabels(
386393
xytext=(0, -labelpad),
387394
textcoords="offset points",
388395
bbox=bbox_props,
389-
**kwargs
396+
**kwargs,
390397
)
391398

392399

393400
def _get_boundary_platecarree(ax):
394401
# get the bounding box of the map in lat/ lon coordinates
395402
# after ax._get_extent_geom
396403
proj = ccrs.PlateCarree()
397-
boundary_poly = sgeom.Polygon(ax.outline_patch.get_path().vertices)
404+
boundary_poly = shapely.geometry.Polygon(ax.spines["geo"].get_path().vertices)
398405
eroded_boundary = boundary_poly.buffer(-ax.projection.threshold / 100)
399406
boundary_pc = proj.project_geometry(eroded_boundary, ax.projection)
400407

408+
# boundary_pc = proj.project_geometry(boundary_poly, ax.projection)
409+
401410
return boundary_pc
402411

403412

404413
def _determine_intersection(polygon, xy1, xy2):
405414

406-
p1 = sgeom.Point(xy1)
407-
p2 = sgeom.Point(xy2)
408-
ls = sgeom.LineString([p1, p2])
415+
p1 = shapely.geometry.Point(xy1)
416+
p2 = shapely.geometry.Point(xy2)
417+
ls = shapely.geometry.LineString([p1, p2])
418+
419+
intersection = polygon.boundary.intersection(ls)
420+
421+
if isinstance(intersection, shapely.geometry.MultiPoint):
422+
arr = np.array([x.coords for x in intersection.geoms]).squeeze()
423+
elif isinstance(intersection, shapely.geometry.Point):
424+
arr = np.array([intersection.coords]).squeeze()
425+
arr = np.atleast_2d(arr)
426+
elif isinstance(intersection, shapely.geometry.LineString):
427+
if intersection.is_empty:
428+
return np.array([])
429+
else:
430+
return np.array(intersection.coords)
431+
else:
432+
raise TypeError(f"Unexpected type: {type(intersection)}")
409433

410-
return np.asarray(polygon.boundary.intersection(ls))
434+
return arr
Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
import sys
2+
3+
import cartopy.crs as ccrs
4+
import numpy as np
5+
6+
import mplotutils as mpu
7+
8+
from . import subplots_context
9+
10+
11+
def test_yticklabels_robinson():
12+
13+
with subplots_context(subplot_kw=dict(projection=ccrs.Robinson())) as (f, ax):
14+
15+
ax.set_global()
16+
17+
lat = np.arange(-90, 91, 20)
18+
19+
mpu.yticklabels(lat, ax=ax, size=8)
20+
21+
x_pos = -179.99
22+
23+
# two elements are not added because they are beyond the map limits
24+
lat = lat[1:-1]
25+
26+
# remove when dropping py 3.9
27+
strict = {"strict": True} if sys.version_info >= (3, 10) else {}
28+
for t, y_pos in zip(ax.texts, lat, **strict):
29+
30+
np.testing.assert_allclose((x_pos, y_pos), t.xy, atol=0.01)
31+
32+
assert ax.texts[0].get_text() == "70°S"
33+
assert ax.texts[-1].get_text() == "70°N"
34+
35+
36+
def test_yticklabels_robinson_180():
37+
38+
proj = ccrs.Robinson(central_longitude=180)
39+
with subplots_context(subplot_kw=dict(projection=proj)) as (f, ax):
40+
41+
ax.set_global()
42+
43+
lat = np.arange(-90, 91, 20)
44+
45+
mpu.yticklabels(lat, ax=ax, size=8)
46+
47+
x_pos = 0.0
48+
49+
# two elements are not added because they are beyond the map limits
50+
lat = lat[1:-1]
51+
52+
# remove when dropping py 3.9
53+
strict = {"strict": True} if sys.version_info >= (3, 10) else {}
54+
for t, y_pos in zip(ax.texts, lat, **strict):
55+
56+
np.testing.assert_allclose((x_pos, y_pos), t.xy, atol=0.01)
57+
58+
assert ax.texts[0].get_text() == "70°S"
59+
assert ax.texts[-1].get_text() == "70°N"
60+
61+
62+
def test_xticklabels_robinson():
63+
64+
with subplots_context(subplot_kw=dict(projection=ccrs.Robinson())) as (f, ax):
65+
66+
ax.set_global()
67+
68+
lon = np.arange(-180, 181, 60)
69+
70+
mpu.xticklabels(lon, ax=ax, size=8)
71+
72+
y_pos = -89.99
73+
74+
# two elements are not added because they are beyond the map limits
75+
lon = lon[1:-1]
76+
77+
# remove when dropping py 3.9
78+
strict = {"strict": True} if sys.version_info >= (3, 10) else {}
79+
80+
for t, x_pos in zip(ax.texts, lon, **strict):
81+
82+
np.testing.assert_allclose((x_pos, y_pos), t.xy, atol=0.01)
83+
84+
assert ax.texts[0].get_text() == "120°W"
85+
assert ax.texts[-1].get_text() == "120°E"
86+
87+
88+
# TODO: https://github.com/mathause/mplotutils/issues/48
89+
# def test_xticklabels_robinson_180():
90+
91+
# proj = ccrs.Robinson(central_longitude=180)
92+
# with subplots_context(subplot_kw=dict(projection=proj)) as (f, ax):
93+
94+
# ax.set_global()
95+
96+
# # lon = np.arange(-180, 181, 60)
97+
# lon = np.arange(0, 360, 60)
98+
99+
100+
# mpu.xticklabels(lon, ax=ax, size=8)
101+
102+
# y_pos = -89.99
103+
104+
# # two elements are not added because they are beyond the map limits
105+
# lon = lon[1:-1]
106+
# for t, x_pos in zip(ax.texts, lon, strict=True):
107+
108+
# np.testing.assert_allclose((x_pos, y_pos), t.xy, atol=0.01)
109+
110+
# assert ax.texts[0].get_text() == "60°E"
111+
# assert ax.texts[-1].get_text() == "60°W"

0 commit comments

Comments
 (0)