Skip to content

Commit ec2434e

Browse files
authored
Directional sparsity cost functional (#259)
* implement directional sparsity * Update cost_functions.py * Update oc.py * Update oc.py
1 parent b797770 commit ec2434e

File tree

3 files changed

+84
-11
lines changed

3 files changed

+84
-11
lines changed

neurolib/control/optimal_control/cost_functions.py

Lines changed: 56 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ def accuracy_cost(x, target_timeseries, weights, cost_matrix, dt, interval=(0, N
1919
:param interval: (t_start, t_end). Indices of start and end point of the slice (both inclusive) in time
2020
dimension. Only 'int' positive index-notation allowed (i.e. no negative indices or 'None').
2121
:type interval: tuple, optional
22-
22+
2323
:return: Accuracy cost.
2424
:rtype: float
2525
"""
@@ -56,7 +56,7 @@ def derivative_accuracy_cost(x, target_timeseries, weights, cost_matrix, interva
5656
:param interval: (t_start, t_end). Indices of start and end point of the slice (both inclusive) in time
5757
dimension. Only 'int' positive index-notation allowed (i.e. no negative indices or 'None').
5858
:type interval: tuple, optional
59-
59+
6060
:return: Accuracy cost derivative.
6161
:rtype: ndarray
6262
"""
@@ -84,7 +84,7 @@ def precision_cost(x_sim, x_target, cost_matrix, interval=(0, None)):
8484
:param interval: (t_start, t_end). Indices of start and end point of the slice (both inclusive) in time
8585
dimension. Only 'int' positive index-notation allowed (i.e. no negative indices or 'None').
8686
:type interval: tuple
87-
87+
8888
:return: Precision cost for time interval.
8989
:rtype: float
9090
"""
@@ -114,7 +114,7 @@ def derivative_precision_cost(x_sim, x_target, cost_matrix, interval):
114114
:param interval: (t_start, t_end). Indices of start and end point of the slice (both inclusive) in time
115115
dimension. Only 'int' positive index-notation allowed (i.e. no negative indices or 'None').
116116
:type interval: tuple
117-
117+
118118
:return: Control-dimensions x T array of precision cost gradients.
119119
:rtype: np.ndarray
120120
"""
@@ -140,7 +140,7 @@ def control_strength_cost(u, weights, dt):
140140
:type weights: dictionary
141141
:param dt: Time step.
142142
:type dt: float
143-
143+
144144
:return: control strength cost of the control.
145145
:rtype: float
146146
"""
@@ -159,17 +159,22 @@ def control_strength_cost(u, weights, dt):
159159
for t in range(u.shape[2]):
160160
cost += cost_timeseries[n, v, t] * dt
161161

162+
if weights["w_1D"] != 0.0:
163+
cost += weights["w_1D"] * L1D_cost_integral(u, dt)
164+
162165
return cost
163166

164167

165168
@numba.njit
166-
def derivative_control_strength_cost(u, weights):
169+
def derivative_control_strength_cost(u, weights, dt):
167170
"""Derivative of the 'control_strength_cost' wrt. the control 'u'.
168171
169172
:param u: Control-dimensions x T array. Control signals.
170173
:type u: np.ndarray
171174
:param weights: Dictionary of weights.
172175
:type weights: dictionary
176+
:param dt: Time step.
177+
:type dt: float
173178
174179
:return: Control-dimensions x T array of L2-cost gradients.
175180
:rtype: np.ndarray
@@ -179,6 +184,8 @@ def derivative_control_strength_cost(u, weights):
179184

180185
if weights["w_2"] != 0.0:
181186
der += weights["w_2"] * derivative_L2_cost(u)
187+
if weights["w_1D"] != 0.0:
188+
der += weights["w_1D"] * derivative_L1D_cost(u, dt)
182189

183190
return der
184191

@@ -189,7 +196,7 @@ def L2_cost(u):
189196
190197
:param u: Control-dimensions x T array. Control signals.
191198
:type u: np.ndarray
192-
199+
193200
:return: L2 cost of the control.
194201
:rtype: float
195202
"""
@@ -203,8 +210,49 @@ def derivative_L2_cost(u):
203210
204211
:param u: Control-dimensions x T array. Control signals.
205212
:type u: np.ndarray
206-
213+
207214
:return: Control-dimensions x T array of L2-cost gradients.
208215
:rtype: np.ndarray
209216
"""
210217
return u
218+
219+
220+
@numba.njit
221+
def L1D_cost_integral(
222+
u,
223+
dt,
224+
):
225+
"""'Directional sparsity' or 'L1D' cost integrated over time. Penalizes for control strength.
226+
:param u: Control-dimensions x T array. Control signals.
227+
:type u: np.ndarray
228+
:param dt: Time step.
229+
:type dt: float
230+
:return: L1D cost of the control.
231+
:rtype: float
232+
"""
233+
234+
return np.sum(np.sum(np.sqrt(np.sum(u**2, axis=2) * dt), axis=1), axis=0)
235+
236+
237+
@numba.njit
238+
def derivative_L1D_cost(
239+
u,
240+
dt,
241+
):
242+
"""
243+
:param u: Control-dimensions x T array. Control signals.
244+
:type u: np.ndarray
245+
:param dt: Time step.
246+
:type dt: float
247+
:return : Control-dimensions x T array of L1D-cost gradients.
248+
:rtype: np.ndarray
249+
"""
250+
251+
denominator = np.sqrt(np.sum(u**2, axis=2) * dt)
252+
der = np.zeros((u.shape))
253+
for n in range(der.shape[0]):
254+
for v in range(der.shape[1]):
255+
if denominator[n, v] != 0.0:
256+
der[n, v, :] = u[n, v, :] / denominator[n, v]
257+
258+
return der

neurolib/control/optimal_control/oc.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ def getdefaultweights():
1717
)
1818
weights["w_p"] = 1.0
1919
weights["w_2"] = 0.0
20+
weights["w_1D"] = 0.0
2021

2122
return weights
2223

@@ -471,14 +472,14 @@ def __init__(
471472
for v, iv in enumerate(self.model.input_vars):
472473
control[:, v, :] = self.model.params[iv]
473474

474-
self.control = control.copy()
475+
self.control = control.copy()
475476
self.check_params()
476477

477478
self.control = update_control_with_limit(
478479
self.N, self.dim_in, self.T, control, 0.0, np.zeros(control.shape), self.maximum_control_strength
479480
)
480481

481-
self.model_params = self.get_model_params()
482+
self.model_params = self.get_model_params()
482483

483484
def check_params(self):
484485
"""Checks a subset of parameters and throws an error if a wrong dimension is found."""
@@ -624,7 +625,7 @@ def compute_gradient(self):
624625
:rtype: np.ndarray of shape N x V x T
625626
"""
626627
self.solve_adjoint()
627-
df_du = cost_functions.derivative_control_strength_cost(self.control, self.weights)
628+
df_du = cost_functions.derivative_control_strength_cost(self.control, self.weights, self.dt)
628629
d_du = self.Duh()
629630

630631
return compute_gradient(

tests/control/optimal_control/test_oc_cost_functions.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,30 @@ def test_derivative_L2_cost(self):
160160
desired_output = u
161161
self.assertTrue(np.all(cost_functions.derivative_L2_cost(u) == desired_output))
162162

163+
def test_L1D_cost(self):
164+
print(" Test L1D cost")
165+
dt = 0.1
166+
reference_result = 2.0 * np.sum(np.sqrt(np.sum(p.TEST_INPUT_1N_6**2 * dt, axis=1)))
167+
weights = getdefaultweights()
168+
weights["w_1D"] = 1.0
169+
u = np.concatenate([p.TEST_INPUT_1N_6[:, np.newaxis, :], p.TEST_INPUT_1N_6[:, np.newaxis, :]], axis=1)
170+
L1D_cost = cost_functions.control_strength_cost(u, weights, dt)
171+
172+
self.assertAlmostEqual(L1D_cost, reference_result, places=8)
173+
174+
def test_derivative_L1D_cost(self):
175+
print(" Test L1D cost derivative")
176+
dt = 0.1
177+
denominator = np.sqrt(np.sum(p.TEST_INPUT_1N_6**2 * dt, axis=1))
178+
179+
u = np.concatenate([p.TEST_INPUT_1N_6[:, np.newaxis, :], p.TEST_INPUT_1N_6[:, np.newaxis, :]], axis=1)
180+
reference_result = np.zeros((u.shape))
181+
for n in range(u.shape[0]):
182+
for v in range(u.shape[1]):
183+
reference_result[n, v, :] = u[n, v, :] / denominator[n]
184+
185+
self.assertTrue(np.all(cost_functions.derivative_L1D_cost(u, dt) == reference_result))
186+
163187
def test_weights_dictionary(self):
164188
print("Test dictionary of cost weights")
165189
model = FHNModel()

0 commit comments

Comments
 (0)