Skip to content

Commit e6b074b

Browse files
authored
Merge pull request #690 from nipreps/enh/carpetplot-labels-readability
ENH: Add a legend to carpet plots with more than one segment
2 parents c0cc020 + 6d4240e commit e6b074b

File tree

2 files changed

+51
-11
lines changed

2 files changed

+51
-11
lines changed

niworkflows/tests/test_viz.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ def test_carpetplot(tr, sorting):
6060
drop_trs=15,
6161
)
6262

63-
labels = ("Cortical GM", "Deep GM", "Cerebellar GM", "WM + brainstem", "CSF")
63+
labels = ("Ctx GM", "Subctx GM", "WM+CSF", "Cereb.", "Edge")
6464
sizes = (200, 100, 50, 100, 50)
6565
total_size = np.sum(sizes)
6666
data = np.zeros((total_size, 300))

niworkflows/viz/plots.py

Lines changed: 50 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,7 @@ def plot_carpet(
150150
size=(900, 1200),
151151
sort_rows="ward",
152152
drop_trs=0,
153+
legend=True,
153154
):
154155
"""
155156
Plot an image representation of voxel intensities across time.
@@ -195,18 +196,23 @@ def plot_carpet(
195196
"whole brain (voxels)": list(range(data.shape[0]))
196197
}
197198

199+
nsegments = len(segments)
200+
if nsegments == 1:
201+
legend = False
202+
198203
if cmap is None:
199204
colors = cm.get_cmap("tab10").colors
200205
elif cmap == "paired":
201206
colors = list(cm.get_cmap("Paired").colors)
202207
colors[0], colors[1] = colors[1], colors[0]
203208
colors[2], colors[7] = colors[7], colors[2]
204209

205-
vminmax = (None, None)
206210
if detrend:
207211
from nilearn.signal import clean
208212
data = clean(data.T, t_r=tr, filter=False).T
209-
vminmax = (np.percentile(data, 2), np.percentile(data, 98))
213+
214+
# We want all subplots to have the same dynamic range
215+
vminmax = (np.percentile(data, 2), np.percentile(data, 98))
210216

211217
# Decimate number of time-series before clustering
212218
n_dec = int((1.8 * data.shape[0]) // size[0])
@@ -245,8 +251,6 @@ def plot_carpet(
245251
t_dec = max(int((1.8 * n_trs) // size[1]), 1)
246252
data = data[:, drop_trs::t_dec]
247253

248-
nsegments = len(segments)
249-
250254
# Define nested GridSpec
251255
gs = mgs.GridSpecFromSubplotSpec(
252256
nsegments,
@@ -285,14 +289,9 @@ def plot_carpet(
285289
xticks = np.linspace(0, data.shape[-1], endpoint=True, num=7)
286290
ax.set_xticks(xticks)
287291
ax.set_yticks([])
288-
ax.set_ylabel(label)
289292
ax.grid(False)
290293

291-
if i < (nsegments - 1):
292-
ax.spines["bottom"].set_color("none")
293-
ax.spines["bottom"].set_visible(False)
294-
ax.set_xticklabels([])
295-
else:
294+
if i == (nsegments - 1):
296295
xlabel = "time-points (index)"
297296
xticklabels = (xticks * n_trs / data.shape[-1]).astype("uint32") + drop_trs
298297
if tr is not None:
@@ -307,10 +306,51 @@ def plot_carpet(
307306
ax.spines["bottom"].set_position(("outward", 5))
308307
ax.spines["bottom"].set_color("k")
309308
ax.spines["bottom"].set_linewidth(.8)
309+
else:
310+
ax.set_xticklabels([])
311+
ax.set_xticks([])
312+
ax.spines["bottom"].set_color("none")
313+
ax.spines["bottom"].set_visible(False)
310314

311315
if title and i == 0:
312316
ax.set_title(title)
313317

318+
if nsegments == 1:
319+
ax.set_ylabel(label)
320+
321+
if legend:
322+
from matplotlib.patches import Patch
323+
from mpl_toolkits.axes_grid1.inset_locator import inset_axes
324+
325+
axlegend = inset_axes(
326+
ax,
327+
width="100%",
328+
height=0.01,
329+
loc='lower center',
330+
borderpad=-4.1,
331+
)
332+
axlegend.grid(False)
333+
axlegend.set_xticks([])
334+
axlegend.set_yticks([])
335+
axlegend.patch.set_alpha(0.0)
336+
for loc in ("top", "bottom", "left", "right"):
337+
axlegend.spines[loc].set_color("none")
338+
axlegend.spines[loc].set_visible(False)
339+
340+
axlegend.legend(
341+
handles=[
342+
Patch(color=colors[i], label=l)
343+
for i, l in enumerate(segments.keys())
344+
],
345+
loc="upper center",
346+
bbox_to_anchor=(0.5, 0),
347+
shadow=False,
348+
fancybox=False,
349+
ncol=min(len(segments.keys()), 5),
350+
frameon=False,
351+
prop={'size': 8}
352+
)
353+
314354
if output_file is not None:
315355
figure = plt.gcf()
316356
figure.savefig(output_file, bbox_inches="tight")

0 commit comments

Comments
 (0)