1+ import os
2+ import numpy as np
3+ import matplotlib as mpl
4+ import matplotlib .pyplot as plt
5+
6+ import logging
7+ import tqdm
8+ import networkx as nx
9+ import nibabel as nib
10+ import pathlib
11+
12+ from IPython .display import clear_output
13+ from neurolib .utils import atlases
14+
15+
16+ class Brainplot :
17+ def __init__ (self , Cmat , data , nframes = None , dt = 0.1 , fps = 25 , labels = False , darkmode = True ):
18+ self .sc = Cmat
19+ self .n = self .sc .shape [0 ]
20+
21+ self .data = data
22+ self .darkmode = darkmode
23+
24+ self .G = nx .Graph ()
25+ self .G .add_nodes_from (range (self .n ))
26+
27+ coords = {}
28+ atlas = atlases .AutomatedAnatomicalParcellation2 ()
29+ for i , c in enumerate (atlas .coords ()):
30+ coords [i ] = [c [0 ], c [1 ]]
31+ self .position = coords
32+
33+ self .edge_threshold = 0.01
34+
35+ self .fps = fps
36+ self .dt = dt
37+
38+ nframes = nframes or int ((data .shape [1 ] * self .dt / 1000 ) * self .fps ) # 20 fps default
39+ logging .info (f"Defaulting to { nframes } frames at { self .fps } fp/s" )
40+ self .nframes = nframes
41+
42+ self .frame_interval = self .data .shape [1 ] // self .nframes
43+
44+ self .interval = int (self .frame_interval * self .dt )
45+
46+ self .draw_labels = labels
47+
48+ for t in range (self .n ):
49+ # print t
50+ for s in range (t ):
51+ # print( n, t, s)
52+ if self .sc [t , s ] > self .edge_threshold :
53+ # print( 'edge', t, s, self.sc[t,s])
54+ self .G .add_edge (t , s )
55+
56+ # node color map
57+ self .cmap = plt .get_cmap ("plasma" ) # mpl.cm.cool
58+
59+ # default style
60+
61+ self .imagealpha = 0.5
62+
63+ self .edgecolor = "k"
64+ self .edgealpha = 0.8
65+ self .edgeweight = 1.0
66+
67+ self .nodesize = 50
68+ self .nodealpha = 0.8
69+ self .vmin = 0
70+ self .vmax = 50
71+
72+ self .lw = 0.5
73+
74+ if self .darkmode :
75+ plt .style .use ("dark" )
76+ # let's choose a cyberpunk style for the dark theme
77+ self .edgecolor = "#37f522"
78+ self .edgeweight = 0.5
79+ self .edgealpha = 0.6
80+
81+ self .nodesize = 40
82+ self .nodealpha = 0.8
83+ self .vmin = 0
84+ self .vmax = 30
85+ self .cmap = plt .get_cmap ("cool" ) # mpl.cm.cool
86+
87+ self .imagealpha = 0.5
88+
89+ self .lw = 1
90+
91+ # fname = os.path.join("neurolib", "data", "resources", "clean_brain_white.png")
92+ fname = os .path .join (
93+ pathlib .Path (__file__ ).parent .absolute (), ".." , "data" , "resources" , "clean_brain_white.png"
94+ )
95+ else :
96+ # plt.style.use("light")
97+ # fname = os.path.join("neurolib", "data", "resources", "clean_brain.png")
98+ fname = os .path .join (pathlib .Path (__file__ ).parent .absolute (), ".." , "data" , "resources" , "clean_brain.png" )
99+
100+ print (fname )
101+ self .imgTopView = mpl .image .imread (fname )
102+
103+ self .pbar = tqdm .tqdm (total = self .nframes )
104+
105+ def update (self , i , ax , ax_rates = None , node_color = None , node_size = None , node_alpha = None , clear = True ):
106+ frame = int (i * self .frame_interval )
107+
108+ node_color = node_color or self .data [:, frame ]
109+ node_size = node_size or self .nodesize
110+ node_alpha = node_alpha or self .nodealpha
111+ if clear :
112+ ax .cla ()
113+ im = ax .imshow (self .imgTopView , alpha = self .imagealpha , origin = "upper" , extent = [40 , 202 , 28 , 240 ])
114+ ns = nx .draw_networkx_nodes (
115+ self .G ,
116+ pos = self .position ,
117+ node_color = node_color ,
118+ cmap = self .cmap ,
119+ vmin = self .vmin ,
120+ vmax = self .vmax ,
121+ node_size = node_size ,
122+ alpha = node_alpha ,
123+ ax = ax ,
124+ edgecolors = "k" ,
125+ )
126+ es = nx .draw_networkx_edges (
127+ self .G , pos = self .position , alpha = self .edgealpha , edge_color = self .edgecolor , ax = ax , width = self .edgeweight
128+ )
129+
130+ labels = {}
131+ for ni in range (self .n ):
132+ labels [ni ] = str (ni )
133+
134+ if self .draw_labels :
135+ nx .draw_networkx_labels (self .G , self .position , labels , font_size = 8 )
136+
137+ ax .set_axis_off ()
138+ ax .set_xlim (20 , 222 )
139+ ax .set_ylim (25 , 245 )
140+
141+ # timeseries
142+ if ax_rates :
143+ ax_rates .cla ()
144+ ax_rates .set_xticks ([])
145+ ax_rates .set_yticks ([])
146+ ax_rates .set_ylabel ("Brain activity" , fontsize = 8 )
147+
148+ t = np .linspace (0 , frame * self .dt , frame )
149+ ax_rates .plot (t , np .mean (self .data [:, :frame ], axis = 0 ).T , lw = self .lw )
150+
151+ t_total = self .data .shape [1 ] * self .dt
152+ ax_rates .set_xlim (0 , t_total )
153+
154+ self .pbar .update (1 )
155+ plt .tight_layout ()
156+ if clear :
157+ clear_output (wait = True )
158+
159+
160+ def plot_rates (model ):
161+ plt .figure (figsize = (4 , 1 ))
162+ plt_until = 10 * 1000
163+ plt .plot (model .t [model .t < plt_until ], model .output [:, model .t < plt_until ].T , lw = 0.5 )
164+
165+
166+ def plot_brain (
167+ model , ds , color = None , size = None , title = None , cbar = True , cmap = "RdBu" , clim = None , cbarticks = None , cbarticklabels = None
168+ ):
169+ """Dump and easy wrapper around the brain plotting function.
170+
171+ :param color: colors of nodes, defaults to None
172+ :type color: numpy.ndarray, optional
173+ :param size: size of the nodes, defaults to None
174+ :type size: numpy.ndarray, optional
175+ :raises ValueError: Raises error if node size is too big.
176+ """
177+ plot_data = model .output
178+ s = Brainplot (ds .Cmat , model .output , fps = 10 , darkmode = False )
179+ s .cmap = plt .get_cmap (cmap )
180+
181+ if color is None :
182+ color = np .ones (ds .Cmat .shape [0 ])
183+
184+ dpi = 300
185+ fig = plt .figure (dpi = dpi )
186+ ax = plt .gca ()
187+ if title :
188+ ax .set_title (title , fontsize = 26 )
189+
190+ if clim is None :
191+ s .vmin , s .vmax = np .min (color ), np .max (color )
192+ else :
193+ s .vmin , s .vmax = clim [0 ], clim [1 ]
194+
195+ if size is not None :
196+ node_size = size
197+ else :
198+ # some weird scaling of the color to a size
199+ def norm (what ):
200+ what = np .asarray (what .copy ())
201+ if np .min (what ) < np .max (what ):
202+ what -= np .min (what )
203+ what = np .divide (what , np .max (what ))
204+ return what
205+
206+ node_size = list (np .exp ((norm (color ) + 2 ) * 2 ))
207+
208+ if isinstance (color , np .ndarray ):
209+ color = list (color )
210+ if isinstance (node_size , np .ndarray ):
211+ node_size = list (node_size )
212+
213+ if np .max (node_size ) > 2000 :
214+ raise ValueError (f"node_size too big: { np .max (node_size )} " )
215+ s .update (0 , ax , node_color = color , node_size = node_size , clear = False )
216+ if cbar :
217+ # cbaxes = fig.add_axes([0.68, 0.1, 0.015, 0.7])
218+ cbaxes = fig .add_axes ([0.75 , 0.1 , 0.015 , 0.7 ])
219+ sm = plt .cm .ScalarMappable (cmap = s .cmap , norm = plt .Normalize (vmin = s .vmin , vmax = s .vmax ))
220+ cbar = plt .colorbar (sm , cbaxes , ticks = cbarticks )
221+ cbar .ax .tick_params (labelsize = 16 )
222+ if cbarticklabels :
223+ cbar .ax .set_yticklabels (cbarticklabels )
0 commit comments