Skip to content

Commit 45ec0b2

Browse files
add the simnib script into repository
1 parent 9ce56eb commit 45ec0b2

File tree

2 files changed

+3931
-0
lines changed

2 files changed

+3931
-0
lines changed

DATA-e-field-simnibs-to-aal2.ipynb

Lines changed: 3708 additions & 0 deletions
Large diffs are not rendered by default.

neurolib/utils/brainplot.py

Lines changed: 223 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,223 @@
1+
import os
2+
import numpy as np
3+
import matplotlib as mpl
4+
import matplotlib.pyplot as plt
5+
6+
import logging
7+
import tqdm
8+
import networkx as nx
9+
import nibabel as nib
10+
import pathlib
11+
12+
from IPython.display import clear_output
13+
from neurolib.utils import atlases
14+
15+
16+
class Brainplot:
17+
def __init__(self, Cmat, data, nframes=None, dt=0.1, fps=25, labels=False, darkmode=True):
18+
self.sc = Cmat
19+
self.n = self.sc.shape[0]
20+
21+
self.data = data
22+
self.darkmode = darkmode
23+
24+
self.G = nx.Graph()
25+
self.G.add_nodes_from(range(self.n))
26+
27+
coords = {}
28+
atlas = atlases.AutomatedAnatomicalParcellation2()
29+
for i, c in enumerate(atlas.coords()):
30+
coords[i] = [c[0], c[1]]
31+
self.position = coords
32+
33+
self.edge_threshold = 0.01
34+
35+
self.fps = fps
36+
self.dt = dt
37+
38+
nframes = nframes or int((data.shape[1] * self.dt / 1000) * self.fps) # 20 fps default
39+
logging.info(f"Defaulting to {nframes} frames at {self.fps} fp/s")
40+
self.nframes = nframes
41+
42+
self.frame_interval = self.data.shape[1] // self.nframes
43+
44+
self.interval = int(self.frame_interval * self.dt)
45+
46+
self.draw_labels = labels
47+
48+
for t in range(self.n):
49+
# print t
50+
for s in range(t):
51+
# print( n, t, s)
52+
if self.sc[t, s] > self.edge_threshold:
53+
# print( 'edge', t, s, self.sc[t,s])
54+
self.G.add_edge(t, s)
55+
56+
# node color map
57+
self.cmap = plt.get_cmap("plasma") # mpl.cm.cool
58+
59+
# default style
60+
61+
self.imagealpha = 0.5
62+
63+
self.edgecolor = "k"
64+
self.edgealpha = 0.8
65+
self.edgeweight = 1.0
66+
67+
self.nodesize = 50
68+
self.nodealpha = 0.8
69+
self.vmin = 0
70+
self.vmax = 50
71+
72+
self.lw = 0.5
73+
74+
if self.darkmode:
75+
plt.style.use("dark")
76+
# let's choose a cyberpunk style for the dark theme
77+
self.edgecolor = "#37f522"
78+
self.edgeweight = 0.5
79+
self.edgealpha = 0.6
80+
81+
self.nodesize = 40
82+
self.nodealpha = 0.8
83+
self.vmin = 0
84+
self.vmax = 30
85+
self.cmap = plt.get_cmap("cool") # mpl.cm.cool
86+
87+
self.imagealpha = 0.5
88+
89+
self.lw = 1
90+
91+
# fname = os.path.join("neurolib", "data", "resources", "clean_brain_white.png")
92+
fname = os.path.join(
93+
pathlib.Path(__file__).parent.absolute(), "..", "data", "resources", "clean_brain_white.png"
94+
)
95+
else:
96+
# plt.style.use("light")
97+
# fname = os.path.join("neurolib", "data", "resources", "clean_brain.png")
98+
fname = os.path.join(pathlib.Path(__file__).parent.absolute(), "..", "data", "resources", "clean_brain.png")
99+
100+
print(fname)
101+
self.imgTopView = mpl.image.imread(fname)
102+
103+
self.pbar = tqdm.tqdm(total=self.nframes)
104+
105+
def update(self, i, ax, ax_rates=None, node_color=None, node_size=None, node_alpha=None, clear=True):
106+
frame = int(i * self.frame_interval)
107+
108+
node_color = node_color or self.data[:, frame]
109+
node_size = node_size or self.nodesize
110+
node_alpha = node_alpha or self.nodealpha
111+
if clear:
112+
ax.cla()
113+
im = ax.imshow(self.imgTopView, alpha=self.imagealpha, origin="upper", extent=[40, 202, 28, 240])
114+
ns = nx.draw_networkx_nodes(
115+
self.G,
116+
pos=self.position,
117+
node_color=node_color,
118+
cmap=self.cmap,
119+
vmin=self.vmin,
120+
vmax=self.vmax,
121+
node_size=node_size,
122+
alpha=node_alpha,
123+
ax=ax,
124+
edgecolors="k",
125+
)
126+
es = nx.draw_networkx_edges(
127+
self.G, pos=self.position, alpha=self.edgealpha, edge_color=self.edgecolor, ax=ax, width=self.edgeweight
128+
)
129+
130+
labels = {}
131+
for ni in range(self.n):
132+
labels[ni] = str(ni)
133+
134+
if self.draw_labels:
135+
nx.draw_networkx_labels(self.G, self.position, labels, font_size=8)
136+
137+
ax.set_axis_off()
138+
ax.set_xlim(20, 222)
139+
ax.set_ylim(25, 245)
140+
141+
# timeseries
142+
if ax_rates:
143+
ax_rates.cla()
144+
ax_rates.set_xticks([])
145+
ax_rates.set_yticks([])
146+
ax_rates.set_ylabel("Brain activity", fontsize=8)
147+
148+
t = np.linspace(0, frame * self.dt, frame)
149+
ax_rates.plot(t, np.mean(self.data[:, :frame], axis=0).T, lw=self.lw)
150+
151+
t_total = self.data.shape[1] * self.dt
152+
ax_rates.set_xlim(0, t_total)
153+
154+
self.pbar.update(1)
155+
plt.tight_layout()
156+
if clear:
157+
clear_output(wait=True)
158+
159+
160+
def plot_rates(model):
161+
plt.figure(figsize=(4, 1))
162+
plt_until = 10 * 1000
163+
plt.plot(model.t[model.t < plt_until], model.output[:, model.t < plt_until].T, lw=0.5)
164+
165+
166+
def plot_brain(
167+
model, ds, color=None, size=None, title=None, cbar=True, cmap="RdBu", clim=None, cbarticks=None, cbarticklabels=None
168+
):
169+
"""Dump and easy wrapper around the brain plotting function.
170+
171+
:param color: colors of nodes, defaults to None
172+
:type color: numpy.ndarray, optional
173+
:param size: size of the nodes, defaults to None
174+
:type size: numpy.ndarray, optional
175+
:raises ValueError: Raises error if node size is too big.
176+
"""
177+
plot_data = model.output
178+
s = Brainplot(ds.Cmat, model.output, fps=10, darkmode=False)
179+
s.cmap = plt.get_cmap(cmap)
180+
181+
if color is None:
182+
color = np.ones(ds.Cmat.shape[0])
183+
184+
dpi = 300
185+
fig = plt.figure(dpi=dpi)
186+
ax = plt.gca()
187+
if title:
188+
ax.set_title(title, fontsize=26)
189+
190+
if clim is None:
191+
s.vmin, s.vmax = np.min(color), np.max(color)
192+
else:
193+
s.vmin, s.vmax = clim[0], clim[1]
194+
195+
if size is not None:
196+
node_size = size
197+
else:
198+
# some weird scaling of the color to a size
199+
def norm(what):
200+
what = np.asarray(what.copy())
201+
if np.min(what) < np.max(what):
202+
what -= np.min(what)
203+
what = np.divide(what, np.max(what))
204+
return what
205+
206+
node_size = list(np.exp((norm(color) + 2) * 2))
207+
208+
if isinstance(color, np.ndarray):
209+
color = list(color)
210+
if isinstance(node_size, np.ndarray):
211+
node_size = list(node_size)
212+
213+
if np.max(node_size) > 2000:
214+
raise ValueError(f"node_size too big: {np.max(node_size)}")
215+
s.update(0, ax, node_color=color, node_size=node_size, clear=False)
216+
if cbar:
217+
# cbaxes = fig.add_axes([0.68, 0.1, 0.015, 0.7])
218+
cbaxes = fig.add_axes([0.75, 0.1, 0.015, 0.7])
219+
sm = plt.cm.ScalarMappable(cmap=s.cmap, norm=plt.Normalize(vmin=s.vmin, vmax=s.vmax))
220+
cbar = plt.colorbar(sm, cbaxes, ticks=cbarticks)
221+
cbar.ax.tick_params(labelsize=16)
222+
if cbarticklabels:
223+
cbar.ax.set_yticklabels(cbarticklabels)

0 commit comments

Comments
 (0)