Skip to content

Commit 5ae2b5c

Browse files
committed
cleanup
1 parent 6c4b3b2 commit 5ae2b5c

File tree

2 files changed

+173
-108
lines changed

2 files changed

+173
-108
lines changed

neurolib/control/optimal_control/oc_jax.py

Lines changed: 61 additions & 108 deletions
Original file line numberDiff line numberDiff line change
@@ -5,36 +5,15 @@
55
import copy
66
from neurolib.models.jax.wc import WCModel
77
from neurolib.models.jax.wc.timeIntegration import timeIntegration_args, timeIntegration_elementwise
8-
9-
import logging
8+
from neurolib.optimize.loss_functions import (
9+
kuramoto_loss,
10+
cross_correlation_loss,
11+
variance_loss,
12+
osc_fourier_loss,
13+
sync_fourier_loss,
14+
)
1015
from neurolib.control.optimal_control.oc import getdefaultweights
1116

12-
# TODO: introduce for all models, not just WC
13-
wc_default_control_params = ["exc_ext", "inh_ext"]
14-
wc_default_target_params = ["exc", "inh"]
15-
16-
17-
def hilbert_jax(signal, axis=-1):
18-
19-
n = signal.shape[axis]
20-
h = jnp.zeros(n)
21-
h = h.at[0].set(1)
22-
23-
if n % 2 == 0:
24-
h = h.at[1 : n // 2].set(2)
25-
h = h.at[n // 2].set(1)
26-
else:
27-
h = h.at[1 : (n + 1) // 2].set(2)
28-
29-
h = jnp.expand_dims(h, tuple(i for i in range(signal.ndim) if i != axis))
30-
h = jnp.broadcast_to(h, signal.shape)
31-
32-
fft_signal = jnp.fft.fft(signal, axis=axis)
33-
analytic_fft = fft_signal * h
34-
35-
analytic_signal = jnp.fft.ifft(analytic_fft)
36-
return analytic_signal
37-
3817

3918
class Optimize:
4019
def __init__(
@@ -43,8 +22,8 @@ def __init__(
4322
loss_function,
4423
param_names,
4524
target_param_names,
46-
target=None,
4725
init_params=None,
26+
regularization_function=lambda _: 0.0,
4827
optimizer=optax.adabelief(1e-3),
4928
):
5029
assert isinstance(param_names, (list, tuple)) and len(param_names) > 0
@@ -54,7 +33,7 @@ def __init__(
5433

5534
self.model = copy.deepcopy(model)
5635
self.loss_function = loss_function
57-
self.target = target
36+
self.regularization_function = regularization_function
5837
self.optimizer = optimizer
5938
self.param_names = param_names
6039
self.target_param_names = target_param_names
@@ -70,7 +49,7 @@ def __init__(
7049
self.params = dict(zip(param_names, [self.args[p] for p in param_names]))
7150
self.opt_state = self.optimizer.init(self.params)
7251

73-
compute_loss = lambda params: self.loss_function(params, self.get_output(params))
52+
compute_loss = lambda params: self.loss_function(self.get_output(params)) + self.regularization_function(params)
7453
self.compute_loss = jax.jit(compute_loss)
7554
self.compute_gradient = jax.jit(jax.grad(self.compute_loss))
7655

@@ -93,15 +72,7 @@ def get_output(self, params):
9372
simulation_results = self.simulate(params)
9473
return jnp.stack([simulation_results[tp][:, self.startind :] for tp in self.target_param_names])
9574

96-
def get_loss(self):
97-
@jax.jit
98-
def loss(params):
99-
output = self.get_output(params)
100-
return self.loss_function(params, output)
101-
102-
return loss
103-
104-
def optimize_deterministic(self, n_max_iterations, output_every_nth=None):
75+
def optimize(self, n_max_iterations, output_every_nth=None):
10576
loss = self.compute_loss(self.control)
10677
print(f"loss in iteration 0: %s" % (loss))
10778
if len(self.cost_history) == 0: # add only if control model has not yet been optimized
@@ -121,105 +92,87 @@ def optimize_deterministic(self, n_max_iterations, output_every_nth=None):
12192
print(f"Final loss : %s" % (loss))
12293

12394

124-
class OcWc(Optimize):
95+
class Oc(Optimize):
96+
"""
97+
Convenience class for optimal control. The cost functional is constructed as a weighted sum of accuracy and control strength costs. Requires optimization parameters to be of shape (N, T).
98+
"""
99+
100+
supported_cost_parameters = [
101+
"w_p",
102+
"w_cc",
103+
"w_var",
104+
"w_f_osc",
105+
"w_f_sync",
106+
"w_ko",
107+
"w_2",
108+
"w_1D",
109+
]
110+
125111
def __init__(
126112
self,
127113
model,
128-
target=None,
114+
target_timeseries=None,
115+
target_frequency=None,
129116
optimizer=optax.adabelief(1e-3),
130-
control_param_names=wc_default_control_params,
131-
target_param_names=wc_default_target_params,
117+
control_param_names=["exc_ext", "inh_ext"],
118+
target_param_names=["exc", "inh"],
119+
weights=None,
132120
):
133121
super().__init__(
134122
model,
135-
self.compute_total_cost,
123+
self.accuracy_cost,
136124
control_param_names,
137125
target_param_names,
138-
target=target,
139126
init_params=None,
140127
optimizer=optimizer,
128+
regularization_function=self.control_strength_cost,
141129
)
130+
self.target_timeseries = target_timeseries
131+
self.target_frequency = target_frequency
142132
self.control = self.params
143-
self.weights = getdefaultweights()
133+
if weights is None:
134+
self.weights = getdefaultweights()
144135

145-
def compute_total_cost(self, control, output):
136+
def accuracy_cost(self, output):
146137
"""
147-
Compute the total cost as the sum of accuracy cost and control strength cost.
148-
149-
Parameters:
150-
control (dict[str, jax.numpy.ndarray]): Dictionary of control inputs, where each entry has shape (N, T).
151-
output (jax.numpy.ndarray): Simulation output of shape ((len(target_param_names)), N, T).
152-
153-
Returns:
154-
float: The total cost.
138+
Args:
139+
output (jax.numpy.ndarray): Simulation output of shape ((len(target_param_names)), N, T).
155140
"""
156-
accuracy_cost = self.accuracy_cost(output)
157-
control_arr = jnp.array(list(control.values()))
158-
control_strength_cost = self.control_strength_cost(control_arr)
159-
return accuracy_cost + control_strength_cost
160-
161-
# TODO: move cost functions outside
162-
def accuracy_cost(self, output):
163141
accuracy_cost = 0.0
164142
if self.weights["w_p"] != 0.0:
165-
accuracy_cost += self.weights["w_p"] * 0.5 * self.model.params.dt * jnp.sum((output - self.target) ** 2)
143+
accuracy_cost += self.weights["w_p"] * self.precision_cost(output)
166144
if self.weights["w_cc"] != 0.0:
167-
accuracy_cost += self.weights["w_cc"] * self.compute_cc_cost(output)
145+
accuracy_cost += self.weights["w_cc"] * cross_correlation_loss(output, self.model.params.dt)
168146
if self.weights["w_var"] != 0.0:
169-
accuracy_cost += self.weights["w_var"] * self.compute_var_cost(output)
147+
accuracy_cost += self.weights["w_var"] * variance_loss(output)
170148
if self.weights["w_f_osc"] != 0.0:
171-
accuracy_cost += self.weights["w_f_osc"] * self.compute_osc_fourier_cost(output)
149+
accuracy_cost += self.weights["w_f_osc"] * osc_fourier_loss(
150+
output, self.target_frequency, self.model.params.dt
151+
)
172152
if self.weights["w_f_sync"] != 0.0:
173-
accuracy_cost += self.weights["w_f_sync"] * self.compute_sync_fourier_cost(output)
153+
accuracy_cost += self.weights["w_f_sync"] * sync_fourier_loss(
154+
output, self.target_frequency, self.model.params.dt
155+
)
174156
if self.weights["w_ko"] != 0.0:
175-
accuracy_cost += self.weights["w_ko"] * self.compute_kuramoto_cost(output)
157+
accuracy_cost += self.weights["w_ko"] * kuramoto_loss(output)
176158
return accuracy_cost
177159

160+
def precision_cost(self, output):
161+
return 0.5 * self.model.params.dt * jnp.sum((output - self.target_timeseries) ** 2)
162+
178163
def control_strength_cost(self, control):
164+
"""
165+
Args:
166+
control (dict[str, jax.numpy.ndarray]): Dictionary of control inputs, where each entry has shape (N, T).
167+
"""
168+
control_arr = jnp.array(list(control.values()))
179169
control_strength_cost = 0.0
180170
if self.weights["w_2"] != 0.0:
181-
control_strength_cost += self.weights["w_2"] * 0.5 * self.model.params.dt * jnp.sum(control**2)
171+
control_strength_cost += self.weights["w_2"] * 0.5 * self.model.params.dt * jnp.sum(control_arr**2)
182172
if self.weights["w_1D"] != 0.0:
183-
control_strength_cost += self.weights["w_1D"] * self.compute_ds_cost(control)
173+
control_strength_cost += self.weights["w_1D"] * self.compute_ds_cost(control_arr)
184174
return control_strength_cost
185175

186176
def compute_ds_cost(self, control):
187177
eps = 1e-6 # avoid grad(sqrt(0.0))
188178
return jnp.sum(jnp.sqrt(jnp.sum(control**2, axis=2) * self.model.params.dt + eps))
189-
190-
def compute_cc_cost(self, output):
191-
xmean = jnp.mean(output, axis=2, keepdims=True)
192-
xstd = jnp.std(output, axis=2, keepdims=True)
193-
194-
xvec = (output - xmean) / xstd
195-
196-
costmat = jnp.einsum("vnt,vkt->vnkt", xvec, xvec)
197-
diag = jnp.einsum("vnt,vnt->vt", xvec, xvec)
198-
cost = jnp.sum(jnp.sum(costmat, axis=(1, 2)) - diag) * self.model.params.dt / 2.0
199-
cost *= -2.0 / (self.model.params.N * (self.model.params.N - 1) * self.T * self.model.params.dt)
200-
return cost
201-
202-
def compute_var_cost(self, output):
203-
return jnp.var(output, axis=(0, 1)).mean()
204-
205-
def get_fourier_component(self, data, target_period):
206-
fourier_series = jnp.abs(jnp.fft.fft(data)[: len(data) // 2])
207-
freqs = jnp.fft.fftfreq(data.size, d=self.model.params.dt)[: len(data) // 2]
208-
return fourier_series[jnp.argmin(jnp.abs(freqs - 1.0 / target_period))]
209-
210-
def compute_osc_fourier_cost(self, output):
211-
cost = 0.0
212-
for n in range(output.shape[1]):
213-
for v in range(output.shape[0]):
214-
cost -= self.get_fourier_component(output[v, n], self.target) ** 2
215-
return cost / (output.shape[2] * self.model.params.dt) ** 2
216-
217-
def compute_sync_fourier_cost(self, output):
218-
cost = 0.0
219-
for v in range(output.shape[0]):
220-
cost -= self.get_fourier_component(jnp.sum(output[v], axis=0), self.target) ** 2
221-
return cost / (output.shape[2] * self.model.params.dt) ** 2
222-
223-
def compute_kuramoto_cost(self, output):
224-
phase = jnp.angle(hilbert_jax(output, axis=2))
225-
return -jnp.mean(jnp.abs(jnp.mean(jnp.exp(complex(0, 1) * phase), axis=1)))
Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
import jax
2+
import jax.numpy as jnp
3+
4+
5+
def variance_loss(output):
6+
"""
7+
Args:
8+
output (jax.numpy.ndarray): Time series data with shape (n_output_vars, N, T)
9+
where N is number of nodes and T is number of timepoints
10+
11+
Returns:
12+
float: Variance over time, averaged across output variables and nodes
13+
"""
14+
return jnp.var(output, axis=(0, 1)).mean()
15+
16+
17+
def cross_correlation_loss(output, dt=1.0):
18+
"""
19+
Args:
20+
output (jax.numpy.ndarray): Time series data with shape (n_output_vars, N, T)
21+
where N is number of nodes and T is number of timepoints
22+
dt (float): Time step
23+
24+
Returns:
25+
float: Negative cross-correlation
26+
"""
27+
_, N, T = output.shape
28+
xmean = jnp.mean(output, axis=2, keepdims=True)
29+
xstd = jnp.std(output, axis=2, keepdims=True)
30+
31+
xvec = (output - xmean) / xstd
32+
33+
lossmat = jnp.einsum("vnt,vkt->vnkt", xvec, xvec)
34+
diag = jnp.einsum("vnt,vnt->vt", xvec, xvec)
35+
loss = jnp.sum(jnp.sum(lossmat, axis=(1, 2)) - diag) * dt / 2.0
36+
loss *= -2.0 / (N * (N - 1) * T * dt)
37+
return loss
38+
39+
40+
def hilbert(signal, axis=-1):
41+
n = signal.shape[axis]
42+
h = jnp.zeros(n)
43+
h = h.at[0].set(1)
44+
45+
if n % 2 == 0:
46+
h = h.at[1 : n // 2].set(2)
47+
h = h.at[n // 2].set(1)
48+
else:
49+
h = h.at[1 : (n + 1) // 2].set(2)
50+
51+
h = jnp.expand_dims(h, tuple(i for i in range(signal.ndim) if i != axis))
52+
h = jnp.broadcast_to(h, signal.shape)
53+
54+
fft_signal = jnp.fft.fft(signal, axis=axis)
55+
analytic_fft = fft_signal * h
56+
57+
analytic_signal = jnp.fft.ifft(analytic_fft)
58+
return analytic_signal
59+
60+
61+
def kuramoto_loss(output):
62+
"""
63+
Args:
64+
output (jax.numpy.ndarray): Time series data with shape (n_output_vars, N, T)
65+
where N is number of nodes and T is number of timepoints
66+
67+
Returns:
68+
float: Negative Kuramoto order parameter averaged over output variables
69+
"""
70+
phase = jnp.angle(hilbert(output, axis=2))
71+
return -jnp.mean(jnp.real(jnp.mean(jnp.exp(1j * phase), axis=1)))
72+
73+
74+
def get_fourier_component(data, target_frequency, dt=1.0):
75+
fourier_series = jnp.abs(jnp.fft.fft(data)[: len(data) // 2])
76+
freqs = jnp.fft.fftfreq(data.size, d=dt)[: len(data) // 2]
77+
return fourier_series[jnp.argmin(jnp.abs(freqs - target_frequency))]
78+
79+
80+
def osc_fourier_loss(output, target_frequency, dt=1.0):
81+
"""
82+
Args:
83+
output (jax.numpy.ndarray): Time series data with shape (n_output_vars, N, T)
84+
where N is number of nodes and T is number of timepoints
85+
target_frequency (float): Frequency to optimize for
86+
dt (float): Time step
87+
88+
Returns:
89+
float: Negative synchronization of output nodes at target frequency, irrespective of phase
90+
"""
91+
loss = 0.0
92+
for n in range(output.shape[1]):
93+
for v in range(output.shape[0]):
94+
loss -= get_fourier_component(output[v, n], target_frequency) ** 2
95+
return loss / (output.shape[2] * dt) ** 2
96+
97+
98+
def sync_fourier_loss(output, target_frequency, dt=1.0):
99+
"""
100+
Args:
101+
output (jax.numpy.ndarray): Time series data with shape (n_output_vars, N, T)
102+
where N is number of nodes and T is number of timepoints
103+
target_frequency (float): Frequency to optimize for
104+
dt (float): Time step
105+
106+
Returns:
107+
float: Negative synchronization of output nodes at target frequency, considering phase
108+
"""
109+
loss = 0.0
110+
for v in range(output.shape[0]):
111+
loss -= get_fourier_component(jnp.sum(output[v], axis=0), target_frequency) ** 2
112+
return loss / (output.shape[2] * dt) ** 2

0 commit comments

Comments
 (0)