Skip to content

Commit 306105d

Browse files
authored
colorbar: remove assert statements (#36)
1 parent a69350d commit 306105d

File tree

3 files changed

+66
-53
lines changed

3 files changed

+66
-53
lines changed

mplotutils/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
# flake8: noqa
22

3-
from . import cartopy_utils, colorbar_utils, mpl_utils
3+
from . import _colorbar, cartopy_utils, mpl_utils
4+
from ._colorbar import *
45
from .cartopy_utils import *
5-
from .colorbar_utils import *
66
from .mpl_utils import *
77
from .xrcompat import *
88

Lines changed: 18 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import warnings
2+
13
import matplotlib.pyplot as plt
24
import numpy as np
35

@@ -145,8 +147,7 @@ def colorbar(
145147
plt.colorbar
146148
"""
147149

148-
orientations = ("vertical", "horizontal")
149-
if orientation not in orientations:
150+
if orientation not in ("vertical", "horizontal"):
150151
raise ValueError("orientation must be 'vertical' or 'horizontal'")
151152

152153
k = kwargs.keys()
@@ -156,14 +157,15 @@ def colorbar(
156157

157158
# ensure 'ax' does not end up in plt.colorbar(**kwargs)
158159
if "ax" in k:
160+
if ax2 is not None:
161+
raise ValueError("Cannot pass `ax`, and `ax2`")
159162
# assume it is ax2 (it can't be ax1)
160163
ax2 = kwargs.pop("ax")
161164

162165
f = ax1.get_figure()
163166

164-
if ax2 is not None:
165-
f2 = ax2.get_figure()
166-
assert f == f2, "'ax1' and 'ax2' must belong to the same figure"
167+
if ax2 is not None and f != ax2.get_figure():
168+
raise ValueError("'ax1' and 'ax2' must belong to the same figure")
167169

168170
cbax = _get_cbax(f)
169171

@@ -242,7 +244,7 @@ def _resize_colorbar_vert(
242244
cbax = f.add_axes([0, 0, 0.1, 0.1])
243245
cbar = plt.colorbar(h, orientation='vertical', cax=cbax)
244246
245-
func = mpu.colorbar_utils_resize_colorbar_vert(cbax, ax)
247+
func = mpu._resize_colorbar_vert(cbax, ax)
246248
f.canvas.mpl_connect('draw_event', func)
247249
248250
ax.set_global()
@@ -337,7 +339,7 @@ def _resize_colorbar_horz(
337339
cbax = f.add_axes([0, 0, 0.1, 0.1])
338340
cbar = plt.colorbar(h, orientation='horizontal', cax=cbax)
339341
340-
func = mpu.colorbar_utils._resize_colorbar_horz(cbax, ax)
342+
func = mpu._resize_colorbar_horz(cbax, ax)
341343
f.canvas.mpl_connect('draw_event', func)
342344
343345
ax.set_global()
@@ -379,11 +381,11 @@ def inner(event=None):
379381
full_width = posn2.x0 - posn1.x0 + posn2.width
380382

381383
pad_scaled = pad * posn1.height
382-
size_scaled = size * posn1.height
383384

384385
width = full_width - shrink * full_width
385386

386387
if aspect is None:
388+
size_scaled = size * posn1.height
387389
height = size_scaled
388390
else:
389391
figure_aspect = np.divide(*f.get_size_inches())
@@ -414,15 +416,16 @@ def _parse_shift_shrink(shift, shrink):
414416
if shrink is None:
415417
shrink = shift
416418

417-
assert (shift >= 0.0) & (shift <= 1.0), "'shift' must be in 0...1"
418-
assert (shrink >= 0.0) & (shrink <= 1.0), "'shrink' must be in 0...1"
419+
if (shift < 0.0) or (shift > 1.0):
420+
raise ValueError("'shift' must be in 0...1")
421+
422+
if (shrink < 0.0) or (shrink > 1.0):
423+
raise ValueError("'shrink' must be in 0...1")
419424

420425
if shift > shrink:
421-
msg = (
422-
"Warning: 'shift' is larger than 'shrink', colorbar\n"
423-
"will extend beyond the axes!"
426+
warnings.warn(
427+
"'shift' is larger than 'shrink', colorbar will extend beyond the axes"
424428
)
425-
print(msg)
426429

427430
return shift, shrink
428431

@@ -433,16 +436,12 @@ def _parse_shift_shrink(shift, shrink):
433436
def _parse_size_aspect_pad(size, aspect, pad, orientation):
434437

435438
if (size is not None) and (aspect is not None):
436-
raise ValueError("you can only pass one of 'aspect' and 'size'")
439+
raise ValueError("Can only pass one of 'aspect' and 'size'")
437440

438441
# default is aspect=20
439442
if (size is None) and (aspect is None):
440443
aspect = 20
441444

442-
# we need a large size so it is not limiting for set_aspect
443-
if aspect is not None:
444-
size = 10
445-
446445
# default mpl setting
447446
if pad is None:
448447
pad = 0.05 if orientation == "vertical" else 0.15

mplotutils/tests/test_colorbar_utils.py

Lines changed: 46 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
import matplotlib.pyplot as plt
22
import numpy as np
3-
from pytest import raises
3+
import pytest
44

5-
from mplotutils.colorbar_utils import (
5+
from mplotutils._colorbar import (
66
_get_cbax,
77
_parse_shift_shrink,
88
_parse_size_aspect_pad,
@@ -25,54 +25,72 @@ def test_parse_shift_shrink():
2525

2626
assert _parse_shift_shrink(0.5, 0.5) == (0.5, 0.5)
2727

28-
with raises(AssertionError):
28+
with pytest.raises(ValueError, match="'shift' must be in 0...1"):
2929
_parse_shift_shrink(-0.1, 0)
3030

31-
with raises(AssertionError):
31+
with pytest.raises(ValueError, match="'shift' must be in 0...1"):
3232
_parse_shift_shrink(1.1, 0)
3333

34-
with raises(AssertionError):
34+
with pytest.raises(ValueError, match="'shrink' must be in 0...1"):
3535
_parse_shift_shrink(0, -0.1)
3636

37-
with raises(AssertionError):
37+
with pytest.raises(ValueError, match="'shrink' must be in 0...1"):
3838
_parse_shift_shrink(0, 1.1)
3939

40-
41-
# =============================================================================
40+
with pytest.warns(UserWarning, match="'shift' is larger than 'shrink'"):
41+
_parse_shift_shrink(0.6, 0.3)
4242

4343

4444
def test_parse_size_aspect_pad():
4545
"""
4646
size, aspect, pad = _parse_size_aspect_pad(size, aspect, pad, 'horizontal')
4747
"""
4848

49-
res = _parse_size_aspect_pad(0.1, None, 0.1, "horizontal")
50-
exp = (0.1, None, 0.1)
51-
assert res == exp
49+
with pytest.raises(ValueError, match="Can only pass one of 'aspect' and 'size'"):
50+
_parse_size_aspect_pad(1, 1, 0.1, "horizontal")
5251

53-
res = _parse_size_aspect_pad(None, None, 0.1, "horizontal")
54-
exp = (10, 20, 0.1)
55-
assert res == exp
52+
result = _parse_size_aspect_pad(0.1, None, 0.1, "horizontal")
53+
assert result == (0.1, None, 0.1)
5654

57-
res = _parse_size_aspect_pad(None, 20, 0.1, "horizontal")
58-
exp = (10, 20, 0.1)
59-
assert res == exp
55+
result = _parse_size_aspect_pad(None, None, 0.1, "horizontal")
56+
assert result == (None, 20, 0.1)
6057

61-
with raises(ValueError):
62-
_parse_size_aspect_pad(1, 1, 0.1, "horizontal")
58+
result = _parse_size_aspect_pad(None, 10, 0.1, "horizontal")
59+
assert result == (None, 10, 0.1)
60+
61+
result = _parse_size_aspect_pad(None, 20, 0.1, "horizontal")
62+
assert result == (None, 20, 0.1)
6363

64-
res = _parse_size_aspect_pad(None, None, None, "horizontal")
65-
exp = (10, 20, 0.15)
66-
assert res == exp
64+
result = _parse_size_aspect_pad(None, None, None, "horizontal")
65+
assert result == (None, 20, 0.15)
6766

68-
res = _parse_size_aspect_pad(None, None, None, "vertical")
69-
exp = (10, 20, 0.05)
70-
assert res == exp
67+
result = _parse_size_aspect_pad(None, None, None, "vertical")
68+
assert result == (None, 20, 0.05)
7169

7270

7371
# =============================================================================
7472

7573

74+
def test_colorbar_differnt_figures():
75+
76+
_, ax1 = plt.subplots()
77+
_, ax2 = plt.subplots()
78+
79+
h = ax1.pcolormesh([[0, 1]])
80+
81+
with pytest.raises(ValueError, match="must belong to the same figure"):
82+
colorbar(h, ax1, ax2)
83+
84+
85+
def test_colorbar_ax_and_ax2_error():
86+
87+
_, (ax1, ax2, ax3) = plt.subplots(3, 1)
88+
h = ax1.pcolormesh([[0, 1]])
89+
90+
with pytest.raises(ValueError, match="Cannot pass `ax`, and `ax2`"):
91+
colorbar(h, ax1, ax2, ax=ax3)
92+
93+
7694
def _easy_cbar_vert(**kwargs):
7795

7896
f, ax = plt.subplots()
@@ -363,19 +381,15 @@ def test_colorbar():
363381
f1, ax1 = plt.subplots()
364382
h = ax1.pcolormesh([[0, 1]])
365383

366-
with raises(ValueError):
384+
with pytest.raises(ValueError):
367385
colorbar(h, ax1, orientation="wrong")
368386

369-
with raises(ValueError):
387+
with pytest.raises(ValueError):
370388
colorbar(h, ax1, anchor=5)
371389

372-
with raises(ValueError):
390+
with pytest.raises(ValueError):
373391
colorbar(h, ax1, panchor=5)
374392

375-
with raises(AssertionError):
376-
f2, ax2 = plt.subplots()
377-
colorbar(h, ax1, ax2)
378-
379393

380394
# =============================================================================
381395

0 commit comments

Comments
 (0)