Skip to content

Commit 1942c83

Browse files
authored
Merge pull request #687 from nipreps/fix/carpetplot-minimize-processing
ENH: Refactor carpetplot reasigning responsibilities
2 parents b58e383 + 5ff3786 commit 1942c83

File tree

5 files changed

+418
-316
lines changed

5 files changed

+418
-316
lines changed

.circleci/config.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -225,6 +225,7 @@ jobs:
225225
mkdir -p $PWD/artifacts $PWD/summaries
226226
docker run -u $( id -u ):$( id -g ) -it --rm -w /src/niworkflows \
227227
-e COVERAGE_FILE=/tmp/summaries/.pytest.coverage \
228+
-e SAVE_CIRCLE_ARTIFACTS="/tmp/artifacts/" \
228229
-e TEST_DATA_HOME=/data -v /tmp/data:/data \
229230
-e FS_LICENSE=/etc/fslicense.txt \
230231
-v /tmp/fslicense/license.txt:/etc/fslicense.txt:ro \

niworkflows/interfaces/plotting.py

Lines changed: 116 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,9 @@
2121
# https://www.nipreps.org/community/licensing/
2222
#
2323
"""Visualization tools."""
24+
from collections import defaultdict
2425
import numpy as np
26+
import nibabel as nb
2527

2628
from nipype.utils.filemanip import fname_presuffix
2729
from nipype.interfaces.base import (
@@ -37,14 +39,14 @@
3739

3840
class _FMRISummaryInputSpec(BaseInterfaceInputSpec):
3941
in_func = File(exists=True, mandatory=True, desc="")
40-
in_mask = File(exists=True, desc="")
42+
in_spikes_bg = File(exists=True, desc="")
43+
fd = File(exists=True, desc="")
44+
dvars = File(exists=True, desc="")
45+
outliers = File(exists=True, desc="")
4146
in_segm = File(exists=True, desc="")
42-
in_spikes_bg = File(exists=True, mandatory=True, desc="")
43-
fd = File(exists=True, mandatory=True, desc="")
44-
fd_thres = traits.Float(0.2, usedefault=True, desc="")
45-
dvars = File(exists=True, mandatory=True, desc="")
46-
outliers = File(exists=True, mandatory=True, desc="")
4747
tr = traits.Either(None, traits.Float, usedefault=True, desc="the TR")
48+
fd_thres = traits.Float(0.2, usedefault=True, desc="")
49+
drop_trs = traits.Int(0, usedefault=True, desc="dummy scans")
4850

4951

5052
class _FMRISummaryOutputSpec(TraitedSpec):
@@ -67,28 +69,44 @@ def _run_interface(self, runtime):
6769
newpath=runtime.cwd,
6870
)
6971

70-
dataframe = pd.DataFrame(
71-
{
72-
"outliers": np.loadtxt(self.inputs.outliers, usecols=[0]).tolist(),
73-
# Pick non-standardize dvars (col 1)
74-
# First timepoint is NaN (difference)
75-
"DVARS": [np.nan]
76-
+ np.loadtxt(self.inputs.dvars, skiprows=1, usecols=[1]).tolist(),
77-
# First timepoint is zero (reference volume)
78-
"FD": [0.0]
79-
+ np.loadtxt(self.inputs.fd, skiprows=1, usecols=[0]).tolist(),
80-
}
72+
dataframe = pd.DataFrame({
73+
"outliers": np.loadtxt(self.inputs.outliers, usecols=[0]).tolist(),
74+
# Pick non-standardize dvars (col 1)
75+
# First timepoint is NaN (difference)
76+
"DVARS": [np.nan]
77+
+ np.loadtxt(self.inputs.dvars, skiprows=1, usecols=[1]).tolist(),
78+
# First timepoint is zero (reference volume)
79+
"FD": [0.0]
80+
+ np.loadtxt(self.inputs.fd, skiprows=1, usecols=[0]).tolist(),
81+
}) if (
82+
isdefined(self.inputs.outliers)
83+
and isdefined(self.inputs.dvars)
84+
and isdefined(self.inputs.fd)
85+
) else None
86+
87+
input_data = nb.load(self.inputs.in_func)
88+
seg_file = self.inputs.in_segm if isdefined(self.inputs.in_segm) else None
89+
dataset, segments = (
90+
_cifti_timeseries(input_data)
91+
if isinstance(input_data, nb.Cifti2Image) else
92+
_nifti_timeseries(input_data, seg_file)
8193
)
8294

8395
fig = fMRIPlot(
84-
self.inputs.in_func,
85-
mask_file=self.inputs.in_mask if isdefined(self.inputs.in_mask) else None,
86-
seg_file=self.inputs.in_segm if isdefined(self.inputs.in_segm) else None,
87-
spikes_files=[self.inputs.in_spikes_bg],
88-
tr=self.inputs.tr,
89-
data=dataframe[["outliers", "DVARS", "FD"]],
96+
dataset,
97+
segments=segments,
98+
spikes_files=(
99+
[self.inputs.in_spikes_bg]
100+
if isdefined(self.inputs.in_spikes_bg) else None
101+
),
102+
tr=(
103+
self.inputs.tr if isdefined(self.inputs.tr) else
104+
_get_tr(self.inputs.in_func)
105+
),
106+
confounds=dataframe,
90107
units={"outliers": "%", "FD": "mm"},
91108
vlines={"FD": [self.inputs.fd_thres]},
109+
nskip=self.inputs.drop_trs,
92110
).plot()
93111
fig.savefig(self._results["out_file"], bbox_inches="tight")
94112
return runtime
@@ -203,3 +221,78 @@ def _run_interface(self, runtime):
203221
reference=self.inputs.reference_column,
204222
)
205223
return runtime
224+
225+
226+
def _cifti_timeseries(dataset):
227+
"""Extract timeseries from CIFTI2 dataset."""
228+
dataset = nb.load(dataset) if isinstance(dataset, str) else dataset
229+
230+
if dataset.nifti_header.get_intent()[0] != "ConnDenseSeries":
231+
raise ValueError("Not a dense timeseries")
232+
233+
matrix = dataset.header.matrix
234+
seg = defaultdict(list)
235+
for bm in matrix.get_index_map(1).brain_models:
236+
label = bm.brain_structure.replace("CIFTI_STRUCTURE_", "").replace("_", " ").title()
237+
if "CORTEX" not in bm.brain_structure and "CEREBELLUM" not in bm.brain_structure:
238+
label = "Other"
239+
240+
seg[label] += list(range(
241+
bm.index_offset, bm.index_offset + bm.index_count
242+
))
243+
244+
return dataset.get_fdata(dtype="float32").T, seg
245+
246+
247+
def _nifti_timeseries(
248+
dataset,
249+
segmentation=None,
250+
lut=None,
251+
labels=("CSF", "WM", "Cerebellum", "Cortex")
252+
):
253+
"""Extract timeseries from NIfTI1/2 datasets."""
254+
dataset = nb.load(dataset) if isinstance(dataset, str) else dataset
255+
data = dataset.get_fdata(dtype="float32").reshape((-1, dataset.shape[-1]))
256+
257+
if segmentation is None:
258+
return data, None
259+
260+
segmentation = nb.load(segmentation) if isinstance(segmentation, str) else segmentation
261+
# Map segmentation
262+
if lut is None:
263+
lut = np.zeros((256,), dtype="int")
264+
lut[1:11] = 4
265+
lut[255] = 3
266+
lut[30:99] = 2
267+
lut[100:201] = 1
268+
# Apply lookup table
269+
seg = lut[np.asanyarray(segmentation.dataobj, dtype=int)].reshape(-1)
270+
fgmask = seg > 0
271+
seg = seg[fgmask]
272+
seg_dict = {}
273+
for i in np.unique(seg):
274+
seg_dict[labels[i - 1]] = np.argwhere(seg == i).squeeze()
275+
276+
return data[fgmask], seg_dict
277+
278+
279+
def _get_tr(img):
280+
"""
281+
Attempt to extract repetition time from NIfTI/CIFTI header
282+
283+
Examples
284+
--------
285+
>>> _get_tr(nb.load(Path(test_data) /
286+
... 'sub-ds205s03_task-functionallocalizer_run-01_bold_volreg.nii.gz'))
287+
2.2
288+
>>> _get_tr(nb.load(Path(test_data) /
289+
... 'sub-01_task-mixedgamblestask_run-02_space-fsLR_den-91k_bold.dtseries.nii'))
290+
2.0
291+
292+
"""
293+
294+
try:
295+
return img.header.matrix.get_index_map(0).series_step
296+
except AttributeError:
297+
return img.header.get_zooms()[-1]
298+
raise RuntimeError("Could not extract TR - unknown data structure type")
Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
# emacs: -*- mode: python; py-indent-offset: 4; indent-tabs-mode: nil -*-
2+
# vi: set ft=python sts=4 ts=4 sw=4 et:
3+
#
4+
# Copyright 2022 The NiPreps Developers <nipreps@gmail.com>
5+
#
6+
# Licensed under the Apache License, Version 2.0 (the "License");
7+
# you may not use this file except in compliance with the License.
8+
# You may obtain a copy of the License at
9+
#
10+
# http://www.apache.org/licenses/LICENSE-2.0
11+
#
12+
# Unless required by applicable law or agreed to in writing, software
13+
# distributed under the License is distributed on an "AS IS" BASIS,
14+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
# See the License for the specific language governing permissions and
16+
# limitations under the License.
17+
#
18+
# We support and encourage derived works from this project, please read
19+
# about our expectations at
20+
#
21+
# https://www.nipreps.org/community/licensing/
22+
#
23+
"""Tests plotting interfaces."""
24+
import os
25+
import nibabel as nb
26+
from niworkflows import viz
27+
from niworkflows.interfaces.plotting import (
28+
_get_tr,
29+
_cifti_timeseries,
30+
_nifti_timeseries,
31+
)
32+
from niworkflows.tests.conftest import datadir
33+
34+
35+
def test_cifti_carpetplot():
36+
"""Exercise extraction of timeseries from CIFTI2."""
37+
save_artifacts = os.getenv("SAVE_CIRCLE_ARTIFACTS", False)
38+
39+
cifti_file = os.path.join(
40+
datadir,
41+
"sub-01_task-mixedgamblestask_run-02_space-fsLR_den-91k_bold.dtseries.nii",
42+
)
43+
data, segments = _cifti_timeseries(cifti_file)
44+
viz.plot_carpet(
45+
data,
46+
segments,
47+
tr=_get_tr(nb.load(cifti_file)),
48+
output_file=(
49+
os.path.join(
50+
save_artifacts, "carpetplot_cifti.svg"
51+
) if save_artifacts else None
52+
),
53+
drop_trs=0,
54+
cmap="paired",
55+
)
56+
57+
58+
def test_nifti_carpetplot():
59+
"""Exercise extraction of timeseries from CIFTI2."""
60+
save_artifacts = os.getenv("SAVE_CIRCLE_ARTIFACTS", False)
61+
62+
nifti_file = os.path.join(
63+
datadir,
64+
"sub-ds205s03_task-functionallocalizer_run-01_bold_volreg.nii.gz",
65+
)
66+
seg_file = os.path.join(
67+
datadir,
68+
"sub-ds205s03_task-functionallocalizer_run-01_bold_parc.nii.gz",
69+
)
70+
data, segments = _nifti_timeseries(nifti_file, seg_file)
71+
viz.plot_carpet(
72+
data,
73+
segments,
74+
tr=_get_tr(nb.load(nifti_file)),
75+
output_file=(
76+
os.path.join(
77+
save_artifacts, "carpetplot_nifti.svg"
78+
) if save_artifacts else None
79+
),
80+
drop_trs=0,
81+
)

0 commit comments

Comments
 (0)