Skip to content

Commit 89af353

Browse files
committed
cost functionals: fourier, kuramoto
1 parent b7cd619 commit 89af353

File tree

1 file changed

+46
-0
lines changed

1 file changed

+46
-0
lines changed

neurolib/control/optimal_control/oc_jax.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,27 @@
1313
wc_default_control_params = ["exc_ext", "inh_ext"]
1414
wc_default_target_params = ["exc", "inh"]
1515

16+
def hilbert_jax(signal, axis=-1):
17+
18+
n = signal.shape[axis]
19+
h = jnp.zeros(n)
20+
h = h.at[0].set(1)
21+
22+
if n % 2 == 0:
23+
h = h.at[1:n//2].set(2)
24+
h = h.at[n//2].set(1)
25+
else:
26+
h = h.at[1:(n+1)//2].set(2)
27+
28+
h = jnp.expand_dims(h, tuple(i for i in range(signal.ndim) if i != axis))
29+
h = jnp.broadcast_to(h, signal.shape)
30+
31+
fft_signal = jnp.fft.fft(signal, axis=axis)
32+
analytic_fft = fft_signal * h
33+
34+
analytic_signal = jnp.fft.ifft(analytic_fft)
35+
return analytic_signal
36+
1637

1738
class OcWc:
1839
def __init__(
@@ -96,6 +117,10 @@ def accuracy_cost(self, output):
96117
accuracy_cost += self.weights["w_var"] * self.compute_var_cost(output)
97118
if self.weights["w_f_osc"] != 0.0:
98119
accuracy_cost += self.weights["w_f_osc"] * self.compute_osc_fourier_cost(output)
120+
if self.weights["w_f_sync"] != 0.0:
121+
accuracy_cost += self.weights["w_f_sync"] * self.compute_sync_fourier_cost(output)
122+
if self.weights["w_kuramoto"] != 0.0:
123+
accuracy_cost += self.weights["w_kuramoto"] * self.compute_kuramoto_cost(output)
99124
return accuracy_cost
100125

101126
def control_strength_cost(self, control):
@@ -124,8 +149,29 @@ def compute_cc_cost(self, output):
124149

125150
def compute_var_cost(self, output):
126151
return jnp.var(output, axis=(0, 1)).mean()
152+
153+
def get_fourier_component(self, data, target_period):
154+
fourier_series = jnp.abs(jnp.fft.fft(data)[:len(data)//2])
155+
freqs = jnp.fft.fftfreq(data.size,d=self.model.params.dt)[:len(data)//2]
156+
return fourier_series[jnp.argmin(jnp.abs(freqs - 1./target_period))]
127157

128158
def compute_osc_fourier_cost(self, output):
159+
cost = 0.0
160+
for n in range(output.shape[1]):
161+
for v in range(output.shape[0]):
162+
cost -= self.get_fourier_component(output[v, n], self.target)**2
163+
return cost / (output.shape[2] * self.model.params.dt)**2
164+
165+
def compute_sync_fourier_cost(self, output):
166+
cost = 0.0
167+
for v in range(output.shape[0]):
168+
cost -= self.get_fourier_component(jnp.sum(output[v], axis=0), self.target)**2
169+
return cost / (output.shape[2] * self.model.params.dt)**2
170+
171+
def compute_kuramoto_cost(self, output):
172+
phase = jnp.angle(hilbert_jax(output, axis=2))
173+
return -jnp.mean(jnp.abs(jnp.mean(jnp.exp(complex(0,1)*phase), axis=1)))
174+
129175

130176
def optimize_deterministic(self, n_max_iterations, output_every_nth=None):
131177
"""Compute the optimal control signal for noise averaging method 0.

0 commit comments

Comments
 (0)