@@ -150,6 +150,7 @@ def plot_carpet(
150150 size = (900 , 1200 ),
151151 sort_rows = "ward" ,
152152 drop_trs = 0 ,
153+ legend = True ,
153154):
154155 """
155156 Plot an image representation of voxel intensities across time.
@@ -195,18 +196,23 @@ def plot_carpet(
195196 "whole brain (voxels)" : list (range (data .shape [0 ]))
196197 }
197198
199+ nsegments = len (segments )
200+ if nsegments == 1 :
201+ legend = False
202+
198203 if cmap is None :
199204 colors = cm .get_cmap ("tab10" ).colors
200205 elif cmap == "paired" :
201206 colors = list (cm .get_cmap ("Paired" ).colors )
202207 colors [0 ], colors [1 ] = colors [1 ], colors [0 ]
203208 colors [2 ], colors [7 ] = colors [7 ], colors [2 ]
204209
205- vminmax = (None , None )
206210 if detrend :
207211 from nilearn .signal import clean
208212 data = clean (data .T , t_r = tr , filter = False ).T
209- vminmax = (np .percentile (data , 2 ), np .percentile (data , 98 ))
213+
214+ # We want all subplots to have the same dynamic range
215+ vminmax = (np .percentile (data , 2 ), np .percentile (data , 98 ))
210216
211217 # Decimate number of time-series before clustering
212218 n_dec = int ((1.8 * data .shape [0 ]) // size [0 ])
@@ -245,8 +251,6 @@ def plot_carpet(
245251 t_dec = max (int ((1.8 * n_trs ) // size [1 ]), 1 )
246252 data = data [:, drop_trs ::t_dec ]
247253
248- nsegments = len (segments )
249-
250254 # Define nested GridSpec
251255 gs = mgs .GridSpecFromSubplotSpec (
252256 nsegments ,
@@ -285,14 +289,9 @@ def plot_carpet(
285289 xticks = np .linspace (0 , data .shape [- 1 ], endpoint = True , num = 7 )
286290 ax .set_xticks (xticks )
287291 ax .set_yticks ([])
288- ax .set_ylabel (label )
289292 ax .grid (False )
290293
291- if i < (nsegments - 1 ):
292- ax .spines ["bottom" ].set_color ("none" )
293- ax .spines ["bottom" ].set_visible (False )
294- ax .set_xticklabels ([])
295- else :
294+ if i == (nsegments - 1 ):
296295 xlabel = "time-points (index)"
297296 xticklabels = (xticks * n_trs / data .shape [- 1 ]).astype ("uint32" ) + drop_trs
298297 if tr is not None :
@@ -307,10 +306,51 @@ def plot_carpet(
307306 ax .spines ["bottom" ].set_position (("outward" , 5 ))
308307 ax .spines ["bottom" ].set_color ("k" )
309308 ax .spines ["bottom" ].set_linewidth (.8 )
309+ else :
310+ ax .set_xticklabels ([])
311+ ax .set_xticks ([])
312+ ax .spines ["bottom" ].set_color ("none" )
313+ ax .spines ["bottom" ].set_visible (False )
310314
311315 if title and i == 0 :
312316 ax .set_title (title )
313317
318+ if nsegments == 1 :
319+ ax .set_ylabel (label )
320+
321+ if legend :
322+ from matplotlib .patches import Patch
323+ from mpl_toolkits .axes_grid1 .inset_locator import inset_axes
324+
325+ axlegend = inset_axes (
326+ ax ,
327+ width = "100%" ,
328+ height = 0.01 ,
329+ loc = 'lower center' ,
330+ borderpad = - 4.1 ,
331+ )
332+ axlegend .grid (False )
333+ axlegend .set_xticks ([])
334+ axlegend .set_yticks ([])
335+ axlegend .patch .set_alpha (0.0 )
336+ for loc in ("top" , "bottom" , "left" , "right" ):
337+ axlegend .spines [loc ].set_color ("none" )
338+ axlegend .spines [loc ].set_visible (False )
339+
340+ axlegend .legend (
341+ handles = [
342+ Patch (color = colors [i ], label = l )
343+ for i , l in enumerate (segments .keys ())
344+ ],
345+ loc = "upper center" ,
346+ bbox_to_anchor = (0.5 , 0 ),
347+ shadow = False ,
348+ fancybox = False ,
349+ ncol = min (len (segments .keys ()), 5 ),
350+ frameon = False ,
351+ prop = {'size' : 8 }
352+ )
353+
314354 if output_file is not None :
315355 figure = plt .gcf ()
316356 figure .savefig (output_file , bbox_inches = "tight" )
0 commit comments