Skip to content

Commit 9071ca8

Browse files
authored
Revise control tests (#258)
* revise control tests * fix issue in test
1 parent 4436468 commit 9071ca8

File tree

2 files changed

+82
-71
lines changed

2 files changed

+82
-71
lines changed

tests/control/optimal_control/test_oc.py

Lines changed: 44 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -12,48 +12,16 @@
1212
from neurolib.models.wc import WCModel
1313
from neurolib.utils.stimulus import ZeroInput
1414

15+
import test_oc_utils as test_oc_utils
16+
17+
p = test_oc_utils.params
18+
1519

1620
class TestOC(unittest.TestCase):
1721
"""
1822
Test functions in neurolib/control/optimal_control/oc.py
1923
"""
2024

21-
@staticmethod
22-
def get_arbitrary_array_finite_values():
23-
"""2x2x10 array filled with arbitrary positive and negative values in range [-5.0, 5.0]."""
24-
return np.array(
25-
[
26-
[
27-
[1.0, 2.0, 0.0, -1.5123, 2.35, 1.0, -1.0, 5.0, 0.0, 0.0],
28-
[-1.0, 3.0, 2.0, 0.5, -0.1, -0.2, 0.2, -0.2, 0.1, 0.5],
29-
],
30-
[
31-
[-0.5, 0.5, 0.5, -0.1, -1.0, 2.05, 3.1, -4.0, -5.0, -0.7],
32-
[1.0, 0.0, 0.0, 0.0, -1.0, 0.0, 0.0, 0.0, 0.0, 2.5],
33-
],
34-
]
35-
)
36-
37-
@staticmethod
38-
def get_arbitrary_array():
39-
"""2x2x10 array filled with arbitrary positive and negative values and +/-np.inf. Besides +/-np.inf all values
40-
fall in the range [-5., 5.].
41-
:return: An array with arbitrary float and np.inf values.
42-
:rtype: np.ndarray of shape 2 x 2 x 10
43-
"""
44-
return np.array(
45-
[
46-
[
47-
[1.0, 2.0, 0.0, -1.5123, 2.35, 1.0, -1.0, 0.0, 0.0, 0.0],
48-
[-1.0, 3.0, 2.0, 0.5, -np.inf, np.inf, 0.2, -0.2, 0.1, 5.0],
49-
],
50-
[
51-
[-0.5, np.inf, 0.5, -0.1, -1.0, 2.5, 3.01, -4.0, -5.0, -0.7],
52-
[1.0, 0.0, 0.0, 0.0, -1.0, 0.0, 0.0, 0.0, 0.0, 2.5],
53-
],
54-
]
55-
)
56-
5725
@staticmethod
5826
def get_deterministic_wilson_cowan_test_setup():
5927
"""Run a Wilson-Cowan model with default parameters for five time steps, use simulated time series to create
@@ -64,23 +32,12 @@ def get_deterministic_wilson_cowan_test_setup():
6432
:rtype: Tuple[neurolib.models.wc.model.WCModel, np.array]
6533
"""
6634
model = WCModel()
67-
dt = model.params["dt"]
68-
duration = 5 * dt
69-
model.params["duration"] = duration
35+
model.params["duration"] = p.TEST_DURATION_6
7036
model.run()
7137

72-
target = np.concatenate(
73-
(
74-
np.concatenate((model.params["exc_init"], model.params["inh_init"]), axis=1)[:, :, np.newaxis],
75-
np.stack((model.exc, model.inh), axis=1),
76-
),
77-
axis=2,
78-
)
79-
38+
target = test_oc_utils.gettarget_1n(model)
8039
target[0, 0, :] = target[0, 0, :] * 2.0
81-
82-
model.params["exc_ext"] = ZeroInput().generate_input(duration=duration + dt, dt=dt)
83-
model.params["inh_ext"] = ZeroInput().generate_input(duration=duration + dt, dt=dt)
40+
test_oc_utils.set_input(model, p.ZERO_INPUT_1N_6)
8441

8542
return model, target
8643

@@ -196,8 +153,27 @@ def test_compute_gradient(self):
196153
for n in range(N):
197154
self.assertTrue(np.all(gadient[n, :, :] == result))
198155

156+
def test_update_input(self):
157+
# Run the test with an instance of an arbitrarily derived class
158+
# check update_input function
159+
print("Tets update_input function.")
160+
161+
model, target = self.get_deterministic_wilson_cowan_test_setup()
162+
model_controlled = OcWc(model, target)
163+
model_controlled.control = np.concatenate(
164+
(p.TEST_INPUT_1N_6[:, np.newaxis, :], p.TEST_INPUT_1N_6[:, np.newaxis, :]), axis=1
165+
)
166+
167+
for iv in model.input_vars:
168+
self.assertTrue((model_controlled.model.params[iv] == 0.0).all())
169+
170+
model_controlled.update_input()
171+
172+
for iv in model.input_vars:
173+
self.assertTrue((model_controlled.model.params[iv] == p.TEST_INPUT_1N_6).all())
174+
199175
def test_step_size(self):
200-
# Run the test with an instance of an arbitrary derived class.
176+
# Run the test with an instance of an arbitrarily derived class.
201177
# This test case is not specific to any step size algorithm or initial step size.
202178

203179
print("Test step size is larger zero.")
@@ -214,7 +190,7 @@ def test_step_size(self):
214190
self.assertTrue(model_controlled.step_size(-model_controlled.compute_gradient()) > 0.0)
215191

216192
def test_step_size_no_step(self):
217-
# Run the test with an instance of an arbitrary derived class.
193+
# Run the test with an instance of an arbitrarily derived class.
218194
# Checks that for a zero-gradient no step is performed (i.e. step-size=0).
219195

220196
print("Test step size is zero if gradient is zero.")
@@ -235,16 +211,18 @@ def test_update_control_with_limit_no_limit(self):
235211

236212
print("Test control update with and without strength limit.")
237213

238-
control = self.get_arbitrary_array_finite_values()
214+
control = np.concatenate((p.TEST_INPUT_2N_6[:, np.newaxis, :], p.TEST_INPUT_2N_6[:, np.newaxis, :]), axis=1)
215+
control[0, 0, -1] = np.inf
216+
control[0, 1 - 1] = -np.inf
239217
step = 1.0
240-
cost_gradient = self.get_arbitrary_array()
241218
(N, dim_in, T) = control.shape
219+
cost_gradient = 4.0 * control
242220

243221
u_max = None
244222
control_limited = update_control_with_limit(N, dim_in, T, control, step, cost_gradient, u_max)
245223
self.assertTrue(np.all(control_limited == control + step * cost_gradient))
246224

247-
u_max = 5.0
225+
u_max = 0.1
248226
control_limited = update_control_with_limit(N, dim_in, T, control, step, cost_gradient, u_max)
249227
self.assertTrue(np.all(np.abs(control_limited) <= u_max))
250228

@@ -253,23 +231,18 @@ def test_limit_control_to_interval(self):
253231

254232
print("Test limit of control interval.")
255233

256-
control = self.get_arbitrary_array_finite_values()
234+
control = np.concatenate((p.TEST_INPUT_2N_6[:, np.newaxis, :], p.TEST_INPUT_2N_6[:, np.newaxis, :]), axis=1)
257235
(N, dim_in, T) = control.shape
258236

259-
control_interval = (0, T)
260-
control_limited = limit_control_to_interval(N, dim_in, T, control, control_interval)
261-
self.assertTrue(np.all(control_limited == control))
262-
263-
control_interval = (3, 7)
264-
control_limited = limit_control_to_interval(N, dim_in, T, control, control_interval)
265-
self.assertTrue(np.all(control_limited[:, :, : control_interval[0]]) == 0.0)
266-
self.assertTrue(np.all(control_limited[:, :, control_interval[1] :]) == 0.0)
267-
self.assertTrue(
268-
np.all(
269-
control_limited[:, :, control_interval[0] : control_interval[1]]
270-
== control[:, :, control_interval[0] : control_interval[1]]
271-
)
272-
)
237+
c_int = (0, T)
238+
control_lim = limit_control_to_interval(N, dim_in, T, control, c_int)
239+
self.assertTrue(np.all(control_lim == control))
240+
241+
c_int = (4, 6)
242+
control_lim = limit_control_to_interval(N, dim_in, T, control, c_int)
243+
self.assertTrue(np.all(control_lim[:, :, : c_int[0]]) == 0.0)
244+
self.assertTrue(np.all(control_lim[:, :, c_int[1] :]) == 0.0)
245+
self.assertTrue(np.all(control_lim[:, :, c_int[0] : c_int[1]] == control[:, :, c_int[0] : c_int[1]]))
273246

274247
def test_convert_interval_none(self):
275248
print("Test convert interval.")
@@ -290,7 +263,7 @@ def test_convert_interval_negative(self):
290263
array_length = 10 # arbitrary
291264
interval = (-6, -2)
292265
interval_converted = convert_interval(interval, array_length)
293-
self.assertTupleEqual(interval_converted, (4, 8))
266+
self.assertTupleEqual(interval_converted, (array_length + interval[0], array_length + interval[1]))
294267

295268
def test_convert_interval_unchanged(self):
296269
print("Test convert interval.")

tests/control/optimal_control/test_oc_fhn.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -360,6 +360,44 @@ def test_adjust_input(self):
360360
for input_var0 in model.input_vars:
361361
self.assertTrue(model_controlled.model.params[input_var0].shape == targetinputshape)
362362

363+
# tests if the cost is independent of the integration time step
364+
def test_cost_dt(self):
365+
print("Test cost independent of dt")
366+
model = FHNModel()
367+
model.params["duration"] = p.TEST_DURATION_6
368+
369+
model.params["dt"] = 1e-3
370+
test_input = np.zeros((1, 1 + 100 * (p.TEST_INPUT_1N_6.shape[1] - 1)))
371+
for t in range(p.TEST_INPUT_1N_6.shape[1]):
372+
test_input[0, 100 * t : 100 * t + 100] = p.TEST_INPUT_1N_6[0, t]
373+
374+
test_oc_utils.set_input(model, test_input)
375+
model.run()
376+
target = test_oc_utils.gettarget_1n(model)
377+
test_oc_utils.set_input(model, np.zeros((test_input.shape)))
378+
379+
model_controlled = oc_fhn.OcFhn(model, target)
380+
model_controlled.weights["w_p"] = 1.0
381+
model_controlled.weights["w_2"] = 1.0
382+
cost0 = model_controlled.compute_total_cost()
383+
384+
model.params["dt"] = 1e-4
385+
test_input = np.zeros((1, 1 + 1000 * (p.TEST_INPUT_1N_6.shape[1] - 1)))
386+
for t in range(p.TEST_INPUT_1N_6.shape[1]):
387+
test_input[0, 1000 * t : 1000 * t + 1000] = p.TEST_INPUT_1N_6[0, t]
388+
389+
test_oc_utils.set_input(model, test_input)
390+
model.run()
391+
target = test_oc_utils.gettarget_1n(model)
392+
test_oc_utils.set_input(model, np.zeros((test_input.shape)))
393+
394+
model_controlled = oc_fhn.OcFhn(model, target)
395+
model_controlled.weights["w_p"] = 1.0
396+
model_controlled.weights["w_2"] = 1.0
397+
cost1 = model_controlled.compute_total_cost()
398+
399+
self.assertAlmostEqual(cost0, cost1, 3)
400+
363401

364402
if __name__ == "__main__":
365403
unittest.main()

0 commit comments

Comments
 (0)