Skip to content

Commit d9bd78a

Browse files
committed
final touches
1 parent a67e14c commit d9bd78a

File tree

1 file changed

+52
-50
lines changed

1 file changed

+52
-50
lines changed

niworkflows/viz/plots.py

Lines changed: 52 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ def __init__(self, func_file, mask_file=None, data=None, conf_file=None, seg_fil
7272
def plot(self, figure=None):
7373
"""Main plotter"""
7474
sns.set_style("whitegrid")
75+
sns.set_context("paper", font_scale=0.8)
7576

7677
if figure is None:
7778
figure = plt.gcf()
@@ -81,7 +82,7 @@ def plot(self, figure=None):
8182
nrows = 1 + nconfounds + nspikes
8283

8384
# Create grid
84-
grid = mgs.GridSpec(nrows, 1, wspace=0.0, hspace=0.4,
85+
grid = mgs.GridSpec(nrows, 1, wspace=0.0, hspace=0.05,
8586
height_ratios=[1] * (nrows - 1) + [5])
8687

8788
grid_id = 0
@@ -218,7 +219,7 @@ def plot_carpet(img, atlaslabels, detrend=True, nskip=0, size=(950, 800),
218219
ax1.set_xticks(xticks)
219220
ax1.set_xlabel('time (s)')
220221
labels = tr * (np.array(xticks)) * t_dec
221-
ax1.set_xticklabels(['%.02f' % t for t in labels.tolist()])
222+
ax1.set_xticklabels(['%.02f' % t for t in labels.tolist()], fontsize=5)
222223

223224
# Remove and redefine spines
224225
for side in ["top", "right"]:
@@ -316,7 +317,7 @@ def spikesplot(ts_z, outer_gs=None, tr=None, zscored=True, spike_thresh=6., titl
316317
['%.02f' % t for t in (tr * np.array(xticks)).tolist()])
317318

318319
# Handle Y axis
319-
ylabel = 'slice-wise signal intensity of background'
320+
ylabel = 'slice-wise noise average on background'
320321
if zscored:
321322
ylabel += ' (z-scored)'
322323
zs_max = np.abs(ts_z).max()
@@ -347,8 +348,11 @@ def spikesplot(ts_z, outer_gs=None, tr=None, zscored=True, spike_thresh=6., titl
347348
# ts_z[:, nskip:].max() * 1.05)
348349

349350
ax.annotate(
350-
ylabel, xy=(0.01, 0.0), xytext=(0, -1), xycoords='axes fraction',
351-
textcoords='offset points', va='center', color='gray', size=8)
351+
ylabel, xy=(0.0, 0.7), xycoords='axes fraction',
352+
xytext=(0, 0), textcoords='offset points',
353+
va='center', ha='left', color='gray', size=4,
354+
bbox={'boxstyle': 'round', 'fc': 'w', 'ec': 'none', 'color': 'none',
355+
'lw': 0, 'alpha': 0.8})
352356
ax.set_yticks([])
353357
ax.set_yticklabels([])
354358

@@ -364,7 +368,7 @@ def spikesplot(ts_z, outer_gs=None, tr=None, zscored=True, spike_thresh=6., titl
364368
ax.spines[side].set_visible(False)
365369

366370
if not hide_x:
367-
ax.spines["bottom"].set_position(('outward', 20))
371+
ax.spines["bottom"].set_position(('outward', 10))
368372
ax.xaxis.set_ticks_position('bottom')
369373
else:
370374
ax.spines["bottom"].set_color('none')
@@ -416,8 +420,6 @@ def confoundplot(tseries, gs_ts, gs_dist=None, name=None,
416420

417421
ax_ts = plt.subplot(gs[1])
418422
ax_ts.grid(False)
419-
ax_ts.plot(tseries, color=color)
420-
ax_ts.set_xlim((0, ntsteps - 1))
421423

422424
# Set 10 frame markers in X axis
423425
interval = max((ntsteps // 10, ntsteps // 5, 1))
@@ -436,11 +438,14 @@ def confoundplot(tseries, gs_ts, gs_dist=None, name=None,
436438

437439
if name is not None:
438440
if units is not None:
439-
name += (' [{}]' if notr else ' [{}/s]').format(units)
441+
name += ' [%s]' % units
440442

441443
ax_ts.annotate(
442-
name, xy=(0.01, 0.0), xytext=(0, -4), xycoords='axes fraction',
443-
textcoords='offset points', va='center', color=color, size=8)
444+
name, xy=(0.0, 0.7), xytext=(0, 0), xycoords='axes fraction',
445+
textcoords='offset points', va='center', ha='left',
446+
color=color, size=8,
447+
bbox={'boxstyle': 'round', 'fc': 'w', 'ec': 'none',
448+
'color': 'none', 'lw': 0, 'alpha': 0.8})
444449

445450
for side in ["top", "right"]:
446451
ax_ts.spines[side].set_color('none')
@@ -467,57 +472,54 @@ def confoundplot(tseries, gs_ts, gs_dist=None, name=None,
467472
if ylims[1] is not None:
468473
def_ylims[1] = max([def_ylims[1], ylims[1]])
469474

475+
# Add space for plot title and mean/SD annotation
476+
def_ylims[0] -= 0.1 * (def_ylims[1] - def_ylims[0])
477+
470478
ax_ts.set_ylim(def_ylims)
471479
# yticks = sorted(def_ylims)
472480
ax_ts.set_yticks([])
473481
ax_ts.set_yticklabels([])
474482
# ax_ts.set_yticks(yticks)
475483
# ax_ts.set_yticklabels(['%.02f' % y for y in yticks])
476-
yrange = def_ylims[1] - def_ylims[0]
477484

478-
# Plot average
485+
# Annotate stats
486+
maxv = tseries[~np.isnan(tseries)].max()
487+
mean = tseries[~np.isnan(tseries)].mean()
488+
stdv = tseries[~np.isnan(tseries)].std()
489+
p95 = np.percentile(tseries[~np.isnan(tseries)], 95.0)
490+
491+
stats_label = (r'max: {max:.3f}{units} $\bullet$ mean: {mean:.3f}{units} '
492+
r'$\bullet$ $\sigma$: {sigma:.3f}').format(
493+
max=maxv, mean=mean, units=units or '', sigma=stdv)
494+
ax_ts.annotate(
495+
stats_label, xy=(0.98, 0.7), xycoords='axes fraction',
496+
xytext=(0, 0), textcoords='offset points',
497+
va='center', ha='right', color=color, size=4,
498+
bbox={'boxstyle': 'round', 'fc': 'w', 'ec': 'none', 'color': 'none',
499+
'lw': 0, 'alpha': 0.8}
500+
)
501+
502+
# Annotate percentile 95
503+
ax_ts.plot((0, ntsteps - 1), [p95] * 2, linewidth=.1, color='lightgray')
504+
ax_ts.annotate(
505+
'%.2f' % p95, xy=(0, p95), xytext=(-1, 0),
506+
textcoords='offset points', va='center', ha='right',
507+
color='lightgray', size=3)
508+
479509
if cutoff is None:
480510
cutoff = []
481511

482-
cutoff.insert(0, tseries[~np.isnan(tseries)].mean())
483-
484512
for i, thr in enumerate(cutoff):
485513
ax_ts.plot((0, ntsteps - 1), [thr] * 2,
486-
linewidth=.75,
487-
linestyle='-' if i == 0 else ':',
488-
color=color if i == 0 else 'k')
489-
490-
if i == 0:
491-
mean_label = r'$\mu$=%.3f%s' % (thr, units if units is not None else '')
492-
ax_ts.annotate(
493-
mean_label, xy=(ntsteps - 1, thr), xytext=(11, 0),
494-
textcoords='offset points', va='center', color='w', size=10,
495-
bbox=dict(boxstyle='round', fc=color, ec='none', color='none', lw=0),
496-
arrowprops=dict(
497-
arrowstyle='wedge,tail_width=0.8', lw=0, patchA=None, patchB=None,
498-
fc=color, ec='none', relpos=(0.01, 0.5)))
499-
else:
500-
y_off = [0.0, 0.0]
501-
for pth in cutoff[:i]:
502-
inc = abs(thr - pth)
503-
if inc < yrange:
504-
factor = (- (inc / yrange) + 1) ** 2
505-
if (thr - pth) < 0.0:
506-
y_off[0] -= factor * 20
507-
else:
508-
y_off[1] += factor * 20
509-
510-
offset = y_off[0] if abs(y_off[0]) > y_off[1] else y_off[1]
511-
512-
a_label = '%.2f%s' % (thr, units if units is not None else '')
513-
ax_ts.annotate(
514-
a_label, xy=(ntsteps - 1, thr), xytext=(11, offset),
515-
textcoords='offset points', va='center',
516-
color='w', size=10,
517-
bbox=dict(boxstyle='round', fc='dimgray', ec='none', color='none', lw=0),
518-
arrowprops=dict(
519-
arrowstyle='wedge,tail_width=.9', lw=0, patchA=None, patchB=None,
520-
fc='dimgray', ec='none', relpos=(.1, .5)))
514+
linewidth=.2, color='dimgray')
515+
516+
ax_ts.annotate(
517+
'%.2f' % thr, xy=(0, thr), xytext=(-1, 0),
518+
textcoords='offset points', va='center', ha='right',
519+
color='dimgray', size=3)
520+
521+
ax_ts.plot(tseries, color=color, linewidth=.8)
522+
ax_ts.set_xlim((0, ntsteps - 1))
521523

522524
if gs_dist is not None:
523525
ax_dist = plt.subplot(gs_dist)

0 commit comments

Comments
 (0)