22import numpy as np
33
44from neurolib .control .optimal_control .oc import OC
5- from neurolib .models .aln .timeIntegration import compute_hx , compute_hx_nw , Duh , Dxdoth , compute_hx_de , compute_hx_di
5+ from neurolib .models .aln .timeIntegration import (
6+ compute_hx ,
7+ compute_hx_nw ,
8+ Duh ,
9+ Dxdoth ,
10+ compute_hx_de ,
11+ compute_hx_di ,
12+ )
613
714
815class OcAln (OC ):
@@ -21,6 +28,7 @@ def __init__(
2128 weights = None ,
2229 print_array = [],
2330 cost_interval = (None , None ),
31+ control_interval = (None , None ),
2432 cost_matrix = None ,
2533 control_matrix = None ,
2634 M = 1 ,
@@ -34,6 +42,7 @@ def __init__(
3442 print_array = print_array ,
3543 cost_interval = cost_interval ,
3644 cost_matrix = cost_matrix ,
45+ control_interval = control_interval ,
3746 control_matrix = control_matrix ,
3847 M = M ,
3948 M_validation = M_validation ,
@@ -197,7 +206,9 @@ def compute_hx_list(self):
197206 hx_de = self .compute_hx_de ()
198207 hx_di = self .compute_hx_di ()
199208
200- return numba .typed .List ([hx , hx_de , hx_di ]), numba .typed .List ([0 , self .ndt_de , self .ndt_di ])
209+ return numba .typed .List ([hx , hx_de , hx_di ]), numba .typed .List (
210+ [0 , self .ndt_de , self .ndt_di ]
211+ )
201212
202213 def compute_hx (self ):
203214 """Jacobians of ALNModel wrt. the 'e'- and 'i'-variable for each time step.
@@ -317,7 +328,9 @@ def get_fullstate(self):
317328 if t <= T - 2 :
318329 self .model .params [iv ] = control [:, iv_ind , t : t + 2 ]
319330 elif t == T - 1 :
320- self .model .params [iv ] = np .concatenate ((control [:, iv_ind , t :], np .zeros ((N , 1 ))), axis = 1 )
331+ self .model .params [iv ] = np .concatenate (
332+ (control [:, iv_ind , t :], np .zeros ((N , 1 ))), axis = 1
333+ )
321334 else :
322335 self .model .params [iv ] = 0.0
323336 self .model .run ()
@@ -349,11 +362,19 @@ def setasinit(self, fullstate, t):
349362
350363 for n in range (N ):
351364 for v in range (V ):
352- if "rates" in self .model .init_vars [v ] or "IA" in self .model .init_vars [v ]:
365+ if (
366+ "rates" in self .model .init_vars [v ]
367+ or "IA" in self .model .init_vars [v ]
368+ ):
353369 if t >= T :
354- self .model .params [self .model .init_vars [v ]] = fullstate [:, v , t - T : t + 1 ]
370+ self .model .params [self .model .init_vars [v ]] = fullstate [
371+ :, v , t - T : t + 1
372+ ]
355373 else :
356- init = np .concatenate ((fullstate [:, v , - T + t + 1 :], fullstate [:, v , : t + 1 ]), axis = 1 )
374+ init = np .concatenate (
375+ (fullstate [:, v , - T + t + 1 :], fullstate [:, v , : t + 1 ]),
376+ axis = 1 ,
377+ )
357378 self .model .params [self .model .init_vars [v ]] = init
358379 else :
359380 self .model .params [self .model .init_vars [v ]] = fullstate [:, v , t ]
@@ -371,8 +392,13 @@ def getinitstate(self):
371392
372393 for n in range (N ):
373394 for v in range (V ):
374- if "rates" in self .model .init_vars [v ] or "IA" in self .model .init_vars [v ]:
375- initstate [n , v , :] = self .model .params [self .model .init_vars [v ]][n , - T :]
395+ if (
396+ "rates" in self .model .init_vars [v ]
397+ or "IA" in self .model .init_vars [v ]
398+ ):
399+ initstate [n , v , :] = self .model .params [self .model .init_vars [v ]][
400+ n , - T :
401+ ]
376402
377403 else :
378404 initstate [n , v , :] = self .model .params [self .model .init_vars [v ]][n ]
@@ -389,7 +415,10 @@ def getfinalstate(self):
389415 state = np .zeros ((N , V ))
390416 for n in range (N ):
391417 for v in range (V ):
392- if "rates" in self .model .state_vars [v ] or "IA" in self .model .state_vars [v ]:
418+ if (
419+ "rates" in self .model .state_vars [v ]
420+ or "IA" in self .model .state_vars [v ]
421+ ):
393422 state [n , v ] = self .model .state [self .model .state_vars [v ]][n , - 1 ]
394423
395424 else :
@@ -408,7 +437,10 @@ def setinitstate(self, state):
408437
409438 for n in range (N ):
410439 for v in range (V ):
411- if "rates" in self .model .init_vars [v ] or "IA" in self .model .init_vars [v ]:
440+ if (
441+ "rates" in self .model .init_vars [v ]
442+ or "IA" in self .model .init_vars [v ]
443+ ):
412444 self .model .params [self .model .init_vars [v ]] = state [:, v , - T :]
413445 else :
414446 self .model .params [self .model .init_vars [v ]] = state [:, v , - 1 ]
0 commit comments