1212from neurolib .models .wc import WCModel
1313from neurolib .utils .stimulus import ZeroInput
1414
15+ import test_oc_utils as test_oc_utils
16+
17+ p = test_oc_utils .params
18+
1519
1620class TestOC (unittest .TestCase ):
1721 """
1822 Test functions in neurolib/control/optimal_control/oc.py
1923 """
2024
21- @staticmethod
22- def get_arbitrary_array_finite_values ():
23- """2x2x10 array filled with arbitrary positive and negative values in range [-5.0, 5.0]."""
24- return np .array (
25- [
26- [
27- [1.0 , 2.0 , 0.0 , - 1.5123 , 2.35 , 1.0 , - 1.0 , 5.0 , 0.0 , 0.0 ],
28- [- 1.0 , 3.0 , 2.0 , 0.5 , - 0.1 , - 0.2 , 0.2 , - 0.2 , 0.1 , 0.5 ],
29- ],
30- [
31- [- 0.5 , 0.5 , 0.5 , - 0.1 , - 1.0 , 2.05 , 3.1 , - 4.0 , - 5.0 , - 0.7 ],
32- [1.0 , 0.0 , 0.0 , 0.0 , - 1.0 , 0.0 , 0.0 , 0.0 , 0.0 , 2.5 ],
33- ],
34- ]
35- )
36-
37- @staticmethod
38- def get_arbitrary_array ():
39- """2x2x10 array filled with arbitrary positive and negative values and +/-np.inf. Besides +/-np.inf all values
40- fall in the range [-5., 5.].
41- :return: An array with arbitrary float and np.inf values.
42- :rtype: np.ndarray of shape 2 x 2 x 10
43- """
44- return np .array (
45- [
46- [
47- [1.0 , 2.0 , 0.0 , - 1.5123 , 2.35 , 1.0 , - 1.0 , 0.0 , 0.0 , 0.0 ],
48- [- 1.0 , 3.0 , 2.0 , 0.5 , - np .inf , np .inf , 0.2 , - 0.2 , 0.1 , 5.0 ],
49- ],
50- [
51- [- 0.5 , np .inf , 0.5 , - 0.1 , - 1.0 , 2.5 , 3.01 , - 4.0 , - 5.0 , - 0.7 ],
52- [1.0 , 0.0 , 0.0 , 0.0 , - 1.0 , 0.0 , 0.0 , 0.0 , 0.0 , 2.5 ],
53- ],
54- ]
55- )
56-
5725 @staticmethod
5826 def get_deterministic_wilson_cowan_test_setup ():
5927 """Run a Wilson-Cowan model with default parameters for five time steps, use simulated time series to create
@@ -64,23 +32,12 @@ def get_deterministic_wilson_cowan_test_setup():
6432 :rtype: Tuple[neurolib.models.wc.model.WCModel, np.array]
6533 """
6634 model = WCModel ()
67- dt = model .params ["dt" ]
68- duration = 5 * dt
69- model .params ["duration" ] = duration
35+ model .params ["duration" ] = p .TEST_DURATION_6
7036 model .run ()
7137
72- target = np .concatenate (
73- (
74- np .concatenate ((model .params ["exc_init" ], model .params ["inh_init" ]), axis = 1 )[:, :, np .newaxis ],
75- np .stack ((model .exc , model .inh ), axis = 1 ),
76- ),
77- axis = 2 ,
78- )
79-
38+ target = test_oc_utils .gettarget_1n (model )
8039 target [0 , 0 , :] = target [0 , 0 , :] * 2.0
81-
82- model .params ["exc_ext" ] = ZeroInput ().generate_input (duration = duration + dt , dt = dt )
83- model .params ["inh_ext" ] = ZeroInput ().generate_input (duration = duration + dt , dt = dt )
40+ test_oc_utils .set_input (model , p .ZERO_INPUT_1N_6 )
8441
8542 return model , target
8643
@@ -196,8 +153,27 @@ def test_compute_gradient(self):
196153 for n in range (N ):
197154 self .assertTrue (np .all (gadient [n , :, :] == result ))
198155
156+ def test_update_input (self ):
157+ # Run the test with an instance of an arbitrarily derived class
158+ # check update_input function
159+ print ("Tets update_input function." )
160+
161+ model , target = self .get_deterministic_wilson_cowan_test_setup ()
162+ model_controlled = OcWc (model , target )
163+ model_controlled .control = np .concatenate (
164+ (p .TEST_INPUT_1N_6 [:, np .newaxis , :], p .TEST_INPUT_1N_6 [:, np .newaxis , :]), axis = 1
165+ )
166+
167+ for iv in model .input_vars :
168+ self .assertTrue ((model_controlled .model .params [iv ] == 0.0 ).all ())
169+
170+ model_controlled .update_input ()
171+
172+ for iv in model .input_vars :
173+ self .assertTrue ((model_controlled .model .params [iv ] == p .TEST_INPUT_1N_6 ).all ())
174+
199175 def test_step_size (self ):
200- # Run the test with an instance of an arbitrary derived class.
176+ # Run the test with an instance of an arbitrarily derived class.
201177 # This test case is not specific to any step size algorithm or initial step size.
202178
203179 print ("Test step size is larger zero." )
@@ -214,7 +190,7 @@ def test_step_size(self):
214190 self .assertTrue (model_controlled .step_size (- model_controlled .compute_gradient ()) > 0.0 )
215191
216192 def test_step_size_no_step (self ):
217- # Run the test with an instance of an arbitrary derived class.
193+ # Run the test with an instance of an arbitrarily derived class.
218194 # Checks that for a zero-gradient no step is performed (i.e. step-size=0).
219195
220196 print ("Test step size is zero if gradient is zero." )
@@ -235,16 +211,18 @@ def test_update_control_with_limit_no_limit(self):
235211
236212 print ("Test control update with and without strength limit." )
237213
238- control = self .get_arbitrary_array_finite_values ()
214+ control = np .concatenate ((p .TEST_INPUT_2N_6 [:, np .newaxis , :], p .TEST_INPUT_2N_6 [:, np .newaxis , :]), axis = 1 )
215+ control [0 , 0 , - 1 ] = np .inf
216+ control [0 , 1 - 1 ] = - np .inf
239217 step = 1.0
240- cost_gradient = self .get_arbitrary_array ()
241218 (N , dim_in , T ) = control .shape
219+ cost_gradient = 4.0 * control
242220
243221 u_max = None
244222 control_limited = update_control_with_limit (N , dim_in , T , control , step , cost_gradient , u_max )
245223 self .assertTrue (np .all (control_limited == control + step * cost_gradient ))
246224
247- u_max = 5.0
225+ u_max = 0.1
248226 control_limited = update_control_with_limit (N , dim_in , T , control , step , cost_gradient , u_max )
249227 self .assertTrue (np .all (np .abs (control_limited ) <= u_max ))
250228
@@ -253,23 +231,18 @@ def test_limit_control_to_interval(self):
253231
254232 print ("Test limit of control interval." )
255233
256- control = self . get_arbitrary_array_finite_values ( )
234+ control = np . concatenate (( p . TEST_INPUT_2N_6 [:, np . newaxis , :], p . TEST_INPUT_2N_6 [:, np . newaxis , :]), axis = 1 )
257235 (N , dim_in , T ) = control .shape
258236
259- control_interval = (0 , T )
260- control_limited = limit_control_to_interval (N , dim_in , T , control , control_interval )
261- self .assertTrue (np .all (control_limited == control ))
262-
263- control_interval = (3 , 7 )
264- control_limited = limit_control_to_interval (N , dim_in , T , control , control_interval )
265- self .assertTrue (np .all (control_limited [:, :, : control_interval [0 ]]) == 0.0 )
266- self .assertTrue (np .all (control_limited [:, :, control_interval [1 ] :]) == 0.0 )
267- self .assertTrue (
268- np .all (
269- control_limited [:, :, control_interval [0 ] : control_interval [1 ]]
270- == control [:, :, control_interval [0 ] : control_interval [1 ]]
271- )
272- )
237+ c_int = (0 , T )
238+ control_lim = limit_control_to_interval (N , dim_in , T , control , c_int )
239+ self .assertTrue (np .all (control_lim == control ))
240+
241+ c_int = (4 , 6 )
242+ control_lim = limit_control_to_interval (N , dim_in , T , control , c_int )
243+ self .assertTrue (np .all (control_lim [:, :, : c_int [0 ]]) == 0.0 )
244+ self .assertTrue (np .all (control_lim [:, :, c_int [1 ] :]) == 0.0 )
245+ self .assertTrue (np .all (control_lim [:, :, c_int [0 ] : c_int [1 ]] == control [:, :, c_int [0 ] : c_int [1 ]]))
273246
274247 def test_convert_interval_none (self ):
275248 print ("Test convert interval." )
@@ -290,7 +263,7 @@ def test_convert_interval_negative(self):
290263 array_length = 10 # arbitrary
291264 interval = (- 6 , - 2 )
292265 interval_converted = convert_interval (interval , array_length )
293- self .assertTupleEqual (interval_converted , (4 , 8 ))
266+ self .assertTupleEqual (interval_converted , (array_length + interval [ 0 ], array_length + interval [ 1 ] ))
294267
295268 def test_convert_interval_unchanged (self ):
296269 print ("Test convert interval." )
0 commit comments