Skip to content

Commit 0776d7a

Browse files
authored
Added control interval (#256)
* added control interval parameter and test * fix control interval test * fix control interval test
1 parent 9071ca8 commit 0776d7a

File tree

5 files changed

+93
-18
lines changed

5 files changed

+93
-18
lines changed

neurolib/control/optimal_control/oc_aln/oc_aln.py

Lines changed: 42 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,14 @@
22
import numpy as np
33

44
from neurolib.control.optimal_control.oc import OC
5-
from neurolib.models.aln.timeIntegration import compute_hx, compute_hx_nw, Duh, Dxdoth, compute_hx_de, compute_hx_di
5+
from neurolib.models.aln.timeIntegration import (
6+
compute_hx,
7+
compute_hx_nw,
8+
Duh,
9+
Dxdoth,
10+
compute_hx_de,
11+
compute_hx_di,
12+
)
613

714

815
class OcAln(OC):
@@ -21,6 +28,7 @@ def __init__(
2128
weights=None,
2229
print_array=[],
2330
cost_interval=(None, None),
31+
control_interval=(None, None),
2432
cost_matrix=None,
2533
control_matrix=None,
2634
M=1,
@@ -34,6 +42,7 @@ def __init__(
3442
print_array=print_array,
3543
cost_interval=cost_interval,
3644
cost_matrix=cost_matrix,
45+
control_interval=control_interval,
3746
control_matrix=control_matrix,
3847
M=M,
3948
M_validation=M_validation,
@@ -197,7 +206,9 @@ def compute_hx_list(self):
197206
hx_de = self.compute_hx_de()
198207
hx_di = self.compute_hx_di()
199208

200-
return numba.typed.List([hx, hx_de, hx_di]), numba.typed.List([0, self.ndt_de, self.ndt_di])
209+
return numba.typed.List([hx, hx_de, hx_di]), numba.typed.List(
210+
[0, self.ndt_de, self.ndt_di]
211+
)
201212

202213
def compute_hx(self):
203214
"""Jacobians of ALNModel wrt. the 'e'- and 'i'-variable for each time step.
@@ -317,7 +328,9 @@ def get_fullstate(self):
317328
if t <= T - 2:
318329
self.model.params[iv] = control[:, iv_ind, t : t + 2]
319330
elif t == T - 1:
320-
self.model.params[iv] = np.concatenate((control[:, iv_ind, t:], np.zeros((N, 1))), axis=1)
331+
self.model.params[iv] = np.concatenate(
332+
(control[:, iv_ind, t:], np.zeros((N, 1))), axis=1
333+
)
321334
else:
322335
self.model.params[iv] = 0.0
323336
self.model.run()
@@ -349,11 +362,19 @@ def setasinit(self, fullstate, t):
349362

350363
for n in range(N):
351364
for v in range(V):
352-
if "rates" in self.model.init_vars[v] or "IA" in self.model.init_vars[v]:
365+
if (
366+
"rates" in self.model.init_vars[v]
367+
or "IA" in self.model.init_vars[v]
368+
):
353369
if t >= T:
354-
self.model.params[self.model.init_vars[v]] = fullstate[:, v, t - T : t + 1]
370+
self.model.params[self.model.init_vars[v]] = fullstate[
371+
:, v, t - T : t + 1
372+
]
355373
else:
356-
init = np.concatenate((fullstate[:, v, -T + t + 1 :], fullstate[:, v, : t + 1]), axis=1)
374+
init = np.concatenate(
375+
(fullstate[:, v, -T + t + 1 :], fullstate[:, v, : t + 1]),
376+
axis=1,
377+
)
357378
self.model.params[self.model.init_vars[v]] = init
358379
else:
359380
self.model.params[self.model.init_vars[v]] = fullstate[:, v, t]
@@ -371,8 +392,13 @@ def getinitstate(self):
371392

372393
for n in range(N):
373394
for v in range(V):
374-
if "rates" in self.model.init_vars[v] or "IA" in self.model.init_vars[v]:
375-
initstate[n, v, :] = self.model.params[self.model.init_vars[v]][n, -T:]
395+
if (
396+
"rates" in self.model.init_vars[v]
397+
or "IA" in self.model.init_vars[v]
398+
):
399+
initstate[n, v, :] = self.model.params[self.model.init_vars[v]][
400+
n, -T:
401+
]
376402

377403
else:
378404
initstate[n, v, :] = self.model.params[self.model.init_vars[v]][n]
@@ -389,7 +415,10 @@ def getfinalstate(self):
389415
state = np.zeros((N, V))
390416
for n in range(N):
391417
for v in range(V):
392-
if "rates" in self.model.state_vars[v] or "IA" in self.model.state_vars[v]:
418+
if (
419+
"rates" in self.model.state_vars[v]
420+
or "IA" in self.model.state_vars[v]
421+
):
393422
state[n, v] = self.model.state[self.model.state_vars[v]][n, -1]
394423

395424
else:
@@ -408,7 +437,10 @@ def setinitstate(self, state):
408437

409438
for n in range(N):
410439
for v in range(V):
411-
if "rates" in self.model.init_vars[v] or "IA" in self.model.init_vars[v]:
440+
if (
441+
"rates" in self.model.init_vars[v]
442+
or "IA" in self.model.init_vars[v]
443+
):
412444
self.model.params[self.model.init_vars[v]] = state[:, v, -T:]
413445
else:
414446
self.model.params[self.model.init_vars[v]] = state[:, v, -1]

neurolib/control/optimal_control/oc_fhn/oc_fhn.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ def __init__(
1919
maximum_control_strength=None,
2020
print_array=[],
2121
cost_interval=(None, None),
22+
control_interval=(None, None),
2223
cost_matrix=None,
2324
control_matrix=None,
2425
M=1,
@@ -32,6 +33,7 @@ def __init__(
3233
maximum_control_strength=maximum_control_strength,
3334
print_array=print_array,
3435
cost_interval=cost_interval,
36+
control_interval=control_interval,
3537
cost_matrix=cost_matrix,
3638
control_matrix=control_matrix,
3739
M=M,

neurolib/control/optimal_control/oc_hopf/oc_hopf.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ def __init__(
2121
maximum_control_strength=None,
2222
print_array=[],
2323
cost_interval=(None, None),
24+
control_interval=(None, None),
2425
cost_matrix=None,
2526
control_matrix=None,
2627
M=1,
@@ -34,6 +35,7 @@ def __init__(
3435
maximum_control_strength=maximum_control_strength,
3536
print_array=print_array,
3637
cost_interval=cost_interval,
38+
control_interval=control_interval,
3739
cost_matrix=cost_matrix,
3840
control_matrix=control_matrix,
3941
M=M,

neurolib/control/optimal_control/oc_wc/oc_wc.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ def __init__(
2020
weights=None,
2121
print_array=[],
2222
cost_interval=(None, None),
23+
control_interval=(None, None),
2324
cost_matrix=None,
2425
control_matrix=None,
2526
M=1,
@@ -32,6 +33,7 @@ def __init__(
3233
weights=weights,
3334
print_array=print_array,
3435
cost_interval=cost_interval,
36+
control_interval=control_interval,
3537
cost_matrix=cost_matrix,
3638
control_matrix=control_matrix,
3739
M=M,

tests/control/optimal_control/test_oc_fhn.py

Lines changed: 45 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,9 @@ def test_1n(self):
2626
cost_mat = np.zeros((model.params.N, len(model.output_vars)))
2727
control_mat = np.zeros((model.params.N, len(model.state_vars)))
2828
control_mat[0, input_channel] = 1.0 # only allow inputs to input_channel
29-
cost_mat[0, np.abs(input_channel - 1).astype(int)] = 1.0 # only measure other channel
29+
cost_mat[
30+
0, np.abs(input_channel - 1).astype(int)
31+
] = 1.0 # only measure other channel
3032

3133
test_oc_utils.set_input(model, p.ZERO_INPUT_1N_6)
3234
model.params[model.input_vars[input_channel]] = p.TEST_INPUT_1N_6
@@ -52,7 +54,9 @@ def test_1n(self):
5254
model_controlled.optimize(p.ITERATIONS)
5355
control = model_controlled.control
5456

55-
c_diff = (np.abs(control[0, input_channel, :] - p.TEST_INPUT_1N_6[0, :]),)
57+
c_diff = (
58+
np.abs(control[0, input_channel, :] - p.TEST_INPUT_1N_6[0, :]),
59+
)
5660

5761
if np.amax(c_diff) < p.LIMIT_DIFF:
5862
control_coincide = True
@@ -99,7 +103,11 @@ def test_2n(self):
99103
)
100104

101105
model_controlled.control = np.concatenate(
102-
[p.INIT_INPUT_2N_8[:, np.newaxis, :], p.ZERO_INPUT_2N_8[:, np.newaxis, :]], axis=1
106+
[
107+
p.INIT_INPUT_2N_8[:, np.newaxis, :],
108+
p.ZERO_INPUT_2N_8[:, np.newaxis, :],
109+
],
110+
axis=1,
103111
)
104112
model_controlled.update_input()
105113

@@ -261,7 +269,9 @@ def test_u_max_no_optimizations(self):
261269
control_matrix=control_mat,
262270
)
263271

264-
self.assertTrue(np.max(np.abs(model_controlled.control) <= maximum_control_strength))
272+
self.assertTrue(
273+
np.max(np.abs(model_controlled.control) <= maximum_control_strength)
274+
)
265275

266276
# Arbitrary network and control setting, initial control violates the maximum absolute criterion.
267277
def test_u_max_after_optimizations(self):
@@ -289,7 +299,9 @@ def test_u_max_after_optimizations(self):
289299
)
290300

291301
model_controlled.optimize(1)
292-
self.assertTrue(np.max(np.abs(model_controlled.control) <= maximum_control_strength))
302+
self.assertTrue(
303+
np.max(np.abs(model_controlled.control) <= maximum_control_strength)
304+
)
293305

294306
def test_adjust_init(self):
295307
print("Test adjust_init function of OC class")
@@ -327,7 +339,10 @@ def test_adjust_init(self):
327339
for init_var0 in model.init_vars:
328340
if "ou" in init_var0:
329341
continue
330-
self.assertTrue(model_controlled.model.params[init_var0].shape == targetinitshape)
342+
self.assertTrue(
343+
model_controlled.model.params[init_var0].shape
344+
== targetinitshape
345+
)
331346

332347
def test_adjust_input(self):
333348
print("Test test_adjust_input function of OC class")
@@ -337,7 +352,9 @@ def test_adjust_input(self):
337352
model = FHNModel(Cmat=cmat, Dmat=dmat)
338353
model.params.duration = p.TEST_DURATION_6
339354

340-
target = np.zeros((model.params.N, len(model.state_vars), p.TEST_INPUT_2N_6.shape[1]))
355+
target = np.zeros(
356+
(model.params.N, len(model.state_vars), p.TEST_INPUT_2N_6.shape[1])
357+
)
341358
targetinputshape = (target.shape[0], target.shape[2])
342359

343360
for test_input in [
@@ -358,7 +375,27 @@ def test_adjust_input(self):
358375
)
359376

360377
for input_var0 in model.input_vars:
361-
self.assertTrue(model_controlled.model.params[input_var0].shape == targetinputshape)
378+
self.assertTrue(
379+
model_controlled.model.params[input_var0].shape
380+
== targetinputshape
381+
)
382+
383+
# tests if the control is only active in the control interval
384+
# single-node case
385+
def test_onenode_control_interval(self):
386+
print("Test OC for control_interval = [0,0] in single-node model")
387+
model = FHNModel()
388+
389+
model.params["duration"] = p.TEST_DURATION_8
390+
test_oc_utils.setinitzero_1n(model)
391+
392+
test_oc_utils.set_input(model, p.TEST_INPUT_1N_8)
393+
model.run()
394+
target = test_oc_utils.gettarget_1n(model)
395+
396+
model_controlled = oc_fhn.OcFhn(model, target, control_interval=(0, 1))
397+
model_controlled.optimize(1)
398+
self.assertEqual(np.amax(np.abs(model_controlled.control[:, :, 1:])), 0.0)
362399

363400
# tests if the cost is independent of the integration time step
364401
def test_cost_dt(self):

0 commit comments

Comments
 (0)