Skip to content

Commit 6be0d37

Browse files
authored
spring cleaning: removing bugs and code smells (#261)
* removed non-existing argument from aln docstring * removed params['bold'] option * cleaned up bold initialization check * removed unreachable condition * fix continue_run order * fix typo * clean up simulateBold * fix continue_run for chunkwise=True * simplified outputDict.items() * major fixes: setOutput / BOLD * remove unused arguments * fixed aln minimal notebook * fixes to multimodel, testexploration, and continue_run * reverted to automatic append for BOLD outputs * remove duplicate arguments * removed unnecessary append after reverting to default append for BOLD * remove unused `self.start_t`, fix BOLD append for multimodel * allow `continue_run=True` on first model run * allow continue_run=True on first multimodel run * removed first run continue_run warnings
1 parent 4e9454a commit 6be0d37

File tree

12 files changed

+246
-300
lines changed

12 files changed

+246
-300
lines changed

examples/example-0-aln-minimal.ipynb

Lines changed: 85 additions & 76 deletions
Large diffs are not rendered by default.

neurolib/models/aln/model.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,6 @@ def __init__(self, params=None, Cmat=None, Dmat=None, lookupTableFileName=None,
6060
:param Dmat: Distance matrix between all nodes (in mm)
6161
:param lookupTableFileName: Filename for precomputed transfer functions and tables
6262
:param seed: Random number generator seed
63-
:param simulateChunkwise: Chunkwise time integration (for lower memory use)
6463
"""
6564

6665
# Global attributes

neurolib/models/bold/model.py

Lines changed: 4 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -36,12 +36,12 @@ def __init__(self, N, dt, normalize_input=False, normalize_max=50):
3636
self.V_BOLD = np.ones((N,))
3737
# Blood volume
3838

39-
def run(self, activity, append=False):
39+
def run(self, activity):
4040
"""Runs the Balloon-Windkessel BOLD simulation.
4141
4242
Parameters:
4343
:param activity: Neuronal firing rate in Hz
44-
44+
4545
:param activity: Neuronal firing rate in Hz
4646
:type activity: numpy.ndarray
4747
"""
@@ -50,7 +50,6 @@ def run(self, activity, append=False):
5050
BOLD_chunk, self.X_BOLD, self.F_BOLD, self.Q_BOLD, self.V_BOLD = simulateBOLD(
5151
activity,
5252
self.dt * 1e-3,
53-
10000 * np.ones((self.N,)),
5453
X=self.X_BOLD,
5554
F=self.F_BOLD,
5655
Q=self.Q_BOLD,
@@ -67,19 +66,8 @@ def run(self, activity, append=False):
6766
* self.dt
6867
)
6968

70-
if self.BOLD.shape[1] == 0:
71-
# add new data
72-
self.t_BOLD = t_BOLD_resampled
73-
self.BOLD = BOLD_resampled
74-
elif append is True:
75-
# append new data to old data
76-
self.t_BOLD = np.hstack((self.t_BOLD, t_BOLD_resampled))
77-
self.BOLD = np.hstack((self.BOLD, BOLD_resampled))
78-
else:
79-
# overwrite old data
80-
self.t_BOLD = t_BOLD_resampled
81-
self.BOLD = BOLD_resampled
82-
69+
self.t_BOLD = t_BOLD_resampled
70+
self.BOLD = BOLD_resampled
8371
self.BOLD_chunk = BOLD_resampled
8472

8573
self.idxLastT = self.idxLastT + activity.shape[1]

neurolib/models/bold/timeIntegration.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import numba
33

44

5-
def simulateBOLD(Z, dt, voxelCounts, X=None, F=None, Q=None, V=None):
5+
def simulateBOLD(Z, dt, voxelCounts=None, X=None, F=None, Q=None, V=None):
66
"""Simulate BOLD activity using the Balloon-Windkessel model.
77
See Friston 2000, Friston 2003 and Deco 2013 for reference on how the BOLD signal is simulated.
88
The returned BOLD signal should be downsampled to be comparable to a recorded fMRI signal.
@@ -11,7 +11,7 @@ def simulateBOLD(Z, dt, voxelCounts, X=None, F=None, Q=None, V=None):
1111
:type Z: numpy.ndarray
1212
:param dt: dt of input activity in s
1313
:type dt: float
14-
:param voxelCounts: Number of voxels in each region (not used yet!)
14+
:param voxelCounts: Number of voxels in each region (not used yet!) # TODO
1515
:type voxelCounts: numpy.ndarray
1616
:param X: Initial values of Vasodilatory signal, defaults to None
1717
:type X: numpy.ndarray, optional
@@ -28,9 +28,6 @@ def simulateBOLD(Z, dt, voxelCounts, X=None, F=None, Q=None, V=None):
2828

2929
N = np.shape(Z)[0]
3030

31-
if "voxelCounts" not in globals():
32-
voxelCounts = np.ones((N,))
33-
3431
# Balloon-Windkessel model parameters (from Friston 2003):
3532
# Friston paper: Nonlinear responses in fMRI: The balloon model, Volterra kernels, and other hemodynamics
3633
# Note: the distribution of each Balloon-Windkessel models parameters are given per voxel

neurolib/models/model.py

Lines changed: 78 additions & 95 deletions
Original file line numberDiff line numberDiff line change
@@ -66,45 +66,47 @@ def initializeBold(self):
6666
self.boldInitialized = True
6767
# logging.info(f"{self.name}: BOLD model initialized.")
6868

69-
def simulateBold(self, t, variables, append=False):
69+
def get_bold_variable(self, variables):
70+
default_index = self.state_vars.index(self.default_output)
71+
return variables[default_index]
72+
73+
def simulateBold(self, bold_variable, append=True):
7074
"""Gets the default output of the model and simulates the BOLD model.
7175
Adds the simulated BOLD signal to outputs.
7276
"""
73-
if self.boldInitialized:
74-
# first we loop through all state variables
75-
for svn, sv in zip(self.state_vars, variables):
76-
# the default output is used as the input for the bold model
77-
if svn == self.default_output:
78-
bold_input = sv[:, self.startindt :]
79-
# logging.debug(f"BOLD input `{svn}` of shape {bold_input.shape}")
80-
if bold_input.shape[1] >= self.boldModel.samplingRate_NDt:
81-
# only if the length of the output has a zero mod to the sampling rate,
82-
# the downsampled output from the boldModel can correctly appended to previous data
83-
# so: we are lazy here and simply disable appending in that case ...
84-
if not bold_input.shape[1] % self.boldModel.samplingRate_NDt == 0:
85-
append = False
86-
logging.warn(
87-
f"Output size {bold_input.shape[1]} is not a multiple of BOLD sampling length { self.boldModel.samplingRate_NDt}, will not append data."
88-
)
89-
logging.debug(f"Simulating BOLD: boldModel.run(append={append})")
90-
91-
# transform bold input according to self.boldInputTransform
92-
if self.boldInputTransform:
93-
bold_input = self.boldInputTransform(bold_input)
94-
95-
# simulate bold model
96-
self.boldModel.run(bold_input, append=append)
97-
98-
t_BOLD = self.boldModel.t_BOLD
99-
BOLD = self.boldModel.BOLD
100-
self.setOutput("BOLD.t_BOLD", t_BOLD)
101-
self.setOutput("BOLD.BOLD", BOLD)
102-
else:
103-
logging.warn(
104-
f"Will not simulate BOLD if output {bold_input.shape[1]*self.params['dt']} not at least of duration {self.boldModel.samplingRate_NDt*self.params['dt']}"
105-
)
106-
else:
77+
if not self.boldInitialized:
10778
logging.warn("BOLD model not initialized, not simulating BOLD. Use `run(bold=True)`")
79+
return
80+
81+
bold_input = bold_variable[:, self.startindt :]
82+
# logging.debug(f"BOLD input `{svn}` of shape {bold_input.shape}")
83+
if not bold_input.shape[1] >= self.boldModel.samplingRate_NDt:
84+
logging.warn(
85+
f"Will not simulate BOLD if output {bold_input.shape[1]*self.params['dt']} not at least of duration {self.boldModel.samplingRate_NDt*self.params['dt']}"
86+
)
87+
return
88+
89+
# only if the length of the output has a zero mod to the sampling rate,
90+
# the downsampled output from the boldModel can correctly appended to previous data
91+
# so: we are lazy here and simply disable appending in that case ...
92+
if append and not bold_input.shape[1] % self.boldModel.samplingRate_NDt == 0:
93+
append = False
94+
logging.warn(
95+
f"Output size {bold_input.shape[1]} is not a multiple of BOLD sampling length { self.boldModel.samplingRate_NDt}, will not append data."
96+
)
97+
logging.debug(f"Simulating BOLD: boldModel.run()")
98+
99+
# transform bold input according to self.boldInputTransform
100+
if self.boldInputTransform:
101+
bold_input = self.boldInputTransform(bold_input)
102+
103+
# simulate bold model
104+
self.boldModel.run(bold_input)
105+
106+
t_BOLD = self.boldModel.t_BOLD
107+
BOLD = self.boldModel.BOLD
108+
self.setOutput("BOLD.t_BOLD", t_BOLD, append=append)
109+
self.setOutput("BOLD.BOLD", BOLD, append=append)
108110

109111
def checkChunkwise(self, chunksize):
110112
"""Checks if the model fulfills requirements for chunkwise simulation.
@@ -172,21 +174,16 @@ def initializeRun(self, initializeBold=False):
172174
# check dt / sampling_dt
173175
self.setSamplingDt()
174176

175-
# force bold if params['bold'] == True
176-
if self.params.get("bold"):
177-
initializeBold = True
178177
# set up the bold model, if it didn't happen yet
179178
if initializeBold and not self.boldInitialized:
180179
self.initializeBold()
181180

182181
def run(
183182
self,
184-
inputs=None,
185183
chunkwise=False,
186184
chunksize=None,
187185
bold=False,
188-
append=False,
189-
append_outputs=None,
186+
append_outputs=False,
190187
continue_run=False,
191188
):
192189
"""
@@ -195,7 +192,7 @@ def run(
195192
The model can be run in three different ways:
196193
1) `model.run()` starts a new run.
197194
2) `model.run(chunkwise=True)` runs the simulation in chunks of length `chunksize`.
198-
3) `mode.run(continue_run=True)` continues the simulation of a previous run.
195+
3) `mode.run(continue_run=True)` continues the simulation of a previous run. This has no effect during the first run.
199196
200197
:param inputs: list of inputs to the model, must have the same order as model.input_vars. Note: no sanity check is performed for performance reasons. Take care of the inputs yourself.
201198
:type inputs: list[np.ndarray|]
@@ -205,28 +202,24 @@ def run(
205202
:type chunksize: int, optional
206203
:param bold: simulate BOLD signal (only for chunkwise integration), defaults to False
207204
:type bold: bool, optional
208-
:param append: append the chunkwise outputs to the outputs attribute, defaults to False, defaults to False
209-
:type append: bool, optional
210-
:param continue_run: continue a simulation by using the initial values from a previous simulation
205+
:param append_outputs: append new and chunkwise outputs to the outputs attribute, defaults to False. Note: BOLD outputs are always appended.
206+
:type append_outputs: bool, optional
207+
:param continue_run: continue a simulation by using the initial values from a previous simulation. This has no effect during the first run.
211208
:type continue_run: bool
212209
"""
213-
# TODO: legacy argument support
214-
if append_outputs is not None:
215-
append = append_outputs
210+
self.initializeRun(initializeBold=bold)
216211

217212
# if a previous run is not to be continued clear the model's state
218-
if continue_run is False:
213+
if continue_run:
214+
self.setInitialValuesToLastState()
215+
else:
219216
self.clearModelState()
220217

221-
self.initializeRun(initializeBold=bold)
222-
223218
# enable chunkwise if chunksize is set
224219
chunkwise = chunkwise if chunksize is None else True
225220

226221
if chunkwise is False:
227-
self.integrate(append_outputs=append, simulate_bold=bold)
228-
if continue_run:
229-
self.setInitialValuesToLastState()
222+
self.integrate(append_outputs=append_outputs, simulate_bold=bold)
230223

231224
else:
232225
if chunksize is None:
@@ -235,10 +228,8 @@ def run(
235228
# check if model is safe for chunkwise integration
236229
# and whether sampling_dt is compatible with duration and chunksize
237230
self.checkChunkwise(chunksize)
238-
if bold and not self.boldInitialized:
239-
logging.warn(f"{self.name}: BOLD model not initialized, not simulating BOLD. Use `run(bold=True)`")
240-
bold = False
241-
self.integrateChunkwise(chunksize=chunksize, bold=bold, append_outputs=append)
231+
232+
self.integrateChunkwise(chunksize=chunksize, bold=bold, append_outputs=append_outputs)
242233

243234
# check if there was a problem with the simulated data
244235
self.checkOutputs()
@@ -260,20 +251,17 @@ def checkOutputs(self):
260251
def integrate(self, append_outputs=False, simulate_bold=False):
261252
"""Calls each models `integration` function and saves the state and the outputs of the model.
262253
263-
:param append: append the chunkwise outputs to the outputs attribute, defaults to False, defaults to False
254+
:param append: append the chunkwise outputs to the outputs attribute, defaults to False
264255
:type append: bool, optional
265256
"""
266257
# run integration
267258
t, *variables = self.integration(self.params)
268259
self.storeOutputsAndStates(t, variables, append=append_outputs)
269260

270-
# force bold if params['bold'] == True
271-
if self.params.get("bold"):
272-
simulate_bold = True
273-
274261
# bold simulation after integration
275262
if simulate_bold and self.boldInitialized:
276-
self.simulateBold(t, variables, append=True)
263+
bold_variable = self.get_bold_variable(variables)
264+
self.simulateBold(bold_variable, append=True)
277265

278266
def integrateChunkwise(self, chunksize, bold=False, append_outputs=False):
279267
"""Repeatedly calls the chunkwise integration for the whole duration of the simulation.
@@ -311,7 +299,7 @@ def clearModelState(self):
311299
self.state = dotdict({})
312300
self.outputs = dotdict({})
313301
# reinitialize bold model
314-
if self.params.get("bold"):
302+
if self.boldInitialized:
315303
self.initializeBold()
316304

317305
def storeOutputsAndStates(self, t, variables, append=False):
@@ -335,6 +323,8 @@ def storeOutputsAndStates(self, t, variables, append=False):
335323

336324
def setInitialValuesToLastState(self):
337325
"""Reads the last state of the model and sets the initial conditions to that state for continuing a simulation."""
326+
if not all([sv in self.state for sv in self.state_vars]):
327+
return
338328
for iv, sv in zip(self.init_vars, self.state_vars):
339329
# if state variables are one-dimensional (in space only)
340330
if (self.state[sv].ndim == 0) or (self.state[sv].ndim == 1):
@@ -474,25 +464,28 @@ def setOutput(self, name, data, append=False, removeICs=False):
474464
raise ValueError(f"Don't know how to truncate data of shape {data.shape}.")
475465

476466
# subsample to sampling dt
477-
if data.ndim == 1:
478-
data = data[:: self.sample_every]
479-
elif data.ndim == 2:
480-
data = data[:, :: self.sample_every]
481-
else:
482-
raise ValueError(f"Don't know how to subsample data of shape {data.shape}.")
467+
if data.shape[-1] >= self.params["duration"] - self.startindt:
468+
if data.ndim == 1:
469+
data = data[:: self.sample_every]
470+
elif data.ndim == 2:
471+
data = data[:, :: self.sample_every]
472+
else:
473+
raise ValueError(f"Don't know how to subsample data of shape {data.shape}.")
474+
475+
def save_leaf(node, name, data, append):
476+
if name in node:
477+
if data.ndim == 1 and name == "t":
478+
# special treatment for time data:
479+
# increment the time by the last recorded duration
480+
data += node[name][-1]
481+
if append and data.shape[-1] != 0:
482+
data = np.hstack((node[name], data))
483+
node[name] = data
484+
return node
483485

484486
# if the output is a single name (not dot.separated)
485487
if "." not in name:
486-
# append data
487-
if append and name in self.outputs:
488-
# special treatment for time data:
489-
# increment the time by the last recorded duration
490-
if name == "t":
491-
data += self.outputs[name][-1]
492-
self.outputs[name] = np.hstack((self.outputs[name], data))
493-
else:
494-
# save all data into output dict
495-
self.outputs[name] = data
488+
save_leaf(self.outputs, name, data, append)
496489
# set output as an attribute
497490
setattr(self, name, self.outputs[name])
498491
else:
@@ -503,18 +496,10 @@ def setOutput(self, name, data, append=False, removeICs=False):
503496
for i, k in enumerate(keys):
504497
# if it's the last iteration, store data
505498
if i == len(keys) - 1:
506-
# TODO: this needs to be append-aware like above
507-
# if append:
508-
# if k == "t":
509-
# data += level[k][-1]
510-
# level[k] = np.hstack((level[k], data))
511-
# else:
512-
# level[k] = data
513-
level[k] = data
499+
level = save_leaf(level, k, data, append)
514500
# if key is in outputs, then go deeper
515501
elif k in level:
516502
level = level[k]
517-
setattr(self, k, level)
518503
# if it's a new key, create new nested dictionary, set attribute, then go deeper
519504
else:
520505
level[k] = dotdict({})
@@ -604,11 +589,9 @@ def xr(self, group=""):
604589
assert len(timeDictKey) > 0, f"No time array found (starting with t) in output group {group}."
605590
t = outputDict[timeDictKey].copy()
606591
del outputDict[timeDictKey]
607-
outputs = []
608-
outputNames = []
609-
for key, value in outputDict.items():
610-
outputNames.append(key)
611-
outputs.append(value)
592+
593+
outputNames, outputs = zip(*outputDict.items())
594+
outputNames = list(outputNames)
612595

613596
nNodes = outputs[0].shape[0]
614597
nodes = list(range(nNodes))

0 commit comments

Comments
 (0)