|
1 | | -from collections.abc import Callable |
| 1 | +from collections import defaultdict |
| 2 | +from collections.abc import Callable, Iterable |
| 3 | +from functools import partial |
2 | 4 | from numbers import Number |
3 | 5 |
|
| 6 | +import ipywidgets as widgets |
| 7 | +import matplotlib.widgets as mwidgets |
4 | 8 | import numpy as np |
| 9 | +from IPython.display import display as ipy_display |
| 10 | +from matplotlib import __version__ as mpl_version |
5 | 11 | from matplotlib import get_backend |
| 12 | +from matplotlib.pyplot import axes |
6 | 13 | from numpy.distutils.misc_util import is_sequence |
| 14 | +from packaging import version |
| 15 | + |
| 16 | +from .utils import figure, ioff |
7 | 17 |
|
8 | 18 | __all__ = [ |
9 | 19 | "decompose_bbox", |
|
16 | 26 | "broadcast_many", |
17 | 27 | "notebook_backend", |
18 | 28 | "callable_else_value", |
| 29 | + "kwargs_to_ipywidgets", |
| 30 | + "extract_num_options", |
| 31 | + "changeify", |
| 32 | + "kwargs_to_mpl_widgets", |
| 33 | + "create_slider_format_dict", |
| 34 | + "gogogo_figure", |
| 35 | + "gogogo_display", |
19 | 36 | ] |
20 | 37 |
|
21 | 38 |
|
@@ -160,3 +177,271 @@ def callable_else_value(arg, params): |
160 | 177 | if isinstance(arg, Callable): |
161 | 178 | return arg(**params) |
162 | 179 | return arg |
| 180 | + |
| 181 | + |
| 182 | +def kwargs_to_ipywidgets(kwargs, params, update, slider_format_strings): |
| 183 | + """ |
| 184 | + this will break if you pass a matplotlib slider. I suppose it could support mixed types of sliders |
| 185 | + but that doesn't really seem worthwhile? |
| 186 | + """ |
| 187 | + labels = [] |
| 188 | + sliders = [] |
| 189 | + controls = [] |
| 190 | + for key, val in kwargs.items(): |
| 191 | + if isinstance(val, set): |
| 192 | + if len(val) == 1: |
| 193 | + val = val.pop() |
| 194 | + if isinstance(val, tuple): |
| 195 | + # want the categories to be ordered |
| 196 | + pass |
| 197 | + else: |
| 198 | + # fixed parameter |
| 199 | + params[key] = val |
| 200 | + else: |
| 201 | + val = list(val) |
| 202 | + |
| 203 | + # categorical |
| 204 | + if len(val) <= 3: |
| 205 | + selector = widgets.RadioButtons(options=val) |
| 206 | + else: |
| 207 | + selector = widgets.Select(options=val) |
| 208 | + params[key] = val[0] |
| 209 | + controls.append(selector) |
| 210 | + selector.observe(partial(update, key=key, label=None), names=["value"]) |
| 211 | + elif isinstance(val, widgets.Widget) or isinstance(val, widgets.fixed): |
| 212 | + if not hasattr(val, "value"): |
| 213 | + raise TypeError( |
| 214 | + "widgets passed as parameters must have the `value` trait." |
| 215 | + "But the widget passed for {key} does not have a `.value` attribute" |
| 216 | + ) |
| 217 | + if isinstance(val, widgets.fixed): |
| 218 | + params[key] = val.value |
| 219 | + else: |
| 220 | + params[key] = val.value |
| 221 | + controls.append(val) |
| 222 | + val.observe(partial(update, key=key, label=None), names=["value"]) |
| 223 | + else: |
| 224 | + if isinstance(val, tuple) and len(val) in [2, 3]: |
| 225 | + # treat as an argument to linspace |
| 226 | + # idk if it's acceptable to overwrite kwargs like this |
| 227 | + # but I think at this point kwargs is just a dict like any other |
| 228 | + val = np.linspace(*val) |
| 229 | + kwargs[key] = val |
| 230 | + val = np.atleast_1d(val) |
| 231 | + if val.ndim > 1: |
| 232 | + raise ValueError(f"{key} is {val.ndim}D but can only be 1D or a scalar") |
| 233 | + if len(val) == 1: |
| 234 | + # don't need to create a slider |
| 235 | + params[key] = val |
| 236 | + else: |
| 237 | + params[key] = val[0] |
| 238 | + labels.append(widgets.Label(value=slider_format_strings[key].format(val[0]))) |
| 239 | + sliders.append( |
| 240 | + widgets.IntSlider(min=0, max=val.size - 1, readout=False, description=key) |
| 241 | + ) |
| 242 | + controls.append(widgets.HBox([sliders[-1], labels[-1]])) |
| 243 | + sliders[-1].observe(partial(update, key=key, label=labels[-1]), names=["value"]) |
| 244 | + return sliders, labels, controls |
| 245 | + |
| 246 | + |
| 247 | +def extract_num_options(val): |
| 248 | + """ |
| 249 | + convert a categorical to a number of options |
| 250 | + """ |
| 251 | + if len(val) == 1: |
| 252 | + for v in val: |
| 253 | + if isinstance(v, tuple): |
| 254 | + # this looks nightmarish... |
| 255 | + # but i think it should always work |
| 256 | + # should also check if the tuple has length one here. |
| 257 | + # that will only be an issue if a trailing comma was used to make the tuple ('beep',) |
| 258 | + # but not ('beep') - the latter is not actually a tuple |
| 259 | + return len(v) |
| 260 | + else: |
| 261 | + return 0 |
| 262 | + else: |
| 263 | + return len(val) |
| 264 | + |
| 265 | + |
| 266 | +def changeify(val, key, update): |
| 267 | + """ |
| 268 | + make matplotlib update functions return a dict with key 'new'. |
| 269 | + Do this for compatibility with ipywidgets |
| 270 | + """ |
| 271 | + update({"new": val}, key, None) |
| 272 | + |
| 273 | + |
| 274 | +# this is a bunch of hacky nonsense |
| 275 | +# making it involved me holding a ruler up to my monitor |
| 276 | +# if you have a better solution I would love to hear about it :) |
| 277 | +# - Ian 2020-08-22 |
| 278 | +def kwargs_to_mpl_widgets(kwargs, params, update, slider_format_strings): |
| 279 | + n_opts = 0 |
| 280 | + n_radio = 0 |
| 281 | + n_sliders = 0 |
| 282 | + for key, val in kwargs.items(): |
| 283 | + if isinstance(val, set): |
| 284 | + new_opts = extract_num_options(val) |
| 285 | + if new_opts > 0: |
| 286 | + n_radio += 1 |
| 287 | + n_opts += new_opts |
| 288 | + elif ( |
| 289 | + not isinstance(val, mwidgets.AxesWidget) |
| 290 | + and not isinstance(val, widgets.fixed) |
| 291 | + and isinstance(val, Iterable) |
| 292 | + and len(val) > 1 |
| 293 | + ): |
| 294 | + n_sliders += 1 |
| 295 | + |
| 296 | + # These are roughly the sizes used in the matplotlib widget tutorial |
| 297 | + # https://matplotlib.org/3.2.2/gallery/widgets/slider_demo.html#sphx-glr-gallery-widgets-slider-demo-py |
| 298 | + slider_in = 0.15 |
| 299 | + radio_in = 0.6 / 3 |
| 300 | + widget_gap_in = 0.1 |
| 301 | + |
| 302 | + widget_inches = ( |
| 303 | + n_sliders * slider_in + n_opts * radio_in + widget_gap_in * (n_sliders + n_radio + 1) + 0.5 |
| 304 | + ) # half an inch for margin |
| 305 | + fig = None |
| 306 | + if not all(map(lambda x: isinstance(x, mwidgets.AxesWidget), kwargs.values())): |
| 307 | + # if the only kwargs are existing matplotlib widgets don't make a new figure |
| 308 | + with ioff: |
| 309 | + fig = figure() |
| 310 | + size = fig.get_size_inches() |
| 311 | + fig_h = widget_inches |
| 312 | + fig.set_size_inches(size[0], widget_inches) |
| 313 | + slider_height = slider_in / fig_h |
| 314 | + radio_height = radio_in / fig_h |
| 315 | + # radio |
| 316 | + gap_height = widget_gap_in / fig_h |
| 317 | + widget_y = 0.05 |
| 318 | + slider_ax = [] |
| 319 | + sliders = [] |
| 320 | + radio_ax = [] |
| 321 | + radio_buttons = [] |
| 322 | + cbs = [] |
| 323 | + for key, val in kwargs.items(): |
| 324 | + if isinstance(val, set): |
| 325 | + if len(val) == 1: |
| 326 | + val = val.pop() |
| 327 | + if isinstance(val, tuple): |
| 328 | + pass |
| 329 | + else: |
| 330 | + params[key] = val |
| 331 | + continue |
| 332 | + else: |
| 333 | + val = list(val) |
| 334 | + |
| 335 | + n = len(val) |
| 336 | + longest_len = max(list(map(lambda x: len(list(x)), map(str, val)))) |
| 337 | + # should probably use something based on fontsize rather that .015 |
| 338 | + width = max(0.15, 0.015 * longest_len) |
| 339 | + radio_ax.append(axes([0.2, 0.9 - widget_y - radio_height * n, width, radio_height * n])) |
| 340 | + widget_y += radio_height * n + gap_height |
| 341 | + radio_buttons.append(mwidgets.RadioButtons(radio_ax[-1], val, active=0)) |
| 342 | + cbs.append(radio_buttons[-1].on_clicked(partial(changeify, key=key, update=update))) |
| 343 | + params[key] = val[0] |
| 344 | + elif isinstance(val, mwidgets.RadioButtons): |
| 345 | + val.on_clicked(partial(changeify, key=key, update=update)) |
| 346 | + params[key] = val.val |
| 347 | + elif isinstance(val, mwidgets.Slider): |
| 348 | + val.on_changed(partial(changeify, key=key, update=update)) |
| 349 | + params[key] = val.val |
| 350 | + else: |
| 351 | + if isinstance(val, tuple): |
| 352 | + if len(val) == 2: |
| 353 | + min_ = val[0] |
| 354 | + max_ = val[1] |
| 355 | + elif len(val) == 3: |
| 356 | + # should warn that that doesn't make sense with matplotlib sliders |
| 357 | + min_ = val[0] |
| 358 | + max_ = val[1] |
| 359 | + else: |
| 360 | + val = np.atleast_1d(val) |
| 361 | + if val.ndim > 1: |
| 362 | + raise ValueError(f"{key} is {val.ndim}D but can only be 1D or a scalar") |
| 363 | + if len(val) == 1: |
| 364 | + # don't need to create a slider |
| 365 | + params[key] = val[0] |
| 366 | + continue |
| 367 | + else: |
| 368 | + # list or numpy array |
| 369 | + # should warn here as well |
| 370 | + min_ = np.min(val) |
| 371 | + max_ = np.max(val) |
| 372 | + |
| 373 | + slider_ax.append(axes([0.2, 0.9 - widget_y - gap_height, 0.65, slider_height])) |
| 374 | + sliders.append( |
| 375 | + mwidgets.Slider( |
| 376 | + slider_ax[-1], |
| 377 | + key, |
| 378 | + min_, |
| 379 | + max_, |
| 380 | + valinit=min_, |
| 381 | + valfmt=slider_format_strings[key], |
| 382 | + ) |
| 383 | + ) |
| 384 | + cbs.append(sliders[-1].on_changed(partial(changeify, key=key, update=update))) |
| 385 | + widget_y += slider_height + gap_height |
| 386 | + params[key] = min_ |
| 387 | + controls = [fig, radio_ax, radio_buttons, slider_ax, sliders] |
| 388 | + return controls |
| 389 | + |
| 390 | + |
| 391 | +def create_slider_format_dict(slider_format_string, use_ipywidgets): |
| 392 | + # mpl sliders for verison 3.3 and onwards support None as an argument for valfmt |
| 393 | + mpl_gr_33 = version.parse(mpl_version) >= version.parse("3.3") |
| 394 | + if isinstance(slider_format_string, str): |
| 395 | + slider_format_strings = defaultdict(lambda: slider_format_string) |
| 396 | + elif isinstance(slider_format_string, dict) or slider_format_string is None: |
| 397 | + if use_ipywidgets: |
| 398 | + slider_format_strings = defaultdict(lambda: "{:.2f}") |
| 399 | + elif mpl_gr_33: |
| 400 | + slider_format_strings = defaultdict(lambda: None) |
| 401 | + else: |
| 402 | + slider_format_strings = defaultdict(lambda: "%1.2f") |
| 403 | + |
| 404 | + if slider_format_string is not None: |
| 405 | + for key, val in slider_format_string.items(): |
| 406 | + slider_format_strings[key] = val |
| 407 | + else: |
| 408 | + raise ValueError( |
| 409 | + f"slider_format_string must be a dict or a string but it is a {type(slider_format_string)}" |
| 410 | + ) |
| 411 | + return slider_format_strings |
| 412 | + |
| 413 | + |
| 414 | +def gogogo_figure(ipympl, figsize, ax=None): |
| 415 | + """ |
| 416 | + gogogo the greatest function name of all |
| 417 | + """ |
| 418 | + if ax is None: |
| 419 | + if ipympl: |
| 420 | + with ioff: |
| 421 | + fig = figure(figsize=figsize) |
| 422 | + ax = fig.gca() |
| 423 | + else: |
| 424 | + fig = figure(figsize=figsize) |
| 425 | + ax = fig.gca() |
| 426 | + return fig, ax |
| 427 | + else: |
| 428 | + return ax.get_figure(), ax |
| 429 | + |
| 430 | + |
| 431 | +def gogogo_display(ipympl, use_ipywidgets, display, controls, fig): |
| 432 | + if use_ipywidgets: |
| 433 | + controls = widgets.VBox(controls) |
| 434 | + if display: |
| 435 | + if ipympl: |
| 436 | + ipy_display(widgets.VBox([controls, fig.canvas])) |
| 437 | + else: |
| 438 | + # for the case of using %matplotlib qt |
| 439 | + # but also want ipywidgets sliders |
| 440 | + # ie with force_ipywidgets = True |
| 441 | + ipy_display(controls) |
| 442 | + fig.show() |
| 443 | + else: |
| 444 | + if display: |
| 445 | + fig.show() |
| 446 | + controls[0].show() |
| 447 | + return controls |
0 commit comments