2929import matplotlib .pyplot as plt
3030from matplotlib import gridspec as mgs
3131import matplotlib .cm as cm
32- from matplotlib .colors import ListedColormap , Normalize
32+ from matplotlib .colors import Normalize
3333from matplotlib .colorbar import ColorbarBase
3434
3535DINA4_LANDSCAPE = (11.69 , 8.27 )
@@ -192,13 +192,15 @@ def plot_carpet(
192192 """
193193 if segments is None :
194194 segments = {
195- "brain" : list (range (data .shape [0 ]))
195+ "whole brain (voxels) " : list (range (data .shape [0 ]))
196196 }
197197
198198 if cmap is None :
199- cmap = ListedColormap ( cm .get_cmap ("tab10" ).colors [: len ( segments )])
199+ colors = cm .get_cmap ("tab10" ).colors
200200 elif cmap == "paired" :
201- cmap = ListedColormap ([cm .get_cmap ("Paired" ).colors [i ] for i in (1 , 0 , 7 , 3 )])
201+ colors = list (cm .get_cmap ("Paired" ).colors )
202+ colors [0 ], colors [1 ] = colors [1 ], colors [0 ]
203+ colors [2 ], colors [7 ] = colors [7 ], colors [2 ]
202204
203205 vminmax = (None , None )
204206 if detrend :
@@ -239,76 +241,75 @@ def plot_carpet(
239241 # Length before decimation
240242 n_trs = data .shape [- 1 ] - drop_trs
241243
242- # Define nested GridSpec
243- gs = mgs .GridSpecFromSubplotSpec (
244- 1 , 2 , subplot_spec = subplot , width_ratios = (1 , 100 ), wspace = 0.0 ,
245- )
246-
247- # Segmentation colorbar
248- colors = np .hstack ([[i + 1 ] * len (v ) for i , v in enumerate (segments .values ())])
249- ax0 = plt .subplot (gs [0 ])
250- ax0 .set_yticks ([])
251- ax0 .set_xticks ([])
252- ax0 .imshow (
253- colors [:, np .newaxis ],
254- interpolation = "none" ,
255- aspect = "auto" ,
256- cmap = cmap
257- )
258-
259- ax0 .grid (False )
260- ax0 .spines ["left" ].set_visible (False )
261- ax0 .spines ["bottom" ].set_color ("none" )
262- ax0 .spines ["bottom" ].set_visible (False )
263-
264244 # Calculate time decimation factor
265245 t_dec = max (int ((1.8 * n_trs ) // size [1 ]), 1 )
266- data = data [np .hstack (list (segments .values ())), drop_trs ::t_dec ]
267-
268- # Carpet plot
269- ax1 = plt .subplot (gs [1 ])
270- ax1 .imshow (
271- data ,
272- interpolation = "nearest" ,
273- aspect = "auto" ,
274- cmap = "gray" ,
275- vmin = vminmax [0 ],
276- vmax = vminmax [1 ],
277- )
246+ data = data [:, drop_trs ::t_dec ]
278247
279- ax1 .grid (False )
280- ax1 .set_yticks ([])
281- ax1 .set_yticklabels ([])
282-
283- xticks = np .linspace (0 , data .shape [- 1 ], endpoint = True , num = 7 )
284- xlabel = "time-points (index)"
285- xticklabels = (xticks * n_trs / data .shape [- 1 ]).astype ("uint32" ) + drop_trs
286- if tr is not None :
287- xlabel = "time (mm:ss)"
288- xticklabels = [
289- f"{ int (t // 60 ):02d} :{ (t % 60 ).round (0 ).astype (int ):02d} "
290- for t in (tr * xticklabels )
291- ]
248+ nsegments = len (segments )
249+
250+ # Define nested GridSpec
251+ gs = mgs .GridSpecFromSubplotSpec (
252+ nsegments ,
253+ 1 ,
254+ subplot_spec = subplot ,
255+ hspace = 0.05 ,
256+ height_ratios = [len (v ) for v in segments .values ()]
257+ )
292258
293- ax1 .set_xticks (xticks )
294- ax1 .set_xlabel (xlabel )
295- ax1 .set_xticklabels (xticklabels )
259+ for i , (label , indices ) in enumerate (segments .items ()):
260+ # Carpet plot
261+ ax = plt .subplot (gs [i ])
262+
263+ ax .imshow (
264+ data [indices , :],
265+ interpolation = "nearest" ,
266+ aspect = "auto" ,
267+ cmap = "gray" ,
268+ vmin = vminmax [0 ],
269+ vmax = vminmax [1 ],
270+ )
296271
297- # Remove and redefine spines
298- for side in ["top" , "right" ]:
299272 # Toggle the spine objects
300- ax0 .spines [side ].set_color ("none" )
301- ax0 .spines [side ].set_visible (False )
302- ax1 .spines [side ].set_color ("none" )
303- ax1 .spines [side ].set_visible (False )
304-
305- ax1 .yaxis .set_ticks_position ("left" )
306- ax1 .xaxis .set_ticks_position ("bottom" )
307- ax1 .spines ["bottom" ].set_visible (False )
308- ax1 .spines ["left" ].set_color ("none" )
309- ax1 .spines ["left" ].set_visible (False )
310- if title :
311- ax1 .set_title (title )
273+ ax .spines ["top" ].set_color ("none" )
274+ ax .spines ["top" ].set_visible (False )
275+ ax .spines ["right" ].set_color ("none" )
276+ ax .spines ["right" ].set_visible (False )
277+
278+ # Make colored left axis
279+ ax .spines ["left" ].set_linewidth (3 )
280+ ax .spines ["left" ].set_color (colors [i ])
281+ ax .spines ["left" ].set_capstyle ("butt" )
282+ ax .spines ["left" ].set_position (("outward" , 2 ))
283+
284+ # Make all subplots have same xticks
285+ xticks = np .linspace (0 , data .shape [- 1 ], endpoint = True , num = 7 )
286+ ax .set_xticks (xticks )
287+ ax .set_yticks ([])
288+ ax .set_ylabel (label )
289+ ax .grid (False )
290+
291+ if i < (nsegments - 1 ):
292+ ax .spines ["bottom" ].set_color ("none" )
293+ ax .spines ["bottom" ].set_visible (False )
294+ ax .set_xticklabels ([])
295+ else :
296+ xlabel = "time-points (index)"
297+ xticklabels = (xticks * n_trs / data .shape [- 1 ]).astype ("uint32" ) + drop_trs
298+ if tr is not None :
299+ xlabel = "time (mm:ss)"
300+ xticklabels = [
301+ f"{ int (t // 60 ):02d} :{ (t % 60 ).round (0 ).astype (int ):02d} "
302+ for t in (tr * xticklabels )
303+ ]
304+
305+ ax .set_xlabel (xlabel )
306+ ax .set_xticklabels (xticklabels )
307+ ax .spines ["bottom" ].set_position (("outward" , 5 ))
308+ ax .spines ["bottom" ].set_color ("k" )
309+ ax .spines ["bottom" ].set_linewidth (.8 )
310+
311+ if title and i == 0 :
312+ ax .set_title (title )
312313
313314 if output_file is not None :
314315 figure = plt .gcf ()
@@ -317,7 +318,7 @@ def plot_carpet(
317318 figure = None
318319 return output_file
319320
320- return ( ax0 , ax1 ), gs
321+ return gs
321322
322323
323324def spikesplot (
0 commit comments