@@ -42,14 +42,17 @@ def __init__(
4242 self .args = dict (zip (self .model .args_names , args_values ))
4343
4444 self .T = len (self .args ["t" ])
45- self .startind = self .model .getMaxDelay ()
45+ self .startind = self .model .getMaxDelay () + 1
4646 if init_params is not None :
4747 self .params = init_params
4848 else :
4949 self .params = dict (zip (param_names , [self .args [p ] for p in param_names ]))
5050 self .opt_state = self .optimizer .init (self .params )
5151
52- compute_loss = lambda params : self .loss_function (self .get_output (params )) + self .regularization_function (params )
52+ # TODO: instead apply individually to each param
53+ compute_loss = lambda params : self .loss_function (
54+ jnp .stack (list (self .get_output (params ).values ()))
55+ ) + self .regularization_function (params )
5356 self .compute_loss = jax .jit (compute_loss )
5457 self .compute_gradient = jax .jit (jax .grad (self .compute_loss ))
5558
@@ -70,25 +73,25 @@ def simulate(self, params):
7073
7174 def get_output (self , params ):
7275 simulation_results = self .simulate (params )
73- return jnp . stack ([ simulation_results [tp ][:, self .startind :] for tp in self .target_param_names ])
76+ return { tp : simulation_results [tp ][:, self .startind :] for tp in self .target_param_names }
7477
7578 def optimize (self , n_max_iterations , output_every_nth = None ):
76- loss = self .compute_loss (self .control )
79+ loss = self .compute_loss (self .params )
7780 print (f"loss in iteration 0: %s" % (loss ))
78- if len (self .cost_history ) == 0 : # add only if control model has not yet been optimized
81+ if len (self .cost_history ) == 0 : # add only if params have not yet been optimized
7982 self .cost_history .append (loss )
8083
8184 for i in range (1 , n_max_iterations + 1 ):
82- self .gradient = self .compute_gradient (self .control )
85+ self .gradient = self .compute_gradient (self .params )
8386 updates , self .opt_state = self .optimizer .update (self .gradient , self .opt_state )
84- self .control = optax .apply_updates (self .control , updates )
87+ self .params = optax .apply_updates (self .params , updates )
8588
8689 if output_every_nth is not None and i % output_every_nth == 0 :
87- loss = self .compute_loss (self .control )
90+ loss = self .compute_loss (self .params )
8891 self .cost_history .append (loss )
8992 print (f"loss in iteration %s: %s" % (i , loss ))
9093
91- loss = self .compute_loss (self .control )
94+ loss = self .compute_loss (self .params )
9295 print (f"Final loss : %s" % (loss ))
9396
9497
@@ -176,3 +179,7 @@ def control_strength_cost(self, control):
176179 def compute_ds_cost (self , control ):
177180 eps = 1e-6 # avoid grad(sqrt(0.0))
178181 return jnp .sum (jnp .sqrt (jnp .sum (control ** 2 , axis = 2 ) * self .model .params .dt + eps ))
182+
183+ def optimize (self , * args , ** kwargs ):
184+ super ().optimize (* args , ** kwargs )
185+ self .control = self .params
0 commit comments