Skip to content
2 changes: 1 addition & 1 deletion pypesto/visualize/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
projection_scatter_umap,
projection_scatter_umap_original,
)
from .ensemble import ensemble_identifiability
from .ensemble import ensemble_identifiability, ensemble_parameters_plot
from .misc import process_offset_y, process_result_list, process_y_limits
from .observable_mapping import (
plot_linear_observable_mappings_from_pypesto_result,
Expand Down
57 changes: 57 additions & 0 deletions pypesto/visualize/ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from matplotlib import colormaps
from matplotlib.collections import PatchCollection
from matplotlib.patches import Rectangle

Expand Down Expand Up @@ -413,3 +414,59 @@ def _create_patches(
)

return patches_both_hit, patches_lb_hit, patches_ub_hit, patches_none_hit


def ensemble_parameters_plot(
ensemble: Ensemble,
ax: Optional[plt.Axes] = None,
size: Optional[tuple[float]] = (12, 6)
):
"""
Visualize parameter ensemble.

Parameters
----------
ensemble:
ensemble of parameter vectors (from pypesto.ensemble)
ax:
Axes object to use.
size:
Figure size (width, height) in inches. Is only applied when no ax
object is specified.

Returns
-------
ax: matplotlib.Axes
The plot axes.
"""

if ax is None:
fig, ax = plt.subplots(figsize=size)

x = -0.4
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is x?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the x coordinate for the rectangle vertex, changed now

w = 0.8 # rectangle width
rectangles = []
cmap = colormaps['Greys']
colors = np.flip(cmap(np.linspace(0.3, 0.8, (ensemble.n_vectors-1))), axis=0)
colors = np.insert(colors, 0, [1. , 0. , 0. , 1. ], axis=0)

for i, par_values in enumerate(ensemble.x_vectors):
h = np.max(par_values) - np.min(par_values) # rectangle height
rectangles.append(
Rectangle((x, np.min(par_values)), w, h))
x += w + 0.2
ax.add_collection(PatchCollection(rectangles, facecolors=[1, 1, 1, 1], edgecolors='dimgrey'))

for i, v in enumerate(ensemble.x_vectors):
ax.scatter(x=[i]*ensemble.n_vectors, y=v, s=40, color=colors, alpha=0.6)
# plot the best parameter values
ax.scatter(np.arange(ensemble.n_x), ensemble.x_vectors[:, 0], s=40,
color=[1. , 0. , 0. , 1. ])

ax.plot(np.arange(ensemble.n_x), ensemble.lower_bound, '--', color='grey')
ax.plot(np.arange(ensemble.n_x), ensemble.upper_bound, '--', color='grey')
ax.set_ylim(np.min(ensemble.lower_bound) * 1.1, np.max(ensemble.upper_bound) * 1.1)
plt.xticks(np.arange(ensemble.n_x), ensemble.x_names, rotation='vertical')
plt.tight_layout()

return ax
15 changes: 15 additions & 0 deletions test/visualize/test_visualize.py
Original file line number Diff line number Diff line change
Expand Up @@ -636,6 +636,21 @@ def test_ensemble_identifiability():
# test plotting from a collection object
visualize.ensemble_identifiability(my_ensemble)

@close_fig
def test_ensemble_parameters_plot():
# creates a test problem
problem = create_problem(n_parameters=100)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

such a big one necessary for the test? 🙈


my_ensemble = [
(1 + np.cos(ix) ** 2) * np.random.rand(500) - 1.0 + np.sin(ix)
for ix in range(100)
]
my_ensemble = ensemble.Ensemble(
np.array(my_ensemble), lower_bound=problem.lb, upper_bound=problem.ub
)

visualize.ensemble_parameters_plot(my_ensemble)


@close_fig
def test_profiles():
Expand Down
Loading