@@ -222,32 +222,3 @@ def compute_sync_fourier_cost(self, output):
222222 def compute_kuramoto_cost (self , output ):
223223 phase = jnp .angle (hilbert_jax (output , axis = 2 ))
224224 return - jnp .mean (jnp .abs (jnp .mean (jnp .exp (complex (0 ,1 )* phase ), axis = 1 )))
225-
226-
227- def optimize_deterministic (self , n_max_iterations , output_every_nth = None ):
228- """Compute the optimal control signal for noise averaging method 0.
229-
230- :param n_max_iterations: maximum number of iterations of gradient descent
231- :type n_max_iterations: int
232- """
233-
234- output = self .get_output (self .control )
235-
236- cost = self .compute_total_cost (self .control , output )
237- print (f"Cost in iteration 0: %s" % (cost ))
238- if len (self .cost_history ) == 0 : # add only if control model has not yet been optimized
239- self .cost_history .append (cost )
240-
241- for i in range (1 , n_max_iterations + 1 ):
242- self .gradient = self .compute_gradient (self .control )
243-
244- updates , self .opt_state = self .optimizer .update (self .gradient , self .opt_state )
245- self .control = optax .apply_updates (self .control , updates )
246-
247- output = self .get_output (self .control )
248- if output_every_nth is not None and i % output_every_nth == 0 :
249- cost = self .compute_total_cost (self .control , output )
250- self .cost_history .append (cost )
251- print (f"Cost in iteration %s: %s" % (i , cost ))
252-
253- print (f"Final cost : %s" % (cost ))
0 commit comments