55import copy
66from neurolib .models .jax .wc import WCModel
77from neurolib .models .jax .wc .timeIntegration import timeIntegration_args , timeIntegration_elementwise
8- from neurolib .optimize .autodiff .wc_optimizer import args_names
98
109import logging
1110from neurolib .control .optimal_control .oc import getdefaultweights
1211
12+ # TODO: introduce for all models, not just WC
1313wc_default_control_params = ["exc_ext" , "inh_ext" ]
1414wc_default_target_params = ["exc" , "inh" ]
1515
@@ -35,78 +35,129 @@ def hilbert_jax(signal, axis=-1):
3535 return analytic_signal
3636
3737
38- class OcWc :
38+ class Optimize :
3939 def __init__ (
4040 self ,
4141 model ,
42+ loss_function ,
43+ param_names ,
44+ target_param_names ,
4245 target = None ,
43- optimizer = optax .adam (1e-3 ),
44- control_params = wc_default_control_params ,
45- target_params = wc_default_target_params ,
46+ init_params = None ,
47+ optimizer = optax .adabelief (1e-3 ),
4648 ):
47- assert isinstance (control_params , (list , tuple )) and len (control_params ) > 0
48- assert isinstance (target_params , (list , tuple )) and len (target_params ) > 0
49- assert all ([cp in wc_default_control_params for cp in control_params ])
50- assert all ([tp in wc_default_target_params for tp in target_params ])
49+ assert isinstance (param_names , (list , tuple )) and len (param_names ) > 0
50+ assert isinstance (target_param_names , (list , tuple )) and len (target_param_names ) > 0
51+ assert all ([p in model . args_names for p in param_names ])
52+ assert all ([tp in model . output_vars for tp in target_param_names ])
5153
5254 self .model = copy .deepcopy (model )
55+ self .loss_function = loss_function
5356 self .target = target
5457 self .optimizer = optimizer
55- self .control_params = control_params
56- self .target_params = target_params
57-
58- self .weights = getdefaultweights ()
58+ self .param_names = param_names
59+ self .target_param_names = target_param_names
5960
6061 args_values = timeIntegration_args (self .model .params )
61- self .args = dict (zip (args_names , args_values ))
62+ self .args = dict (zip (self . model . args_names , args_values ))
6263
63- self .loss = self .get_loss ()
64- self .compute_gradient = jax .jit (jax .grad (self .loss ))
6564 self .T = len (self .args ["t" ])
6665 self .startind = self .model .getMaxDelay ()
67- self .control = jnp .zeros ((len (control_params ), self .model .params .N , self .T ), dtype = float )
68- self .opt_state = self .optimizer .init (self .control )
66+ if init_params is not None :
67+ self .params = init_params
68+ else :
69+ self .params = dict (zip (param_names , [self .args [p ] for p in param_names ]))
70+ self .opt_state = self .optimizer .init (self .params )
71+
72+ compute_loss = lambda params : self .loss_function (params , self .get_output (params ))
73+ self .compute_loss = jax .jit (compute_loss )
74+ self .compute_gradient = jax .jit (jax .grad (self .compute_loss ))
6975
7076 self .cost_history = []
7177
72- def simulate (self , control ):
78+ # TODO: allow arbitrary model, not just WC
79+ def simulate (self , params ):
7380 args_local = self .args .copy ()
74- args_local .update (dict (zip (self .control_params , [c for c in control ])))
75- return timeIntegration_elementwise (** args_local )
76-
77- def get_output (self , control ):
78- t , exc , inh , exc_ou , inh_ou = self .simulate (control )
79- if self .target_params == ["exc" , "inh" ]:
80- output = jnp .stack ((exc , inh ), axis = 0 )
81- elif self .target_params == ["exc" ]:
82- output = exc [None , ...]
83- elif self .target_params == ["inh" ]:
84- output = inh [None , ...]
85- return output [:, :, self .startind :]
81+ args_local .update (params )
82+ t , exc , inh , exc_ou , inh_ou = timeIntegration_elementwise (** args_local )
83+ return {
84+ "t" : t ,
85+ "exc" : exc ,
86+ "inh" : inh ,
87+ "exc_ou" : exc_ou ,
88+ "inh_ou" : inh_ou ,
89+ }
90+
91+ def get_output (self , params ):
92+ simulation_results = self .simulate (params )
93+ return jnp .stack ([simulation_results [tp ][:, self .startind :] for tp in self .target_param_names ])
8694
8795 def get_loss (self ):
8896 @jax .jit
89- def loss (control ):
90- output = self .get_output (control )
91- return self .compute_total_cost ( control , output )
97+ def loss (params ):
98+ output = self .get_output (params )
99+ return self .loss_function ( params , output )
92100
93101 return loss
94102
103+ def optimize_deterministic (self , n_max_iterations , output_every_nth = None ):
104+ loss = self .compute_loss (self .control )
105+ print (f"loss in iteration 0: %s" % (loss ))
106+ if len (self .cost_history ) == 0 : # add only if control model has not yet been optimized
107+ self .cost_history .append (loss )
108+
109+ for i in range (1 , n_max_iterations + 1 ):
110+ self .gradient = self .compute_gradient (self .control )
111+ updates , self .opt_state = self .optimizer .update (self .gradient , self .opt_state )
112+ self .control = optax .apply_updates (self .control , updates )
113+
114+ if output_every_nth is not None and i % output_every_nth == 0 :
115+ loss = self .compute_loss (self .control )
116+ self .cost_history .append (loss )
117+ print (f"loss in iteration %s: %s" % (i , loss ))
118+
119+ loss = self .compute_loss (self .control )
120+ print (f"Final loss : %s" % (loss ))
121+
122+
123+ class OcWc (Optimize ):
124+ def __init__ (
125+ self ,
126+ model ,
127+ target = None ,
128+ optimizer = optax .adabelief (1e-3 ),
129+ control_param_names = wc_default_control_params ,
130+ target_param_names = wc_default_target_params ,
131+ ):
132+ super ().__init__ (
133+ model ,
134+ self .compute_total_cost ,
135+ control_param_names ,
136+ target_param_names ,
137+ target = target ,
138+ init_params = None ,
139+ optimizer = optimizer ,
140+ )
141+ self .control = self .params
142+ self .weights = getdefaultweights ()
143+
95144 def compute_total_cost (self , control , output ):
96145 """
97146 Compute the total cost as the sum of accuracy cost and control strength cost.
98147
99148 Parameters:
100- control (jax.numpy.ndarray): Control input array of shape ((len(control_params)), N, T).
101- output (jax.numpy.ndarray): Simulation output of shape ((len(target_params )), N, T).
149+ control (dict[str, jax.numpy.ndarray] ): Dictionary of control inputs, where each entry has shape (N, T).
150+ output (jax.numpy.ndarray): Simulation output of shape ((len(target_param_names )), N, T).
102151
103152 Returns:
104153 float: The total cost.
105154 """
106155 accuracy_cost = self .accuracy_cost (output )
107- control_strength_cost = self .control_strength_cost (control )
156+ control_arr = jnp .array (list (control .values ()))
157+ control_strength_cost = self .control_strength_cost (control_arr )
108158 return accuracy_cost + control_strength_cost
109159
160+ # TODO: move cost functions outside
110161 def accuracy_cost (self , output ):
111162 accuracy_cost = 0.0
112163 if self .weights ["w_p" ] != 0.0 :
0 commit comments