Skip to content
2 changes: 1 addition & 1 deletion pypesto/visualize/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
projection_scatter_umap,
projection_scatter_umap_original,
)
from .ensemble import ensemble_identifiability
from .ensemble import ensemble_identifiability, ensemble_parameters_plot
from .misc import process_offset_y, process_result_list, process_y_limits
from .observable_mapping import (
plot_linear_observable_mappings_from_pypesto_result,
Expand Down
68 changes: 68 additions & 0 deletions pypesto/visualize/ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from matplotlib import colormaps
from matplotlib.collections import PatchCollection
from matplotlib.patches import Rectangle

Expand Down Expand Up @@ -413,3 +414,70 @@ def _create_patches(
)

return patches_both_hit, patches_lb_hit, patches_ub_hit, patches_none_hit


def ensemble_parameters_plot(
ensemble: Ensemble,
ax: Optional[plt.Axes] = None,
parameter_ids: Optional[list[int]] = None,
size: Optional[tuple[float]] = (6, 12)
):
"""
Visualize parameter ensemble.

Parameters
----------
ensemble:
ensemble of parameter vectors (from pypesto.ensemble).
Comment on lines +430 to +431
Copy link
Contributor

@Doresic Doresic Nov 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So this is a method to generally just visualize the ensemble, so we cannot really say that the rectangles represent, for example, 95% confidence intervals. But one would need to create an ensemble in some way, and then subset it using a xi^2 threshold to get confidence intervals, right?
Would it be good to have this showcased somewhere, like in a notebook that mentiones ensembles or identifiability? How to use this visualization correctly? Or is it ok to just expand the description of the visualization in the docstring to describe that it will just visualize all members of the ensemble and that these do not represent CIs.

ax:
Axes object to use.
parameter_ids:
Indices of parameters to plot.
size:
Figure size (width, height) in inches. Is only applied when no ax
object is specified.

Returns
-------
ax: matplotlib.Axes
The plot axes.
"""

if ax is None:
fig, ax = plt.subplots(figsize=size)

if parameter_ids:
x_vectors = ensemble.x_vectors[parameter_ids]
n_x = len(parameter_ids)
else:
parameter_ids = np.arange(ensemble.n_x)
x_vectors = ensemble.x_vectors
n_x = ensemble.n_x

y_rect = -0.4
h_rect = 0.8 # rectangle height
rectangles = []
cmap = colormaps['Greys']
colors = np.flip(cmap(np.linspace(0.3, 0.8, (ensemble.n_vectors-1))), axis=0)
colors = np.insert(colors, 0, [1., 0., 0., 1.], axis=0)

for i, par_values in enumerate(x_vectors):
w_rect = np.max(par_values) - np.min(par_values) # rectangle width
rectangles.append(
Rectangle((np.min(par_values), y_rect), w_rect, h_rect))
y_rect += h_rect + 0.2
ax.add_collection(PatchCollection(rectangles, facecolors=[1., 1., 1., 1.], edgecolors='dimgrey'))

for i, v in enumerate(x_vectors):
ax.scatter(x=v, y=[i]*ensemble.n_vectors, s=40, color=colors, alpha=0.6)
# plot the best parameter values
ax.scatter(x_vectors[:, 0], np.arange(n_x), s=40,
color=[1., 0., 0., 1.])

ax.plot(ensemble.lower_bound[parameter_ids], np.arange(n_x), '--', color='grey')
ax.plot(ensemble.upper_bound[parameter_ids], np.arange(n_x), '--', color='grey')
ax.set_xlim(np.min(ensemble.lower_bound) * 1.1, np.max(ensemble.upper_bound) * 1.1)
plt.yticks(np.arange(n_x), np.asarray(ensemble.x_names)[parameter_ids])
plt.tight_layout()

return ax
16 changes: 16 additions & 0 deletions test/visualize/test_visualize.py
Original file line number Diff line number Diff line change
Expand Up @@ -646,6 +646,22 @@ def test_ensemble_identifiability():
# test plotting from a collection object
visualize.ensemble_identifiability(my_ensemble)

@close_fig
def test_ensemble_parameters_plot():
# creates a test problem
problem = create_problem(n_parameters=100)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

such a big one necessary for the test? 🙈


my_ensemble = [
(1 + np.cos(ix) ** 2) * np.random.rand(500) - 1.0 + np.sin(ix)
for ix in range(100)
]
my_ensemble = ensemble.Ensemble(
np.array(my_ensemble), lower_bound=problem.lb, upper_bound=problem.ub
)

visualize.ensemble_parameters_plot(my_ensemble)
visualize.ensemble_parameters_plot(my_ensemble, parameter_ids=[0,5,8,13,17,33,45,76,82,88,90])


@close_fig
def test_profiles():
Expand Down
Loading