|
13 | 13 | wc_default_control_params = ["exc_ext", "inh_ext"] |
14 | 14 | wc_default_target_params = ["exc", "inh"] |
15 | 15 |
|
| 16 | +def hilbert_jax(signal, axis=-1): |
| 17 | + |
| 18 | + n = signal.shape[axis] |
| 19 | + h = jnp.zeros(n) |
| 20 | + h = h.at[0].set(1) |
| 21 | + |
| 22 | + if n % 2 == 0: |
| 23 | + h = h.at[1:n//2].set(2) |
| 24 | + h = h.at[n//2].set(1) |
| 25 | + else: |
| 26 | + h = h.at[1:(n+1)//2].set(2) |
| 27 | + |
| 28 | + h = jnp.expand_dims(h, tuple(i for i in range(signal.ndim) if i != axis)) |
| 29 | + h = jnp.broadcast_to(h, signal.shape) |
| 30 | + |
| 31 | + fft_signal = jnp.fft.fft(signal, axis=axis) |
| 32 | + analytic_fft = fft_signal * h |
| 33 | + |
| 34 | + analytic_signal = jnp.fft.ifft(analytic_fft) |
| 35 | + return analytic_signal |
| 36 | + |
16 | 37 |
|
17 | 38 | class OcWc: |
18 | 39 | def __init__( |
@@ -96,6 +117,10 @@ def accuracy_cost(self, output): |
96 | 117 | accuracy_cost += self.weights["w_var"] * self.compute_var_cost(output) |
97 | 118 | if self.weights["w_f_osc"] != 0.0: |
98 | 119 | accuracy_cost += self.weights["w_f_osc"] * self.compute_osc_fourier_cost(output) |
| 120 | + if self.weights["w_f_sync"] != 0.0: |
| 121 | + accuracy_cost += self.weights["w_f_sync"] * self.compute_sync_fourier_cost(output) |
| 122 | + if self.weights["w_kuramoto"] != 0.0: |
| 123 | + accuracy_cost += self.weights["w_kuramoto"] * self.compute_kuramoto_cost(output) |
99 | 124 | return accuracy_cost |
100 | 125 |
|
101 | 126 | def control_strength_cost(self, control): |
@@ -124,8 +149,29 @@ def compute_cc_cost(self, output): |
124 | 149 |
|
125 | 150 | def compute_var_cost(self, output): |
126 | 151 | return jnp.var(output, axis=(0, 1)).mean() |
| 152 | + |
| 153 | + def get_fourier_component(self, data, target_period): |
| 154 | + fourier_series = jnp.abs(jnp.fft.fft(data)[:len(data)//2]) |
| 155 | + freqs = jnp.fft.fftfreq(data.size,d=self.model.params.dt)[:len(data)//2] |
| 156 | + return fourier_series[jnp.argmin(jnp.abs(freqs - 1./target_period))] |
127 | 157 |
|
128 | 158 | def compute_osc_fourier_cost(self, output): |
| 159 | + cost = 0.0 |
| 160 | + for n in range(output.shape[1]): |
| 161 | + for v in range(output.shape[0]): |
| 162 | + cost -= self.get_fourier_component(output[v, n], self.target)**2 |
| 163 | + return cost / (output.shape[2] * self.model.params.dt)**2 |
| 164 | + |
| 165 | + def compute_sync_fourier_cost(self, output): |
| 166 | + cost = 0.0 |
| 167 | + for v in range(output.shape[0]): |
| 168 | + cost -= self.get_fourier_component(jnp.sum(output[v], axis=0), self.target)**2 |
| 169 | + return cost / (output.shape[2] * self.model.params.dt)**2 |
| 170 | + |
| 171 | + def compute_kuramoto_cost(self, output): |
| 172 | + phase = jnp.angle(hilbert_jax(output, axis=2)) |
| 173 | + return -jnp.mean(jnp.abs(jnp.mean(jnp.exp(complex(0,1)*phase), axis=1))) |
| 174 | + |
129 | 175 |
|
130 | 176 | def optimize_deterministic(self, n_max_iterations, output_every_nth=None): |
131 | 177 | """Compute the optimal control signal for noise averaging method 0. |
|
0 commit comments