55import copy
66from neurolib .models .jax .wc import WCModel
77from neurolib .models .jax .wc .timeIntegration import timeIntegration_args , timeIntegration_elementwise
8-
9- import logging
8+ from neurolib .optimize .loss_functions import (
9+ kuramoto_loss ,
10+ cross_correlation_loss ,
11+ variance_loss ,
12+ osc_fourier_loss ,
13+ sync_fourier_loss ,
14+ )
1015from neurolib .control .optimal_control .oc import getdefaultweights
1116
12- # TODO: introduce for all models, not just WC
13- wc_default_control_params = ["exc_ext" , "inh_ext" ]
14- wc_default_target_params = ["exc" , "inh" ]
15-
16-
17- def hilbert_jax (signal , axis = - 1 ):
18-
19- n = signal .shape [axis ]
20- h = jnp .zeros (n )
21- h = h .at [0 ].set (1 )
22-
23- if n % 2 == 0 :
24- h = h .at [1 : n // 2 ].set (2 )
25- h = h .at [n // 2 ].set (1 )
26- else :
27- h = h .at [1 : (n + 1 ) // 2 ].set (2 )
28-
29- h = jnp .expand_dims (h , tuple (i for i in range (signal .ndim ) if i != axis ))
30- h = jnp .broadcast_to (h , signal .shape )
31-
32- fft_signal = jnp .fft .fft (signal , axis = axis )
33- analytic_fft = fft_signal * h
34-
35- analytic_signal = jnp .fft .ifft (analytic_fft )
36- return analytic_signal
37-
3817
3918class Optimize :
4019 def __init__ (
@@ -43,8 +22,8 @@ def __init__(
4322 loss_function ,
4423 param_names ,
4524 target_param_names ,
46- target = None ,
4725 init_params = None ,
26+ regularization_function = lambda _ : 0.0 ,
4827 optimizer = optax .adabelief (1e-3 ),
4928 ):
5029 assert isinstance (param_names , (list , tuple )) and len (param_names ) > 0
@@ -54,7 +33,7 @@ def __init__(
5433
5534 self .model = copy .deepcopy (model )
5635 self .loss_function = loss_function
57- self .target = target
36+ self .regularization_function = regularization_function
5837 self .optimizer = optimizer
5938 self .param_names = param_names
6039 self .target_param_names = target_param_names
@@ -70,7 +49,7 @@ def __init__(
7049 self .params = dict (zip (param_names , [self .args [p ] for p in param_names ]))
7150 self .opt_state = self .optimizer .init (self .params )
7251
73- compute_loss = lambda params : self .loss_function (params , self .get_output (params ))
52+ compute_loss = lambda params : self .loss_function (self .get_output (params )) + self . regularization_function ( params )
7453 self .compute_loss = jax .jit (compute_loss )
7554 self .compute_gradient = jax .jit (jax .grad (self .compute_loss ))
7655
@@ -93,15 +72,7 @@ def get_output(self, params):
9372 simulation_results = self .simulate (params )
9473 return jnp .stack ([simulation_results [tp ][:, self .startind :] for tp in self .target_param_names ])
9574
96- def get_loss (self ):
97- @jax .jit
98- def loss (params ):
99- output = self .get_output (params )
100- return self .loss_function (params , output )
101-
102- return loss
103-
104- def optimize_deterministic (self , n_max_iterations , output_every_nth = None ):
75+ def optimize (self , n_max_iterations , output_every_nth = None ):
10576 loss = self .compute_loss (self .control )
10677 print (f"loss in iteration 0: %s" % (loss ))
10778 if len (self .cost_history ) == 0 : # add only if control model has not yet been optimized
@@ -121,105 +92,87 @@ def optimize_deterministic(self, n_max_iterations, output_every_nth=None):
12192 print (f"Final loss : %s" % (loss ))
12293
12394
124- class OcWc (Optimize ):
95+ class Oc (Optimize ):
96+ """
97+ Convenience class for optimal control. The cost functional is constructed as a weighted sum of accuracy and control strength costs. Requires optimization parameters to be of shape (N, T).
98+ """
99+
100+ supported_cost_parameters = [
101+ "w_p" ,
102+ "w_cc" ,
103+ "w_var" ,
104+ "w_f_osc" ,
105+ "w_f_sync" ,
106+ "w_ko" ,
107+ "w_2" ,
108+ "w_1D" ,
109+ ]
110+
125111 def __init__ (
126112 self ,
127113 model ,
128- target = None ,
114+ target_timeseries = None ,
115+ target_frequency = None ,
129116 optimizer = optax .adabelief (1e-3 ),
130- control_param_names = wc_default_control_params ,
131- target_param_names = wc_default_target_params ,
117+ control_param_names = ["exc_ext" , "inh_ext" ],
118+ target_param_names = ["exc" , "inh" ],
119+ weights = None ,
132120 ):
133121 super ().__init__ (
134122 model ,
135- self .compute_total_cost ,
123+ self .accuracy_cost ,
136124 control_param_names ,
137125 target_param_names ,
138- target = target ,
139126 init_params = None ,
140127 optimizer = optimizer ,
128+ regularization_function = self .control_strength_cost ,
141129 )
130+ self .target_timeseries = target_timeseries
131+ self .target_frequency = target_frequency
142132 self .control = self .params
143- self .weights = getdefaultweights ()
133+ if weights is None :
134+ self .weights = getdefaultweights ()
144135
145- def compute_total_cost (self , control , output ):
136+ def accuracy_cost (self , output ):
146137 """
147- Compute the total cost as the sum of accuracy cost and control strength cost.
148-
149- Parameters:
150- control (dict[str, jax.numpy.ndarray]): Dictionary of control inputs, where each entry has shape (N, T).
151- output (jax.numpy.ndarray): Simulation output of shape ((len(target_param_names)), N, T).
152-
153- Returns:
154- float: The total cost.
138+ Args:
139+ output (jax.numpy.ndarray): Simulation output of shape ((len(target_param_names)), N, T).
155140 """
156- accuracy_cost = self .accuracy_cost (output )
157- control_arr = jnp .array (list (control .values ()))
158- control_strength_cost = self .control_strength_cost (control_arr )
159- return accuracy_cost + control_strength_cost
160-
161- # TODO: move cost functions outside
162- def accuracy_cost (self , output ):
163141 accuracy_cost = 0.0
164142 if self .weights ["w_p" ] != 0.0 :
165- accuracy_cost += self .weights ["w_p" ] * 0.5 * self .model . params . dt * jnp . sum (( output - self . target ) ** 2 )
143+ accuracy_cost += self .weights ["w_p" ] * self .precision_cost ( output )
166144 if self .weights ["w_cc" ] != 0.0 :
167- accuracy_cost += self .weights ["w_cc" ] * self . compute_cc_cost (output )
145+ accuracy_cost += self .weights ["w_cc" ] * cross_correlation_loss (output , self . model . params . dt )
168146 if self .weights ["w_var" ] != 0.0 :
169- accuracy_cost += self .weights ["w_var" ] * self . compute_var_cost (output )
147+ accuracy_cost += self .weights ["w_var" ] * variance_loss (output )
170148 if self .weights ["w_f_osc" ] != 0.0 :
171- accuracy_cost += self .weights ["w_f_osc" ] * self .compute_osc_fourier_cost (output )
149+ accuracy_cost += self .weights ["w_f_osc" ] * osc_fourier_loss (
150+ output , self .target_frequency , self .model .params .dt
151+ )
172152 if self .weights ["w_f_sync" ] != 0.0 :
173- accuracy_cost += self .weights ["w_f_sync" ] * self .compute_sync_fourier_cost (output )
153+ accuracy_cost += self .weights ["w_f_sync" ] * sync_fourier_loss (
154+ output , self .target_frequency , self .model .params .dt
155+ )
174156 if self .weights ["w_ko" ] != 0.0 :
175- accuracy_cost += self .weights ["w_ko" ] * self . compute_kuramoto_cost (output )
157+ accuracy_cost += self .weights ["w_ko" ] * kuramoto_loss (output )
176158 return accuracy_cost
177159
160+ def precision_cost (self , output ):
161+ return 0.5 * self .model .params .dt * jnp .sum ((output - self .target_timeseries ) ** 2 )
162+
178163 def control_strength_cost (self , control ):
164+ """
165+ Args:
166+ control (dict[str, jax.numpy.ndarray]): Dictionary of control inputs, where each entry has shape (N, T).
167+ """
168+ control_arr = jnp .array (list (control .values ()))
179169 control_strength_cost = 0.0
180170 if self .weights ["w_2" ] != 0.0 :
181- control_strength_cost += self .weights ["w_2" ] * 0.5 * self .model .params .dt * jnp .sum (control ** 2 )
171+ control_strength_cost += self .weights ["w_2" ] * 0.5 * self .model .params .dt * jnp .sum (control_arr ** 2 )
182172 if self .weights ["w_1D" ] != 0.0 :
183- control_strength_cost += self .weights ["w_1D" ] * self .compute_ds_cost (control )
173+ control_strength_cost += self .weights ["w_1D" ] * self .compute_ds_cost (control_arr )
184174 return control_strength_cost
185175
186176 def compute_ds_cost (self , control ):
187177 eps = 1e-6 # avoid grad(sqrt(0.0))
188178 return jnp .sum (jnp .sqrt (jnp .sum (control ** 2 , axis = 2 ) * self .model .params .dt + eps ))
189-
190- def compute_cc_cost (self , output ):
191- xmean = jnp .mean (output , axis = 2 , keepdims = True )
192- xstd = jnp .std (output , axis = 2 , keepdims = True )
193-
194- xvec = (output - xmean ) / xstd
195-
196- costmat = jnp .einsum ("vnt,vkt->vnkt" , xvec , xvec )
197- diag = jnp .einsum ("vnt,vnt->vt" , xvec , xvec )
198- cost = jnp .sum (jnp .sum (costmat , axis = (1 , 2 )) - diag ) * self .model .params .dt / 2.0
199- cost *= - 2.0 / (self .model .params .N * (self .model .params .N - 1 ) * self .T * self .model .params .dt )
200- return cost
201-
202- def compute_var_cost (self , output ):
203- return jnp .var (output , axis = (0 , 1 )).mean ()
204-
205- def get_fourier_component (self , data , target_period ):
206- fourier_series = jnp .abs (jnp .fft .fft (data )[: len (data ) // 2 ])
207- freqs = jnp .fft .fftfreq (data .size , d = self .model .params .dt )[: len (data ) // 2 ]
208- return fourier_series [jnp .argmin (jnp .abs (freqs - 1.0 / target_period ))]
209-
210- def compute_osc_fourier_cost (self , output ):
211- cost = 0.0
212- for n in range (output .shape [1 ]):
213- for v in range (output .shape [0 ]):
214- cost -= self .get_fourier_component (output [v , n ], self .target ) ** 2
215- return cost / (output .shape [2 ] * self .model .params .dt ) ** 2
216-
217- def compute_sync_fourier_cost (self , output ):
218- cost = 0.0
219- for v in range (output .shape [0 ]):
220- cost -= self .get_fourier_component (jnp .sum (output [v ], axis = 0 ), self .target ) ** 2
221- return cost / (output .shape [2 ] * self .model .params .dt ) ** 2
222-
223- def compute_kuramoto_cost (self , output ):
224- phase = jnp .angle (hilbert_jax (output , axis = 2 ))
225- return - jnp .mean (jnp .abs (jnp .mean (jnp .exp (complex (0 , 1 ) * phase ), axis = 1 )))
0 commit comments