diff --git a/pypesto/visualize/__init__.py b/pypesto/visualize/__init__.py index 7318cf9c9..25f1667e6 100644 --- a/pypesto/visualize/__init__.py +++ b/pypesto/visualize/__init__.py @@ -17,7 +17,7 @@ projection_scatter_umap, projection_scatter_umap_original, ) -from .ensemble import ensemble_identifiability +from .ensemble import ensemble_identifiability, ensemble_parameters_plot from .misc import process_offset_y, process_result_list, process_y_limits from .observable_mapping import ( plot_linear_observable_mappings_from_pypesto_result, diff --git a/pypesto/visualize/ensemble.py b/pypesto/visualize/ensemble.py index 413341223..fb2922dfe 100644 --- a/pypesto/visualize/ensemble.py +++ b/pypesto/visualize/ensemble.py @@ -3,6 +3,7 @@ import matplotlib.pyplot as plt import numpy as np import pandas as pd +from matplotlib import colormaps from matplotlib.collections import PatchCollection from matplotlib.patches import Rectangle @@ -413,3 +414,70 @@ def _create_patches( ) return patches_both_hit, patches_lb_hit, patches_ub_hit, patches_none_hit + + +def ensemble_parameters_plot( + ensemble: Ensemble, + ax: Optional[plt.Axes] = None, + parameter_ids: Optional[list[int]] = None, + size: Optional[tuple[float]] = (6, 12) +): + """ + Visualize parameter ensemble. + + Parameters + ---------- + ensemble: + ensemble of parameter vectors (from pypesto.ensemble). + ax: + Axes object to use. + parameter_ids: + Indices of parameters to plot. + size: + Figure size (width, height) in inches. Is only applied when no ax + object is specified. + + Returns + ------- + ax: matplotlib.Axes + The plot axes. + """ + + if ax is None: + fig, ax = plt.subplots(figsize=size) + + if parameter_ids: + x_vectors = ensemble.x_vectors[parameter_ids] + n_x = len(parameter_ids) + else: + parameter_ids = np.arange(ensemble.n_x) + x_vectors = ensemble.x_vectors + n_x = ensemble.n_x + + y_rect = -0.4 + h_rect = 0.8 # rectangle height + rectangles = [] + cmap = colormaps['Greys'] + colors = np.flip(cmap(np.linspace(0.3, 0.8, (ensemble.n_vectors-1))), axis=0) + colors = np.insert(colors, 0, [1., 0., 0., 1.], axis=0) + + for i, par_values in enumerate(x_vectors): + w_rect = np.max(par_values) - np.min(par_values) # rectangle width + rectangles.append( + Rectangle((np.min(par_values), y_rect), w_rect, h_rect)) + y_rect += h_rect + 0.2 + ax.add_collection(PatchCollection(rectangles, facecolors=[1., 1., 1., 1.], edgecolors='dimgrey')) + + for i, v in enumerate(x_vectors): + ax.scatter(x=v, y=[i]*ensemble.n_vectors, s=40, color=colors, alpha=0.6) + # plot the best parameter values + ax.scatter(x_vectors[:, 0], np.arange(n_x), s=40, + color=[1., 0., 0., 1.]) + + ax.plot(ensemble.lower_bound[parameter_ids], np.arange(n_x), '--', color='grey') + ax.plot(ensemble.upper_bound[parameter_ids], np.arange(n_x), '--', color='grey') + ax.set_xlim(np.min(ensemble.lower_bound) * 1.1, np.max(ensemble.upper_bound) * 1.1) + plt.yticks(np.arange(n_x), np.asarray(ensemble.x_names)[parameter_ids]) + plt.tight_layout() + + return ax diff --git a/test/visualize/test_visualize.py b/test/visualize/test_visualize.py index 4241e59fe..88d312f31 100644 --- a/test/visualize/test_visualize.py +++ b/test/visualize/test_visualize.py @@ -646,6 +646,22 @@ def test_ensemble_identifiability(): # test plotting from a collection object visualize.ensemble_identifiability(my_ensemble) +@close_fig +def test_ensemble_parameters_plot(): + # creates a test problem + problem = create_problem(n_parameters=100) + + my_ensemble = [ + (1 + np.cos(ix) ** 2) * np.random.rand(500) - 1.0 + np.sin(ix) + for ix in range(100) + ] + my_ensemble = ensemble.Ensemble( + np.array(my_ensemble), lower_bound=problem.lb, upper_bound=problem.ub + ) + + visualize.ensemble_parameters_plot(my_ensemble) + visualize.ensemble_parameters_plot(my_ensemble, parameter_ids=[0,5,8,13,17,33,45,76,82,88,90]) + @close_fig def test_profiles():