Skip to content

Commit 8d94f80

Browse files
authored
Merge pull request #112 from ianhi/move-stuff-around
Move more function into helper files
2 parents 3bb2ce8 + 91cbf22 commit 8d94f80

File tree

3 files changed

+320
-300
lines changed

3 files changed

+320
-300
lines changed

examples/interactive_plot-examples.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -273,7 +273,7 @@
273273
"source": [
274274
"from ipywidgets import HBox, VBox\n",
275275
"fig, ax, sliders = interactive_plot([f1, f2], x=x, tau = tau, beta = beta, display=False)\n",
276-
"slider_vbox = VBox(sliders)\n",
276+
"slider_vbox = sliders\n",
277277
"HBox([slider_vbox, VBox([slider_vbox, fig.canvas, slider_vbox]), slider_vbox])"
278278
]
279279
},

mpl_interactions/helpers.py

Lines changed: 286 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,19 @@
1-
from collections.abc import Callable
1+
from collections import defaultdict
2+
from collections.abc import Callable, Iterable
3+
from functools import partial
24
from numbers import Number
35

6+
import ipywidgets as widgets
7+
import matplotlib.widgets as mwidgets
48
import numpy as np
9+
from IPython.display import display as ipy_display
10+
from matplotlib import __version__ as mpl_version
511
from matplotlib import get_backend
12+
from matplotlib.pyplot import axes
613
from numpy.distutils.misc_util import is_sequence
14+
from packaging import version
15+
16+
from .utils import figure, ioff
717

818
__all__ = [
919
"decompose_bbox",
@@ -16,6 +26,13 @@
1626
"broadcast_many",
1727
"notebook_backend",
1828
"callable_else_value",
29+
"kwargs_to_ipywidgets",
30+
"extract_num_options",
31+
"changeify",
32+
"kwargs_to_mpl_widgets",
33+
"create_slider_format_dict",
34+
"gogogo_figure",
35+
"gogogo_display",
1936
]
2037

2138

@@ -160,3 +177,271 @@ def callable_else_value(arg, params):
160177
if isinstance(arg, Callable):
161178
return arg(**params)
162179
return arg
180+
181+
182+
def kwargs_to_ipywidgets(kwargs, params, update, slider_format_strings):
183+
"""
184+
this will break if you pass a matplotlib slider. I suppose it could support mixed types of sliders
185+
but that doesn't really seem worthwhile?
186+
"""
187+
labels = []
188+
sliders = []
189+
controls = []
190+
for key, val in kwargs.items():
191+
if isinstance(val, set):
192+
if len(val) == 1:
193+
val = val.pop()
194+
if isinstance(val, tuple):
195+
# want the categories to be ordered
196+
pass
197+
else:
198+
# fixed parameter
199+
params[key] = val
200+
else:
201+
val = list(val)
202+
203+
# categorical
204+
if len(val) <= 3:
205+
selector = widgets.RadioButtons(options=val)
206+
else:
207+
selector = widgets.Select(options=val)
208+
params[key] = val[0]
209+
controls.append(selector)
210+
selector.observe(partial(update, key=key, label=None), names=["value"])
211+
elif isinstance(val, widgets.Widget) or isinstance(val, widgets.fixed):
212+
if not hasattr(val, "value"):
213+
raise TypeError(
214+
"widgets passed as parameters must have the `value` trait."
215+
"But the widget passed for {key} does not have a `.value` attribute"
216+
)
217+
if isinstance(val, widgets.fixed):
218+
params[key] = val.value
219+
else:
220+
params[key] = val.value
221+
controls.append(val)
222+
val.observe(partial(update, key=key, label=None), names=["value"])
223+
else:
224+
if isinstance(val, tuple) and len(val) in [2, 3]:
225+
# treat as an argument to linspace
226+
# idk if it's acceptable to overwrite kwargs like this
227+
# but I think at this point kwargs is just a dict like any other
228+
val = np.linspace(*val)
229+
kwargs[key] = val
230+
val = np.atleast_1d(val)
231+
if val.ndim > 1:
232+
raise ValueError(f"{key} is {val.ndim}D but can only be 1D or a scalar")
233+
if len(val) == 1:
234+
# don't need to create a slider
235+
params[key] = val
236+
else:
237+
params[key] = val[0]
238+
labels.append(widgets.Label(value=slider_format_strings[key].format(val[0])))
239+
sliders.append(
240+
widgets.IntSlider(min=0, max=val.size - 1, readout=False, description=key)
241+
)
242+
controls.append(widgets.HBox([sliders[-1], labels[-1]]))
243+
sliders[-1].observe(partial(update, key=key, label=labels[-1]), names=["value"])
244+
return sliders, labels, controls
245+
246+
247+
def extract_num_options(val):
248+
"""
249+
convert a categorical to a number of options
250+
"""
251+
if len(val) == 1:
252+
for v in val:
253+
if isinstance(v, tuple):
254+
# this looks nightmarish...
255+
# but i think it should always work
256+
# should also check if the tuple has length one here.
257+
# that will only be an issue if a trailing comma was used to make the tuple ('beep',)
258+
# but not ('beep') - the latter is not actually a tuple
259+
return len(v)
260+
else:
261+
return 0
262+
else:
263+
return len(val)
264+
265+
266+
def changeify(val, key, update):
267+
"""
268+
make matplotlib update functions return a dict with key 'new'.
269+
Do this for compatibility with ipywidgets
270+
"""
271+
update({"new": val}, key, None)
272+
273+
274+
# this is a bunch of hacky nonsense
275+
# making it involved me holding a ruler up to my monitor
276+
# if you have a better solution I would love to hear about it :)
277+
# - Ian 2020-08-22
278+
def kwargs_to_mpl_widgets(kwargs, params, update, slider_format_strings):
279+
n_opts = 0
280+
n_radio = 0
281+
n_sliders = 0
282+
for key, val in kwargs.items():
283+
if isinstance(val, set):
284+
new_opts = extract_num_options(val)
285+
if new_opts > 0:
286+
n_radio += 1
287+
n_opts += new_opts
288+
elif (
289+
not isinstance(val, mwidgets.AxesWidget)
290+
and not isinstance(val, widgets.fixed)
291+
and isinstance(val, Iterable)
292+
and len(val) > 1
293+
):
294+
n_sliders += 1
295+
296+
# These are roughly the sizes used in the matplotlib widget tutorial
297+
# https://matplotlib.org/3.2.2/gallery/widgets/slider_demo.html#sphx-glr-gallery-widgets-slider-demo-py
298+
slider_in = 0.15
299+
radio_in = 0.6 / 3
300+
widget_gap_in = 0.1
301+
302+
widget_inches = (
303+
n_sliders * slider_in + n_opts * radio_in + widget_gap_in * (n_sliders + n_radio + 1) + 0.5
304+
) # half an inch for margin
305+
fig = None
306+
if not all(map(lambda x: isinstance(x, mwidgets.AxesWidget), kwargs.values())):
307+
# if the only kwargs are existing matplotlib widgets don't make a new figure
308+
with ioff:
309+
fig = figure()
310+
size = fig.get_size_inches()
311+
fig_h = widget_inches
312+
fig.set_size_inches(size[0], widget_inches)
313+
slider_height = slider_in / fig_h
314+
radio_height = radio_in / fig_h
315+
# radio
316+
gap_height = widget_gap_in / fig_h
317+
widget_y = 0.05
318+
slider_ax = []
319+
sliders = []
320+
radio_ax = []
321+
radio_buttons = []
322+
cbs = []
323+
for key, val in kwargs.items():
324+
if isinstance(val, set):
325+
if len(val) == 1:
326+
val = val.pop()
327+
if isinstance(val, tuple):
328+
pass
329+
else:
330+
params[key] = val
331+
continue
332+
else:
333+
val = list(val)
334+
335+
n = len(val)
336+
longest_len = max(list(map(lambda x: len(list(x)), map(str, val))))
337+
# should probably use something based on fontsize rather that .015
338+
width = max(0.15, 0.015 * longest_len)
339+
radio_ax.append(axes([0.2, 0.9 - widget_y - radio_height * n, width, radio_height * n]))
340+
widget_y += radio_height * n + gap_height
341+
radio_buttons.append(mwidgets.RadioButtons(radio_ax[-1], val, active=0))
342+
cbs.append(radio_buttons[-1].on_clicked(partial(changeify, key=key, update=update)))
343+
params[key] = val[0]
344+
elif isinstance(val, mwidgets.RadioButtons):
345+
val.on_clicked(partial(changeify, key=key, update=update))
346+
params[key] = val.val
347+
elif isinstance(val, mwidgets.Slider):
348+
val.on_changed(partial(changeify, key=key, update=update))
349+
params[key] = val.val
350+
else:
351+
if isinstance(val, tuple):
352+
if len(val) == 2:
353+
min_ = val[0]
354+
max_ = val[1]
355+
elif len(val) == 3:
356+
# should warn that that doesn't make sense with matplotlib sliders
357+
min_ = val[0]
358+
max_ = val[1]
359+
else:
360+
val = np.atleast_1d(val)
361+
if val.ndim > 1:
362+
raise ValueError(f"{key} is {val.ndim}D but can only be 1D or a scalar")
363+
if len(val) == 1:
364+
# don't need to create a slider
365+
params[key] = val[0]
366+
continue
367+
else:
368+
# list or numpy array
369+
# should warn here as well
370+
min_ = np.min(val)
371+
max_ = np.max(val)
372+
373+
slider_ax.append(axes([0.2, 0.9 - widget_y - gap_height, 0.65, slider_height]))
374+
sliders.append(
375+
mwidgets.Slider(
376+
slider_ax[-1],
377+
key,
378+
min_,
379+
max_,
380+
valinit=min_,
381+
valfmt=slider_format_strings[key],
382+
)
383+
)
384+
cbs.append(sliders[-1].on_changed(partial(changeify, key=key, update=update)))
385+
widget_y += slider_height + gap_height
386+
params[key] = min_
387+
controls = [fig, radio_ax, radio_buttons, slider_ax, sliders]
388+
return controls
389+
390+
391+
def create_slider_format_dict(slider_format_string, use_ipywidgets):
392+
# mpl sliders for verison 3.3 and onwards support None as an argument for valfmt
393+
mpl_gr_33 = version.parse(mpl_version) >= version.parse("3.3")
394+
if isinstance(slider_format_string, str):
395+
slider_format_strings = defaultdict(lambda: slider_format_string)
396+
elif isinstance(slider_format_string, dict) or slider_format_string is None:
397+
if use_ipywidgets:
398+
slider_format_strings = defaultdict(lambda: "{:.2f}")
399+
elif mpl_gr_33:
400+
slider_format_strings = defaultdict(lambda: None)
401+
else:
402+
slider_format_strings = defaultdict(lambda: "%1.2f")
403+
404+
if slider_format_string is not None:
405+
for key, val in slider_format_string.items():
406+
slider_format_strings[key] = val
407+
else:
408+
raise ValueError(
409+
f"slider_format_string must be a dict or a string but it is a {type(slider_format_string)}"
410+
)
411+
return slider_format_strings
412+
413+
414+
def gogogo_figure(ipympl, figsize, ax=None):
415+
"""
416+
gogogo the greatest function name of all
417+
"""
418+
if ax is None:
419+
if ipympl:
420+
with ioff:
421+
fig = figure(figsize=figsize)
422+
ax = fig.gca()
423+
else:
424+
fig = figure(figsize=figsize)
425+
ax = fig.gca()
426+
return fig, ax
427+
else:
428+
return ax.get_figure(), ax
429+
430+
431+
def gogogo_display(ipympl, use_ipywidgets, display, controls, fig):
432+
if use_ipywidgets:
433+
controls = widgets.VBox(controls)
434+
if display:
435+
if ipympl:
436+
ipy_display(widgets.VBox([controls, fig.canvas]))
437+
else:
438+
# for the case of using %matplotlib qt
439+
# but also want ipywidgets sliders
440+
# ie with force_ipywidgets = True
441+
ipy_display(controls)
442+
fig.show()
443+
else:
444+
if display:
445+
fig.show()
446+
controls[0].show()
447+
return controls

0 commit comments

Comments
 (0)