Skip to content

Commit 8c0b6a6

Browse files
authored
Merge pull request #688 from nipreps/enh/carpetplot-improvements
ENH: Miscellaneous improvements to carpetplot
2 parents 1942c83 + 00755e0 commit 8c0b6a6

File tree

5 files changed

+168
-135
lines changed

5 files changed

+168
-135
lines changed

niworkflows/interfaces/plotting.py

Lines changed: 7 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
# https://www.nipreps.org/community/licensing/
2222
#
2323
"""Visualization tools."""
24-
from collections import defaultdict
2524
import numpy as np
2625
import nibabel as nb
2726

@@ -34,7 +33,12 @@
3433
traits,
3534
isdefined,
3635
)
37-
from ..viz.plots import fMRIPlot, compcor_variance_plot, confounds_correlation_plot
36+
from niworkflows.utils.timeseries import _cifti_timeseries, _nifti_timeseries
37+
from niworkflows.viz.plots import (
38+
fMRIPlot,
39+
compcor_variance_plot,
40+
confounds_correlation_plot,
41+
)
3842

3943

4044
class _FMRISummaryInputSpec(BaseInterfaceInputSpec):
@@ -101,7 +105,7 @@ def _run_interface(self, runtime):
101105
),
102106
tr=(
103107
self.inputs.tr if isdefined(self.inputs.tr) else
104-
_get_tr(self.inputs.in_func)
108+
_get_tr(input_data)
105109
),
106110
confounds=dataframe,
107111
units={"outliers": "%", "FD": "mm"},
@@ -223,59 +227,6 @@ def _run_interface(self, runtime):
223227
return runtime
224228

225229

226-
def _cifti_timeseries(dataset):
227-
"""Extract timeseries from CIFTI2 dataset."""
228-
dataset = nb.load(dataset) if isinstance(dataset, str) else dataset
229-
230-
if dataset.nifti_header.get_intent()[0] != "ConnDenseSeries":
231-
raise ValueError("Not a dense timeseries")
232-
233-
matrix = dataset.header.matrix
234-
seg = defaultdict(list)
235-
for bm in matrix.get_index_map(1).brain_models:
236-
label = bm.brain_structure.replace("CIFTI_STRUCTURE_", "").replace("_", " ").title()
237-
if "CORTEX" not in bm.brain_structure and "CEREBELLUM" not in bm.brain_structure:
238-
label = "Other"
239-
240-
seg[label] += list(range(
241-
bm.index_offset, bm.index_offset + bm.index_count
242-
))
243-
244-
return dataset.get_fdata(dtype="float32").T, seg
245-
246-
247-
def _nifti_timeseries(
248-
dataset,
249-
segmentation=None,
250-
lut=None,
251-
labels=("CSF", "WM", "Cerebellum", "Cortex")
252-
):
253-
"""Extract timeseries from NIfTI1/2 datasets."""
254-
dataset = nb.load(dataset) if isinstance(dataset, str) else dataset
255-
data = dataset.get_fdata(dtype="float32").reshape((-1, dataset.shape[-1]))
256-
257-
if segmentation is None:
258-
return data, None
259-
260-
segmentation = nb.load(segmentation) if isinstance(segmentation, str) else segmentation
261-
# Map segmentation
262-
if lut is None:
263-
lut = np.zeros((256,), dtype="int")
264-
lut[1:11] = 4
265-
lut[255] = 3
266-
lut[30:99] = 2
267-
lut[100:201] = 1
268-
# Apply lookup table
269-
seg = lut[np.asanyarray(segmentation.dataobj, dtype=int)].reshape(-1)
270-
fgmask = seg > 0
271-
seg = seg[fgmask]
272-
seg_dict = {}
273-
for i in np.unique(seg):
274-
seg_dict[labels[i - 1]] = np.argwhere(seg == i).squeeze()
275-
276-
return data[fgmask], seg_dict
277-
278-
279230
def _get_tr(img):
280231
"""
281232
Attempt to extract repetition time from NIfTI/CIFTI header

niworkflows/interfaces/tests/test_plotting.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,8 @@
2424
import os
2525
import nibabel as nb
2626
from niworkflows import viz
27-
from niworkflows.interfaces.plotting import (
28-
_get_tr,
29-
_cifti_timeseries,
30-
_nifti_timeseries,
31-
)
27+
from niworkflows.utils.timeseries import _cifti_timeseries, _nifti_timeseries
28+
from niworkflows.interfaces.plotting import _get_tr
3229
from niworkflows.tests.conftest import datadir
3330

3431

niworkflows/tests/test_viz.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -33,11 +33,8 @@
3333
from .generate_data import _create_dtseries_cifti
3434
from .. import viz
3535
from niworkflows.viz.plots import fMRIPlot
36-
from niworkflows.interfaces.plotting import (
37-
_cifti_timeseries,
38-
_nifti_timeseries,
39-
_get_tr,
40-
)
36+
from niworkflows.utils.timeseries import _cifti_timeseries, _nifti_timeseries
37+
from niworkflows.interfaces.plotting import _get_tr
4138

4239

4340
@pytest.mark.parametrize("tr", (None, 0.7))

niworkflows/utils/timeseries.py

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
# emacs: -*- mode: python; py-indent-offset: 4; indent-tabs-mode: nil -*-
2+
# vi: set ft=python sts=4 ts=4 sw=4 et:
3+
#
4+
# Copyright 2022 The NiPreps Developers <nipreps@gmail.com>
5+
#
6+
# Licensed under the Apache License, Version 2.0 (the "License");
7+
# you may not use this file except in compliance with the License.
8+
# You may obtain a copy of the License at
9+
#
10+
# http://www.apache.org/licenses/LICENSE-2.0
11+
#
12+
# Unless required by applicable law or agreed to in writing, software
13+
# distributed under the License is distributed on an "AS IS" BASIS,
14+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
# See the License for the specific language governing permissions and
16+
# limitations under the License.
17+
#
18+
# We support and encourage derived works from this project, please read
19+
# about our expectations at
20+
#
21+
# https://www.nipreps.org/community/licensing/
22+
#
23+
"""Extracting signals from NIfTI and CIFTI2 files."""
24+
import numpy as np
25+
import nibabel as nb
26+
27+
28+
def _cifti_timeseries(dataset):
29+
"""Extract timeseries from CIFTI2 dataset."""
30+
dataset = nb.load(dataset) if isinstance(dataset, str) else dataset
31+
32+
if dataset.nifti_header.get_intent()[0] != "ConnDenseSeries":
33+
raise ValueError("Not a dense timeseries")
34+
35+
matrix = dataset.header.matrix
36+
labels = {
37+
"CIFTI_STRUCTURE_CORTEX_LEFT": "CtxL",
38+
"CIFTI_STRUCTURE_CORTEX_RIGHT": "CtxR",
39+
"CIFTI_STRUCTURE_CEREBELLUM_LEFT": "CbL",
40+
"CIFTI_STRUCTURE_CEREBELLUM_RIGHT": "CbR",
41+
}
42+
seg = {label: [] for label in list(labels.values()) + ["Other"]}
43+
for bm in matrix.get_index_map(1).brain_models:
44+
label = (
45+
"Other" if bm.brain_structure not in labels else
46+
labels[bm.brain_structure]
47+
)
48+
seg[label] += list(range(
49+
bm.index_offset, bm.index_offset + bm.index_count
50+
))
51+
52+
return dataset.get_fdata(dtype="float32").T, seg
53+
54+
55+
def _nifti_timeseries(
56+
dataset,
57+
segmentation=None,
58+
labels=("Ctx GM", "dGM", "WM+CSF", "Cb"),
59+
remap_rois=True,
60+
lut=None,
61+
):
62+
"""Extract timeseries from NIfTI1/2 datasets."""
63+
dataset = nb.load(dataset) if isinstance(dataset, str) else dataset
64+
data = dataset.get_fdata(dtype="float32").reshape((-1, dataset.shape[-1]))
65+
66+
if segmentation is None:
67+
return data, None
68+
69+
segmentation = nb.load(segmentation) if isinstance(segmentation, str) else segmentation
70+
# Map segmentation
71+
if remap_rois or lut is not None:
72+
if lut is None:
73+
lut = np.zeros((256,), dtype="int")
74+
lut[100:201] = 1 # Ctx GM
75+
lut[30:99] = 2 # dGM
76+
lut[1:11] = 3 # WM+CSF
77+
lut[255] = 4 # Cerebellum
78+
# Apply lookup table
79+
segmentation = lut[np.asanyarray(segmentation.dataobj, dtype=int)].reshape(-1)
80+
81+
fgmask = segmentation > 0
82+
segmentation = segmentation[fgmask]
83+
seg_dict = {}
84+
for i in np.unique(segmentation):
85+
seg_dict[labels[i - 1]] = np.argwhere(segmentation == i).squeeze()
86+
87+
return data[fgmask], seg_dict

niworkflows/viz/plots.py

Lines changed: 70 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
import matplotlib.pyplot as plt
3030
from matplotlib import gridspec as mgs
3131
import matplotlib.cm as cm
32-
from matplotlib.colors import ListedColormap, Normalize
32+
from matplotlib.colors import Normalize
3333
from matplotlib.colorbar import ColorbarBase
3434

3535
DINA4_LANDSCAPE = (11.69, 8.27)
@@ -192,13 +192,15 @@ def plot_carpet(
192192
"""
193193
if segments is None:
194194
segments = {
195-
"brain": list(range(data.shape[0]))
195+
"whole brain (voxels)": list(range(data.shape[0]))
196196
}
197197

198198
if cmap is None:
199-
cmap = ListedColormap(cm.get_cmap("tab10").colors[:len(segments)])
199+
colors = cm.get_cmap("tab10").colors
200200
elif cmap == "paired":
201-
cmap = ListedColormap([cm.get_cmap("Paired").colors[i] for i in (1, 0, 7, 3)])
201+
colors = list(cm.get_cmap("Paired").colors)
202+
colors[0], colors[1] = colors[1], colors[0]
203+
colors[2], colors[7] = colors[7], colors[2]
202204

203205
vminmax = (None, None)
204206
if detrend:
@@ -239,76 +241,75 @@ def plot_carpet(
239241
# Length before decimation
240242
n_trs = data.shape[-1] - drop_trs
241243

242-
# Define nested GridSpec
243-
gs = mgs.GridSpecFromSubplotSpec(
244-
1, 2, subplot_spec=subplot, width_ratios=(1, 100), wspace=0.0,
245-
)
246-
247-
# Segmentation colorbar
248-
colors = np.hstack([[i + 1] * len(v) for i, v in enumerate(segments.values())])
249-
ax0 = plt.subplot(gs[0])
250-
ax0.set_yticks([])
251-
ax0.set_xticks([])
252-
ax0.imshow(
253-
colors[:, np.newaxis],
254-
interpolation="none",
255-
aspect="auto",
256-
cmap=cmap
257-
)
258-
259-
ax0.grid(False)
260-
ax0.spines["left"].set_visible(False)
261-
ax0.spines["bottom"].set_color("none")
262-
ax0.spines["bottom"].set_visible(False)
263-
264244
# Calculate time decimation factor
265245
t_dec = max(int((1.8 * n_trs) // size[1]), 1)
266-
data = data[np.hstack(list(segments.values())), drop_trs::t_dec]
267-
268-
# Carpet plot
269-
ax1 = plt.subplot(gs[1])
270-
ax1.imshow(
271-
data,
272-
interpolation="nearest",
273-
aspect="auto",
274-
cmap="gray",
275-
vmin=vminmax[0],
276-
vmax=vminmax[1],
277-
)
246+
data = data[:, drop_trs::t_dec]
278247

279-
ax1.grid(False)
280-
ax1.set_yticks([])
281-
ax1.set_yticklabels([])
282-
283-
xticks = np.linspace(0, data.shape[-1], endpoint=True, num=7)
284-
xlabel = "time-points (index)"
285-
xticklabels = (xticks * n_trs / data.shape[-1]).astype("uint32") + drop_trs
286-
if tr is not None:
287-
xlabel = "time (mm:ss)"
288-
xticklabels = [
289-
f"{int(t // 60):02d}:{(t % 60).round(0).astype(int):02d}"
290-
for t in (tr * xticklabels)
291-
]
248+
nsegments = len(segments)
249+
250+
# Define nested GridSpec
251+
gs = mgs.GridSpecFromSubplotSpec(
252+
nsegments,
253+
1,
254+
subplot_spec=subplot,
255+
hspace=0.05,
256+
height_ratios=[len(v) for v in segments.values()]
257+
)
292258

293-
ax1.set_xticks(xticks)
294-
ax1.set_xlabel(xlabel)
295-
ax1.set_xticklabels(xticklabels)
259+
for i, (label, indices) in enumerate(segments.items()):
260+
# Carpet plot
261+
ax = plt.subplot(gs[i])
262+
263+
ax.imshow(
264+
data[indices, :],
265+
interpolation="nearest",
266+
aspect="auto",
267+
cmap="gray",
268+
vmin=vminmax[0],
269+
vmax=vminmax[1],
270+
)
296271

297-
# Remove and redefine spines
298-
for side in ["top", "right"]:
299272
# Toggle the spine objects
300-
ax0.spines[side].set_color("none")
301-
ax0.spines[side].set_visible(False)
302-
ax1.spines[side].set_color("none")
303-
ax1.spines[side].set_visible(False)
304-
305-
ax1.yaxis.set_ticks_position("left")
306-
ax1.xaxis.set_ticks_position("bottom")
307-
ax1.spines["bottom"].set_visible(False)
308-
ax1.spines["left"].set_color("none")
309-
ax1.spines["left"].set_visible(False)
310-
if title:
311-
ax1.set_title(title)
273+
ax.spines["top"].set_color("none")
274+
ax.spines["top"].set_visible(False)
275+
ax.spines["right"].set_color("none")
276+
ax.spines["right"].set_visible(False)
277+
278+
# Make colored left axis
279+
ax.spines["left"].set_linewidth(3)
280+
ax.spines["left"].set_color(colors[i])
281+
ax.spines["left"].set_capstyle("butt")
282+
ax.spines["left"].set_position(("outward", 2))
283+
284+
# Make all subplots have same xticks
285+
xticks = np.linspace(0, data.shape[-1], endpoint=True, num=7)
286+
ax.set_xticks(xticks)
287+
ax.set_yticks([])
288+
ax.set_ylabel(label)
289+
ax.grid(False)
290+
291+
if i < (nsegments - 1):
292+
ax.spines["bottom"].set_color("none")
293+
ax.spines["bottom"].set_visible(False)
294+
ax.set_xticklabels([])
295+
else:
296+
xlabel = "time-points (index)"
297+
xticklabels = (xticks * n_trs / data.shape[-1]).astype("uint32") + drop_trs
298+
if tr is not None:
299+
xlabel = "time (mm:ss)"
300+
xticklabels = [
301+
f"{int(t // 60):02d}:{(t % 60).round(0).astype(int):02d}"
302+
for t in (tr * xticklabels)
303+
]
304+
305+
ax.set_xlabel(xlabel)
306+
ax.set_xticklabels(xticklabels)
307+
ax.spines["bottom"].set_position(("outward", 5))
308+
ax.spines["bottom"].set_color("k")
309+
ax.spines["bottom"].set_linewidth(.8)
310+
311+
if title and i == 0:
312+
ax.set_title(title)
312313

313314
if output_file is not None:
314315
figure = plt.gcf()
@@ -317,7 +318,7 @@ def plot_carpet(
317318
figure = None
318319
return output_file
319320

320-
return (ax0, ax1), gs
321+
return gs
321322

322323

323324
def spikesplot(

0 commit comments

Comments
 (0)