diff --git a/docs/changelog.md b/docs/changelog.md index c6e4dba..8b8eee7 100644 --- a/docs/changelog.md +++ b/docs/changelog.md @@ -2,6 +2,8 @@ ## 0.4.1 (_unreleased_) +- Improve plotting: Add possibility to use `explore` on Datasets. In such a case, the user can select a data variable from a dropdown menu. Use Coordinates for the dimension sliders. Add colorbars. Add slider animations. Use the new H3Layer for H3 grids. ({pull}`197`) + ## 0.4.0 (2025-11-03) - support interactive facet plots and combining maps ({pull}`183`) diff --git a/docs/tutorials/h3.ipynb b/docs/tutorials/h3.ipynb index c19d690..81552b2 100644 --- a/docs/tutorials/h3.ipynb +++ b/docs/tutorials/h3.ipynb @@ -195,7 +195,7 @@ "metadata": {}, "outputs": [], "source": [ - "ds[\"air\"].dggs.explore()" + "ds.dggs.explore()" ] }, { @@ -217,6 +217,88 @@ " alpha=0.8, cmap=\"coolwarm\", center=273.15\n", ")" ] + }, + { + "cell_type": "markdown", + "id": "18", + "metadata": {}, + "source": [ + "We can also explore Datasets with variables consisting of different dimensions and configure them separately!" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "19", + "metadata": {}, + "outputs": [], + "source": [ + "ds[\"air_anomaly\"] = ds[\"air\"].groupby(\"time.month\") - ds[\"air\"].groupby(\n", + " \"time.month\"\n", + ").quantile([0.25, 0.5, 0.75], dim=\"time\")\n", + "ds[\"air_anomaly\"].attrs = {\"long_name\": \"Air Temperature Anomaly\", \"units\": \"K\"}\n", + "ds[\"max_air_anomaly\"] = ds[\"air_anomaly\"].max(dim=[\"time\", \"quantile\"])\n", + "ds[\"max_air_anomaly\"].attrs = {\n", + " \"long_name\": \"Maximum Air Temperature Anomaly\",\n", + " \"units\": \"K\",\n", + "}\n", + "ds.dggs.explore(\n", + " air_anomaly={\"center\": 0.0, \"cmap\": \"coolwarm\"},\n", + " max_air_anomaly={\"vmax\": 20},\n", + " robust=True,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "20", + "metadata": {}, + "source": [ + "By passing a `model_kwargs`, the underlying lonboard Map can be configured. This way it is possible to create globe views:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "21", + "metadata": {}, + "outputs": [], + "source": [ + "import ipywidgets\n", + "\n", + "s = ipywidgets.IntSlider(description=\"Test Slider\", min=0, max=10, value=5)\n", + "\n", + "a = {s: 0}\n", + "a[s]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "22", + "metadata": {}, + "outputs": [], + "source": [ + "from lonboard.basemap import MaplibreBasemap\n", + "from lonboard.experimental.view import GlobeView\n", + "\n", + "basemap = MaplibreBasemap(mode=\"interleaved\")\n", + "view = GlobeView()\n", + "ds.dggs.explore(\n", + " air_anomaly={\"center\": 0.0, \"cmap\": \"coolwarm\"},\n", + " max_air_anomaly={\"vmax\": 20},\n", + " robust=True,\n", + " map_kwargs={\"view\": view, \"basemap\": basemap},\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "23", + "metadata": {}, + "outputs": [], + "source": [] } ], "metadata": { @@ -230,7 +312,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.13.7" + "version": "3.13.9" } }, "nbformat": 4, diff --git a/xdggs/accessor.py b/xdggs/accessor.py index dc376fc..6f03df1 100644 --- a/xdggs/accessor.py +++ b/xdggs/accessor.py @@ -1,3 +1,5 @@ +from typing import TYPE_CHECKING + import numpy.typing as npt import xarray as xr @@ -5,6 +7,9 @@ from xdggs.index import DGGSIndex from xdggs.plotting import explore +if TYPE_CHECKING: + from matplotlib.colors import Colormap + @xr.register_dataset_accessor("dggs") @xr.register_dataarray_accessor("dggs") @@ -208,38 +213,86 @@ def zoom_to(self, level: int): return xr.DataArray(zoomed, coords={self._name: self.cell_ids}, dims=dims) - def explore(self, *, cmap="viridis", center=None, alpha=None, coords=None): - """interactively explore the data using `lonboard` + def explore( + self, + *, + coords: float | None = None, + cmap: "str | Colormap | dict[str, str | Colormap]" = "viridis", + alpha: float | None = None, + center: float | dict[str, float] | None = None, + vmin: float | dict[str, float] | None = None, + vmax: float | dict[str, float] | None = None, + robust: bool = False, + map_kwargs: dict = {}, + **coloring_kwargs, + ): + """Interactively explore the data using `lonboard`. Requires `lonboard`, `matplotlib`, and `arro3.core` to be installed. Parameters ---------- - cmap : str - The name of the color map to use - center : int or float, optional - If set, will use this as the center value of a diverging color map. - alpha : float, optional - If set, controls the transparency of the polygons. coords : list of str, default: ["latitude", "longitude"] Additional coordinates to contain in the table of contents. + cmap : str or Colormap or dict[str, str or Colormap], default: "viridis" + The name of the color map to use. If a dict is provided, it can map variable + names to specific color maps. + alpha : float, optional + If set, controls the transparency of the polygons. + center : int or float or dict[str, float], optional + If set, will use this as the center value of a diverging color map. + Similar to cmap, can be a dict mapping variable names to center values. + vmin : float or dict[str, float], optional + If set, will use this as the minimum value for colormap normalization. + Similar to cmap, can be a dict mapping variable names to minimum values. + vmax : float or dict[str, float], optional + If set, will use this as the maximum value for colormap normalization. + Similar to cmap, can be a dict mapping variable names to maximum values. + robust : bool, default: False + If True, the colormap range is computed with robust quantiles (2nd and 98th percentile) + instead of the actual min and max of the data. + This is ignored if vmin and/or vmax are set. + map_kwargs : dict, optional + Additional keyword arguments are forwarded to `lonboard.Map`. + coloring_kwargs : dict, optional + Cmap, center, vmin and vmax can also be set as dictionary entries for each variable by this. + E.g. `coloring_kwargs={"air_anomaly": {"cmap": "coolwarm", "center": 0.0}}` would result in + the same as setting `cmap={"air_anomaly": "coolwarm"}` and `center={"air_anomaly": 0.0}`. Returns ------- map : lonboard.Map The rendered map. - Notes - ----- - Plotting currently is restricted to 1D `DataArray` objects. """ - if isinstance(self._obj, xr.Dataset): - raise ValueError("does not work with Dataset objects, yet") - + if coloring_kwargs: + # Manually building the dicts to override the function arguments + if not isinstance(cmap, dict): + cmap = dict.fromkeys(self._obj.data_vars, cmap) + if not isinstance(center, dict): + center = dict.fromkeys(self._obj.data_vars, center) + if not isinstance(vmin, dict): + vmin = dict.fromkeys(self._obj.data_vars, vmin) + if not isinstance(vmax, dict): + vmax = dict.fromkeys(self._obj.data_vars, vmax) + # Now all color_kwargs are dicts + for data_var, params in coloring_kwargs.items(): + if "cmap" in params: + cmap[data_var] = params["cmap"] + if "center" in params: + center[data_var] = params["center"] + if "vmin" in params: + vmin[data_var] = params["vmin"] + if "vmax" in params: + vmax[data_var] = params["vmax"] return explore( self._obj, + coords=coords, cmap=cmap, - center=center, alpha=alpha, - coords=coords, + center=center, + vmin=vmin, + vmax=vmax, + robust=robust, + **map_kwargs, ) diff --git a/xdggs/plotting.py b/xdggs/plotting.py index 870c914..c0927b2 100644 --- a/xdggs/plotting.py +++ b/xdggs/plotting.py @@ -1,54 +1,549 @@ from __future__ import annotations +import threading from dataclasses import dataclass from functools import partial -from typing import Any +from io import BytesIO +from typing import TYPE_CHECKING import ipywidgets import numpy as np import xarray as xr from lonboard import BaseLayer, Map +from xdggs.h3 import H3Info -def on_slider_change(change, container): - owner = change["owner"] - dim = owner.description +if TYPE_CHECKING: + from lonboard import Map as LonboardMap + from matplotlib.colors import CenteredNorm, Colormap, Normalize - indexers = { - slider.description: slider.value - for slider in container.dimension_sliders.children - if slider.description != dim - } | {dim: change["new"]} - new_slice = container.obj.isel(indexers) - colors = colorize(new_slice.variable, **container.colorize_kwargs) +@dataclass +class Colorizer: + colormap: Colormap + normalizer: CenteredNorm | Normalize + alpha: float | None = None + + @staticmethod + def _get_normalizer( + data, + center: float | None = None, + vmin: float | None = None, + vmax: float | None = None, + robust: bool = False, + ) -> CenteredNorm | Normalize: + from matplotlib.colors import CenteredNorm, Normalize + + # Logic: If one or both of vmin and vmax are set, use them. + # If one is not set, compute it from the data depending on robust flag. + # If neither is set, try to use center if provided. + # If center is not provided, use min and max of data, depending on robust flag. + # Robust flag means using the 2nd and 98th percentiles instead of min and max. + if vmin is not None or vmax is not None: + if vmin is None: + if robust: + vmin = np.nanpercentile(data, 2) + else: + vmin = np.nanmin(data) + if vmax is None: + if robust: + vmax = np.nanpercentile(data, 98) + else: + vmax = np.nanmax(data) + normalizer = Normalize(vmin=vmin, vmax=vmax) + elif center is not None: + if robust: + halfrange = np.abs(data - center).quantile(0.98) + else: + halfrange = np.abs(data - center).max(skipna=True) + normalizer = CenteredNorm(vcenter=center, halfrange=halfrange) + else: + if robust: + vmin = np.nanpercentile(data, 2) + vmax = np.nanpercentile(data, 98) + else: + vmin = np.nanmin(data) + vmax = np.nanmax(data) + normalizer = Normalize(vmin=vmin, vmax=vmax) + + return normalizer + + @classmethod + def for_dataset( + cls, + var_name: str, + data: xr.DataArray, + cmap: str | Colormap | dict[str, str | Colormap] = "viridis", + alpha: float | None = None, + center: float | dict[str, float] | None = None, + vmin: float | dict[str, float] | None = None, + vmax: float | dict[str, float] | None = None, + robust: bool = False, + ): + from matplotlib import colormaps - layer = container.map.layers[0] - layer.get_fill_color = colors + if isinstance(cmap, dict): + current_cmap = cmap.get(var_name, "viridis") + else: + current_cmap = cmap + if isinstance(center, dict): + current_center = center.get(var_name, None) + else: + current_center = center + if isinstance(vmin, dict): + current_vmin = vmin.get(var_name, None) + else: + current_vmin = vmin + if isinstance(vmax, dict): + current_vmax = vmax.get(var_name, None) + else: + current_vmax = vmax + + colormap = ( + colormaps[current_cmap] if isinstance(current_cmap, str) else current_cmap + ) + + normalizer = cls._get_normalizer( + data, + center=current_center, + vmin=current_vmin, + vmax=current_vmax, + robust=robust, + ) + + return cls( + colormap=colormap, + normalizer=normalizer, + alpha=alpha, + ) + + @classmethod + def for_dataarray( + cls, + data: xr.DataArray, + cmap: str | Colormap = "viridis", + alpha: float | None = None, + center: float | None = None, + vmin: float | None = None, + vmax: float | None = None, + robust: bool = False, + ): + from matplotlib import colormaps + + colormap = colormaps[cmap] if isinstance(cmap, str) else cmap + + normalizer = cls._get_normalizer( + data, + center=center, + vmin=vmin, + vmax=vmax, + robust=robust, + ) + + return cls( + colormap=colormap, + normalizer=normalizer, + alpha=alpha, + ) + + def get_cmap_preview(self, label: str): + import matplotlib.pyplot as plt + + sm = plt.cm.ScalarMappable(cmap=self.colormap, norm=self.normalizer) + fig, ax = plt.subplots(figsize=(9, 0.25)) + fig.colorbar(sm, cax=ax, orientation="horizontal", label=label) + return fig, ax + + def colorize(self, data): + from lonboard.colormap import apply_continuous_cmap + + normalized_data = self.normalizer(data) + + return apply_continuous_cmap( + normalized_data, + self.colormap, + alpha=self.alpha, + ) + + +def create_slider_widget(arr, dim): + # If the dimension has coordinates, use them as labels + # Otherwise, use integer indices + style = {"description_width": "0px"} + layout = ipywidgets.Layout(min_width="300px") + + if dim in arr.coords: + # Use a Float Slider for numeric coordinates + # Use a Select Slider for non-numeric coordinates, e.g. time or strings + coord_values = arr.coords[dim].data + if np.issubdtype(coord_values.dtype, np.number): + slider = ipywidgets.FloatSlider( + min=float(coord_values.min()), + max=float(coord_values.max()), + step=float(np.diff(np.unique(coord_values)).min()), + description=dim, + continuous_update=False, + style=style, + layout=layout, + ) + else: + # TODO: Use format for datetime display? + slider = ipywidgets.SelectionSlider( + options=list(coord_values), + description=dim, + continuous_update=False, + style=style, + layout=layout, + ) + + else: + slider = ipywidgets.IntSlider( + min=0, + max=arr.sizes[dim] - 1, + description=dim, + continuous_update=False, + style=style, + layout=layout, + ) + + return slider + + +class SliderPlayer: + """Manages play/pause functionality for a single slider.""" + + def __init__(self, slider: ipywidgets.Widget, dim: str, interval: float = 0.5): + """ + Initialize a slider player. + + Parameters + ---------- + slider : ipywidgets.Widget + The slider widget to control (IntSlider, FloatSlider, or SelectionSlider) + dim : str + The dimension name + interval : float + Time in seconds between steps when playing + """ + self.slider = slider + self.dim = dim + self.interval = interval + self.is_playing = False + self._thread = None + self._stop_event = threading.Event() + + # Create play/pause button + self.play_button = ipywidgets.Button( + description="▶", + layout=ipywidgets.Layout(width="40px"), + tooltip="Play/Pause", + ) + self.play_button.on_click(self._toggle_play) + + def _toggle_play(self, button): + """Toggle between play and pause states.""" + if self.is_playing: + self.pause() + else: + self.play() + + def play(self): + """Start playing through slider values.""" + if self.is_playing: + return + + self.is_playing = True + self.play_button.description = "⏸" + self.play_button.tooltip = "Pause" + self._stop_event.clear() + + # Start the animation thread + self._thread = threading.Thread(target=self._animate, daemon=True) + self._thread.start() + + def pause(self): + """Pause the animation.""" + if not self.is_playing: + return + + self.is_playing = False + self.play_button.description = "▶" + self.play_button.tooltip = "Play" + self._stop_event.set() + + # Wait for thread to finish + if self._thread is not None: + self._thread.join(timeout=1.0) + self._thread = None + + def _animate(self): + """Animation loop that runs in a separate thread.""" + while not self._stop_event.is_set(): + # Get current value and determine next value + if isinstance(self.slider, ipywidgets.IntSlider): + current = self.slider.value + if current >= self.slider.max: + self.slider.value = self.slider.min + else: + self.slider.value = current + self.slider.step + elif isinstance(self.slider, ipywidgets.FloatSlider): + current = self.slider.value + if current >= self.slider.max: + self.slider.value = self.slider.min + else: + self.slider.value = min(current + self.slider.step, self.slider.max) + elif isinstance(self.slider, ipywidgets.SelectionSlider): + current_index = self.slider.index + if current_index >= len(self.slider.options) - 1: + self.slider.index = 0 + else: + self.slider.index = current_index + 1 + + # Wait for the specified interval + self._stop_event.wait(self.interval) + + def widget(self): + """Return a widget with the slider and play button.""" + return ipywidgets.VBox( + [ + ipywidgets.Label(value=self.dim, style={"font_weight": "bold"}), + ipywidgets.HBox( + [self.play_button, self.slider], + layout=ipywidgets.Layout(align_items="center", gap="5px"), + ), + ] + ) + + +def create_slider_with_player(arr, dim, interval: float = 0.5): + """ + Create a slider widget with play/pause controls. + + Parameters + ---------- + arr : xr.DataArray + The data array containing the dimension + dim : str + The dimension name + interval : float + Time in seconds between steps when playing + + Returns + ------- + SliderPlayer + A slider player instance with play/pause controls + """ + slider = create_slider_widget(arr, dim) + return SliderPlayer(slider, dim=dim, interval=interval) -@dataclass class MapContainer: - """container for the map, any control widgets and the data object""" + """Container for the map, any control widgets and the data object.""" - dimension_sliders: ipywidgets.VBox - map: Map - obj: xr.DataArray + def __init__( + self, + map_: LonboardMap, + obj: xr.DataArray | xr.Dataset, + colorizer_kwargs: dict, + play_interval: float = 0.5, + ): + self.map = map_ + self.obj = obj + self.colorizer_kwargs = colorizer_kwargs + self.play_interval = play_interval + + cell_id_coord = self.obj.dggs.coord + [cell_dim] = cell_id_coord.dims + self.cell_dim = cell_dim + + self.dvar_selector = None + if isinstance(obj, xr.Dataset): + self.dvar_selector = ipywidgets.Dropdown( + options=list(obj.data_vars), + description="Variable", + continuous_update=False, + ) + self.dvar_selector.observe(self.create_sliders, names="value") + + # This creates self.colorizer, self.dimension_sliders, self.dimension_indexers, self.dimension_selectors + self.create_sliders(None) + # Quick check so that future changes to the code will fail if these attributes are missing + assert hasattr(self, "data_label") + assert hasattr(self, "colorizer") + assert hasattr(self, "dimension_sliders") + assert hasattr(self, "slider_players") + assert hasattr(self, "dimension_indexers") + assert hasattr(self, "dimension_selectors") + assert hasattr(self, "control_box") + + def _get_colorizer(self, data: xr.DataArray): + if isinstance(self.obj, xr.Dataset): + assert self.dvar_selector is not None + selected_var = self.dvar_selector.value + colorizer = Colorizer.for_dataset( + selected_var, data, **self.colorizer_kwargs + ) + else: + colorizer = Colorizer.for_dataarray(data, **self.colorizer_kwargs) + return colorizer + + def _get_arr(self): + if isinstance(self.obj, xr.Dataset): + assert self.dvar_selector is not None + selected_var = self.dvar_selector.value + arr = self.obj[selected_var] + else: + arr = self.obj + return arr - colorize_kwargs: dict[str, Any] + def create_sliders(self, change): + # Pause any existing players before recreating sliders + if hasattr(self, "slider_players"): + for player in self.slider_players.values(): + player.pause() - def render(self): - # add any additional control widgets here - control_box = ipywidgets.HBox([self.dimension_sliders]) + arr = self._get_arr() + + # Update the label information + if "long_name" in arr.attrs: + self.data_label = arr.attrs["long_name"] + else: + self.data_label = arr.name or "data" + if "units" in arr.attrs: + self.data_label += f" ({arr.attrs['units']})" - return MapWithSliders( - [self.map, control_box], layout=ipywidgets.Layout(width="100%") + # Update the colorizer + self.colorizer = self._get_colorizer(arr) + + # Update sliders based on the new variable's dimensions + # Create slider players for dimensions with more than one value + self.slider_players = { + dim: create_slider_with_player(arr, dim, interval=self.play_interval) + for dim in arr.dims + if dim != self.cell_dim and arr.sizes[dim] > 1 + } + + # Store reference to the actual sliders for easier access + self.dimension_sliders = { + dim: player.slider for dim, player in self.slider_players.items() + } + + # Reset indexers and selectors + self.dimension_indexers = { + dim: 0 + for dim, slider in self.dimension_sliders.items() + if isinstance(slider, ipywidgets.IntSlider) + } + self.dimension_selectors = { + dim: slider.value + for dim, slider in self.dimension_sliders.items() + if not isinstance(slider, ipywidgets.IntSlider) + } + + # Reconnect slider change events + for slider in self.dimension_sliders.values(): + slider.observe(partial(self.recolorize), names="value") + + self.recolorize(arr=arr) + self.create_control_box() + + def recolorize(self, change=None, arr=None): + if arr is None: + arr = self._get_arr() + + if change is not None: + dim = change["owner"].description + if dim in self.dimension_indexers: + self.dimension_indexers[dim] = change["new"] + else: + self.dimension_selectors[dim] = change["new"] + if not self.dimension_indexers and not self.dimension_selectors: + # No indexing needed + new_slice = arr + else: + new_slice = arr.isel(self.dimension_indexers).sel(self.dimension_selectors) + colors = self.colorizer.colorize(new_slice.variable) + layer = self.map.layers[0] + layer.get_fill_color = colors + + def create_control_box(self): + import matplotlib.pyplot as plt + + control_widgets = [] + if self.dvar_selector is not None: + control_widgets.append(self.dvar_selector) + if len(self.slider_players): + # Create widgets with play buttons for each slider + slider_widgets = [ + player.widget() for player in self.slider_players.values() + ] + control_widgets.append( + ipywidgets.VBox( + slider_widgets, layout={"padding": "0 10px", "margin": "0 10px"} + ) + ) + + fig, _ax = self.colorizer.get_cmap_preview(self.data_label) + buf = BytesIO() + fig.savefig(buf, format="png", bbox_inches="tight", dpi=100) + buf.seek(0) + colorbar_widget = ipywidgets.Image(value=buf.read(), format="png") + buf.close() + plt.close(fig) + + # Create layout: controls on left, colorbar on right (wraps to new row if needed) + controls_box = ipywidgets.HBox( + control_widgets, + layout=ipywidgets.Layout( + flex="0 1 auto", min_width="fit-content", align_items="flex-start" + ), + ) + colorbar_box = ipywidgets.Box( + [colorbar_widget], + layout=ipywidgets.Layout( + flex="0 0 auto", + align_items="center", + max_width="500px", + overflow="visible", + ), + ) + + box_children = [controls_box, colorbar_box] + + if not hasattr(self, "control_box"): + # First time creation + self.control_box = ipywidgets.HBox( + box_children, + layout=ipywidgets.Layout( + width="100%", + height="auto", + align_items="flex-start", + padding="5px 0px 0px 0px", + flex_flow="row wrap", + justify_content="space-between", + overflow="visible", + ), + ) + else: + # Empty the existing box and refill + self.control_box.children = box_children + + def stop_all_players(self): + """Stop all slider players.""" + if hasattr(self, "slider_players"): + for player in self.slider_players.values(): + player.pause() + + def render(self): + return MapWithControls( + [self.map, self.control_box], + layout=ipywidgets.Layout(width="100%", overflow="hidden"), ) -def extract_maps(obj: MapGrid | MapWithSliders | Map): +def extract_maps(obj: MapGrid | MapWithControls | Map): if isinstance(obj, Map): - return obj + return (obj,) return getattr(obj, "maps", (obj.map,)) @@ -56,7 +551,7 @@ def extract_maps(obj: MapGrid | MapWithSliders | Map): class MapGrid(ipywidgets.GridBox): def __init__( self, - maps: MapWithSliders | Map = None, + maps: MapWithControls | Map = None, n_columns: int = 2, synchronize: bool = False, ): @@ -83,25 +578,25 @@ def __init__( def _replace_maps(self, maps): return type(self)(maps, n_columns=self.n_columns, synchronize=self.synchronize) - def add_map(self, map_: MapWithSliders | Map): + def add_map(self, map_: MapWithControls | Map): return self._replace_maps(self.maps + (map_,)) @property def maps(self): return self.children - def __or__(self, other: MapGrid | MapWithSliders | Map): + def __or__(self, other: MapGrid | MapWithControls | Map): other_maps = extract_maps(other) return self._replace_maps(self.maps + other_maps) - def __ror__(self, other: MapWithSliders | Map): + def __ror__(self, other: MapWithControls | Map): other_maps = extract_maps(other) return self._replace_maps(self.maps + other_maps) -class MapWithSliders(ipywidgets.VBox): +class MapWithControls(ipywidgets.VBox): def change_layout(self, layout): return type(self)(self.children, layout=layout) @@ -117,8 +612,8 @@ def map(self) -> Map: def layers(self) -> list[BaseLayer]: return self.map.layers - def __or__(self, other: MapWithSliders | Map): - [other_map] = extract_maps(other) + def __or__(self, other: MapWithControls | Map): + # [other_map] = extract_maps(other) return MapGrid([self, other], synchronize=True) @@ -132,16 +627,16 @@ def _merge(self, layers, sliders): if sliders: slider_widgets.extend(sliders) - widgets = [new_map] + widget_list = [new_map] if slider_widgets: - widgets.append(ipywidgets.HBox(slider_widgets)) + widget_list.append(ipywidgets.HBox(slider_widgets)) - return type(self)(widgets, layout=self.layout) + return type(self)(widget_list, layout=self.layout) def add_layer(self, layer: BaseLayer): self.map.add_layer(layer) - def __and__(self, other: MapWithSliders | Map | BaseLayer): + def __and__(self, other: MapWithControls | Map | BaseLayer): if isinstance(other, BaseLayer): layers = [other] sliders = [] @@ -176,80 +671,106 @@ def create_arrow_table(polygons, arr, coords=None): return Table.from_arrays(list(arrow_arrays.values()), schema=schema) -def normalize(var, center=None): - from matplotlib.colors import CenteredNorm, Normalize - - if center is None: - vmin = var.min(skipna=True) - vmax = var.max(skipna=True) - normalizer = Normalize(vmin=vmin, vmax=vmax) - else: - halfrange = np.abs(var - center).max(skipna=True) - normalizer = CenteredNorm(vcenter=center, halfrange=halfrange) - - return normalizer(var.data) - - -def colorize(var, *, center, colormap, alpha): - from lonboard.colormap import apply_continuous_cmap - - normalized_data = normalize(var, center=center) - - return apply_continuous_cmap(normalized_data, colormap, alpha=alpha) - - def explore( - arr, - cmap="viridis", - center=None, - alpha=None, - coords=None, + obj: xr.DataArray | xr.Dataset, + coords: float | None = None, + cmap: str | Colormap | dict[str, str | Colormap] = "viridis", + alpha: float | None = None, + center: float | dict[str, float] | None = None, + vmin: float | dict[str, float] | None = None, + vmax: float | dict[str, float] | None = None, + robust: bool = False, + play_interval: float = 0.5, + **map_kwargs, ): import lonboard - from lonboard import SolidPolygonLayer - from matplotlib import colormaps + from lonboard import H3HexagonLayer, SolidPolygonLayer # guaranteed to be 1D - cell_id_coord = arr.dggs.coord + cell_id_coord = obj.dggs.coord [cell_dim] = cell_id_coord.dims cell_ids = cell_id_coord.data - grid_info = arr.dggs.grid_info + grid_info = obj.dggs.grid_info polygons = grid_info.cell_boundaries(cell_ids, backend="geoarrow") + if isinstance(obj, xr.Dataset): + # pick first data variable + first_var = next(iter(obj.data_vars)) + arr = obj[first_var] + colorizer = Colorizer.for_dataset( + var_name=first_var, + data=arr, + cmap=cmap, + alpha=alpha, + center=center, + vmin=vmin, + vmax=vmax, + robust=robust, + ) + else: + assert not isinstance( + cmap, dict + ), "cmap cannot be a dict when obj is a DataArray" + assert not isinstance( + center, dict + ), "center cannot be a dict when obj is a DataArray" + assert not isinstance( + vmin, dict + ), "vmin cannot be a dict when obj is a DataArray" + assert not isinstance( + vmax, dict + ), "vmax cannot be a dict when obj is a DataArray" + arr = obj + colorizer = Colorizer.for_dataarray( + data=arr, + cmap=cmap, + alpha=alpha, + center=center, + vmin=vmin, + vmax=vmax, + robust=robust, + ) + initial_indexers = {dim: 0 for dim in arr.dims if dim != cell_dim} initial_arr = arr.isel(initial_indexers) - colormap = colormaps[cmap] if isinstance(cmap, str) else cmap - colors = colorize(initial_arr, center=center, alpha=alpha, colormap=colormap) - + fill_colors = colorizer.colorize(initial_arr.variable) table = create_arrow_table(polygons, initial_arr, coords=coords) - layer = SolidPolygonLayer(table=table, filled=True, get_fill_color=colors) - map_ = lonboard.Map(layer) + # Use the H3 Layer for H3 grid + if isinstance(grid_info, H3Info): + layer = H3HexagonLayer( + table=table, + get_hexagon=table["cell_ids"], + filled=True, + get_fill_color=fill_colors, + ) + else: + layer = SolidPolygonLayer(table=table, filled=True, get_fill_color=fill_colors) - if not initial_indexers: - # 1D data - return map_ + map_ = lonboard.Map(layer, **map_kwargs) - sliders = ipywidgets.VBox( - [ - ipywidgets.IntSlider(min=0, max=arr.sizes[dim] - 1, description=dim) - for dim in arr.dims - if dim != cell_dim - ] - ) + if not initial_indexers and ( + isinstance(arr, xr.DataArray) or len(arr.data_vars) == 1 + ): + # 1D data, special case, no sliders / selectors - no interactivity needed + # This also results in a missing colorbar, since only the raw map is returned + return map_ container = MapContainer( - sliders, map_, - arr, - colorize_kwargs={"alpha": alpha, "center": center, "colormap": colormap}, + obj, + { + "cmap": cmap, + "alpha": alpha, + "center": center, + "vmin": vmin, + "vmax": vmax, + "robust": robust, + }, + play_interval=play_interval, ) - # connect slider with map - for slider in sliders.children: - slider.observe(partial(on_slider_change, container=container), names="value") - return container.render() diff --git a/xdggs/tests/test_plotting.py b/xdggs/tests/test_plotting.py index 1057729..6cfff3f 100644 --- a/xdggs/tests/test_plotting.py +++ b/xdggs/tests/test_plotting.py @@ -4,7 +4,6 @@ import pytest import xarray as xr from arro3.core import Array, Table -from matplotlib import colormaps from xdggs import plotting @@ -108,106 +107,202 @@ def test_create_arrow_table(polygons, arr, coords, expected): assert actual == expected -@pytest.mark.parametrize( - ["var", "center", "expected"], - ( - pytest.param( - xr.Variable("cells", np.array([-5, np.nan, -2, 1])), - None, - np.array([0, np.nan, 0.5, 1]), - id="linear-missing_values", - ), - pytest.param( - xr.Variable("cells", np.arange(-5, 2, dtype="float")), - None, - np.linspace(0, 1, 7), - id="linear-manual", - ), - pytest.param( - xr.Variable("cells", np.linspace(0, 10, 5)), - None, - np.linspace(0, 1, 5), - id="linear-linspace", - ), - pytest.param( - xr.Variable("cells", np.linspace(-5, 5, 10)), - 0, - np.linspace(0, 1, 10), - id="centered-0", - ), - pytest.param( - xr.Variable("cells", np.linspace(0, 10, 10)), - 5, - np.linspace(0, 1, 10), - id="centered-2", - ), - ), -) -def test_normalize(var, center, expected): - actual = plotting.normalize(var, center=center) +# Tests for normalize and colorize functions removed - they are now part of the Colorizer class - np.testing.assert_allclose(actual, expected) +class TestColorizer: + def test_for_dataarray_basic(self): + """Test basic colorizer creation from DataArray.""" + data = xr.DataArray([0, 1, 2, 3], dims="cells") + colorizer = plotting.Colorizer.for_dataarray(data, cmap="viridis") -@pytest.mark.parametrize( - ["var", "kwargs", "expected"], - ( - pytest.param( - xr.Variable("cells", [0, 3]), - {"center": 2, "colormap": colormaps["viridis"], "alpha": 1}, - np.array([[68, 1, 84], [94, 201, 97]], dtype="uint8"), - ), - pytest.param( - xr.Variable("cells", [-1, 1]), - {"center": None, "colormap": colormaps["viridis"], "alpha": 0.8}, - np.array([[68, 1, 84, 204], [253, 231, 36, 204]], dtype="uint8"), + assert colorizer.colormap.name == "viridis" + assert colorizer.normalizer.vmin == 0 + assert colorizer.normalizer.vmax == 3 + + def test_for_dataarray_with_center(self): + """Test colorizer with centered normalization.""" + data = xr.DataArray([-5, -2, 0, 2, 5], dims="cells") + colorizer = plotting.Colorizer.for_dataarray(data, cmap="coolwarm", center=0) + + assert colorizer.colormap.name == "coolwarm" + assert hasattr(colorizer.normalizer, "vcenter") + assert colorizer.normalizer.vcenter == 0 + + def test_for_dataarray_with_vmin_vmax(self): + """Test colorizer with explicit vmin/vmax.""" + data = xr.DataArray([0, 1, 2, 3], dims="cells") + colorizer = plotting.Colorizer.for_dataarray( + data, cmap="plasma", vmin=-10, vmax=10 + ) + + assert colorizer.normalizer.vmin == -10 + assert colorizer.normalizer.vmax == 10 + + def test_for_dataarray_with_alpha(self): + """Test colorizer with alpha transparency.""" + data = xr.DataArray([0, 1, 2, 3], dims="cells") + colorizer = plotting.Colorizer.for_dataarray(data, cmap="viridis", alpha=0.5) + + assert colorizer.alpha == 0.5 + + def test_for_dataarray_robust(self): + """Test robust normalization using percentiles.""" + data = xr.DataArray([0, 1, 2, 3, 100], dims="cells") # outlier at 100 + colorizer = plotting.Colorizer.for_dataarray(data, cmap="viridis", robust=True) + + # With robust=True, should use 2nd and 98th percentiles + assert colorizer.normalizer.vmin < 1 + assert colorizer.normalizer.vmax < 100 + + def test_for_dataset_basic(self): + """Test colorizer creation from Dataset.""" + ds = xr.Dataset({"temperature": xr.DataArray([10, 20, 30], dims="cells")}) + colorizer = plotting.Colorizer.for_dataset( + "temperature", ds["temperature"], cmap="viridis" + ) + + assert colorizer.colormap.name == "viridis" + assert colorizer.normalizer.vmin == 10 + assert colorizer.normalizer.vmax == 30 + + def test_for_dataset_with_dict_cmap(self): + """Test dataset colorizer with dictionary of colormaps.""" + data = xr.DataArray([10, 20, 30], dims="cells") + colorizer = plotting.Colorizer.for_dataset( + "temperature", data, cmap={"temperature": "coolwarm", "pressure": "viridis"} + ) + + assert colorizer.colormap.name == "coolwarm" + + def test_for_dataset_with_dict_center(self): + """Test dataset colorizer with dictionary of center values.""" + data = xr.DataArray([-5, 0, 5], dims="cells") + colorizer = plotting.Colorizer.for_dataset( + "temperature", + data, + cmap="coolwarm", + center={"temperature": 0, "pressure": 1000}, + ) + + assert colorizer.normalizer.vcenter == 0 + + def test_colorize_returns_array(self): + """Test that colorize returns a numpy array.""" + data = xr.DataArray([0, 1, 2, 3], dims="cells") + colorizer = plotting.Colorizer.for_dataarray(data, cmap="viridis") + + colors = colorizer.colorize(data.variable) + + assert isinstance(colors, np.ndarray) + assert colors.shape[0] == 4 # 4 cells + assert colors.shape[1] in [3, 4] # RGB or RGBA + + def test_colorize_with_alpha(self): + """Test that colorize with alpha returns RGBA.""" + data = xr.DataArray([0, 1, 2, 3], dims="cells") + colorizer = plotting.Colorizer.for_dataarray(data, cmap="viridis", alpha=0.8) + + colors = colorizer.colorize(data.variable) + + assert colors.shape[1] == 4 # RGBA + assert colors.dtype == np.uint8 + + def test_get_cmap_preview(self): + """Test colormap preview generation.""" + data = xr.DataArray([0, 1, 2, 3], dims="cells") + colorizer = plotting.Colorizer.for_dataarray(data, cmap="viridis") + + fig, ax = colorizer.get_cmap_preview("Test Label") + + assert fig is not None + assert ax is not None + + @pytest.mark.parametrize( + ["data", "kwargs", "expected"], + ( + pytest.param( + xr.DataArray([0, 3], dims="cells"), + {"cmap": "viridis", "center": 2, "alpha": 1}, + np.array([[68, 1, 84], [94, 201, 97]], dtype="uint8"), + id="centered-rgb", + ), + pytest.param( + xr.DataArray([-1, 1], dims="cells"), + {"cmap": "viridis", "center": None, "alpha": 0.8}, + np.array([[68, 1, 84, 204], [253, 231, 36, 204]], dtype="uint8"), + id="linear-rgba", + ), ), - ), -) -def test_colorize(var, kwargs, expected): - actual = plotting.colorize(var, **kwargs) + ) + def test_colorize_expected_values(self, data, kwargs, expected): + """Test colorize produces expected color arrays for specific inputs.""" + colorizer = plotting.Colorizer.for_dataarray(data, **kwargs) + actual = colorizer.colorize(data.variable) - np.testing.assert_equal(actual, expected) + np.testing.assert_equal(actual, expected) class TestMapContainer: def test_init(self): - map_ = lonboard.Map(layers=[]) - sliders = ipywidgets.VBox( - [ipywidgets.IntSlider(min=0, max=10, description="time")] - ) - obj = xr.DataArray([[0, 1], [2, 3]], dims=["time", "cells"]) - colorize_kwargs = {"a": 1, "b": 2} + """Test MapContainer initialization with proper map and data.""" + from lonboard import SolidPolygonLayer + + # Create a valid layer with geometry + obj = xr.DataArray( + [[0, 1], [2, 3]], + coords={"cell_ids": ("cells", [10, 26])}, + dims=["time", "cells"], + ).dggs.decode({"grid_name": "healpix", "level": 1, "indexing_scheme": "nested"}) + + # Get polygons from grid + cell_ids = obj.dggs.coord.data + grid_info = obj.dggs.grid_info + polygons = grid_info.cell_boundaries(cell_ids, backend="geoarrow") + table = plotting.create_arrow_table(polygons, obj.isel(time=0)) + layer = SolidPolygonLayer(table=table) + + map_ = lonboard.Map(layers=[layer]) + colorizer_kwargs = {"cmap": "viridis", "alpha": 0.8} container = plotting.MapContainer( - dimension_sliders=sliders, - map=map_, + map_=map_, obj=obj, - colorize_kwargs=colorize_kwargs, + colorizer_kwargs=colorizer_kwargs, ) assert container.map == map_ xr.testing.assert_equal(container.obj, obj) - assert container.dimension_sliders == sliders - assert container.colorize_kwargs == colorize_kwargs + assert container.colorizer_kwargs == colorizer_kwargs def test_render(self): - map_ = lonboard.Map(layers=[]) - sliders = ipywidgets.VBox( - [ipywidgets.IntSlider(min=0, max=10, description="time")] - ) - obj = xr.DataArray([[0, 1], [2, 3]], dims=["time", "cells"]) - colorize_kwargs = {"a": 1, "b": 2} + """Test MapContainer render method.""" + from lonboard import SolidPolygonLayer + + obj = xr.DataArray( + [[0, 1], [2, 3]], + coords={"cell_ids": ("cells", [10, 26])}, + dims=["time", "cells"], + ).dggs.decode({"grid_name": "healpix", "level": 1, "indexing_scheme": "nested"}) + + # Get polygons from grid + cell_ids = obj.dggs.coord.data + grid_info = obj.dggs.grid_info + polygons = grid_info.cell_boundaries(cell_ids, backend="geoarrow") + table = plotting.create_arrow_table(polygons, obj.isel(time=0)) + layer = SolidPolygonLayer(table=table) + + map_ = lonboard.Map(layers=[layer]) + colorizer_kwargs = {"cmap": "viridis"} container = plotting.MapContainer( - dimension_sliders=sliders, - map=map_, + map_=map_, obj=obj, - colorize_kwargs=colorize_kwargs, + colorizer_kwargs=colorizer_kwargs, ) rendered = container.render() - assert isinstance(rendered, ipywidgets.VBox) + assert isinstance(rendered, plotting.MapWithControls) @pytest.mark.parametrize( @@ -241,7 +336,7 @@ def test_explore(arr, expected_type): assert isinstance(actual, expected_type) -class TestMapWithSliders: +class TestMapWithControls: @pytest.mark.parametrize( ["sliders", "expected"], ( @@ -250,18 +345,18 @@ class TestMapWithSliders: ), ) def test_sliders(self, sliders, expected) -> None: - map_ = plotting.MapWithSliders([lonboard.Map(layers=[]), *sliders]) + map_ = plotting.MapWithControls([lonboard.Map(layers=[]), *sliders]) assert map_.sliders == expected or isinstance(map_.sliders[0], ipywidgets.VBox) def test_map(self): base_map = lonboard.Map(layers=[]) - wrapped_map = plotting.MapWithSliders([base_map, ipywidgets.HBox()]) + wrapped_map = plotting.MapWithControls([base_map, ipywidgets.HBox()]) assert wrapped_map.map is base_map def test_layers(self): base_map = lonboard.Map(layers=[]) - wrapped_map = plotting.MapWithSliders([base_map, ipywidgets.HBox()]) + wrapped_map = plotting.MapWithControls([base_map, ipywidgets.HBox()]) assert wrapped_map.layers == base_map.layers