Skip to content

Commit 8628d03

Browse files
committed
Merge remote-tracking branch 'origin/jax' into jax
2 parents a6f56ad + 288ca67 commit 8628d03

File tree

3 files changed

+60
-4
lines changed

3 files changed

+60
-4
lines changed

neurolib/models/jax/wc/timeIntegration.py

Lines changed: 55 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,8 @@ def timeIntegration_args(params):
102102

103103
# ------------------------------------------------------------------------
104104

105+
integration_method = params['integration_method']
106+
105107
return (
106108
startind,
107109
t,
@@ -134,6 +136,7 @@ def timeIntegration_args(params):
134136
tau_ou,
135137
sigma_ou,
136138
key,
139+
integration_method
137140
)
138141

139142

@@ -170,6 +173,7 @@ def timeIntegration_elementwise(
170173
tau_ou,
171174
sigma_ou,
172175
key,
176+
integration_method
173177
):
174178

175179
update_step = get_update_step(
@@ -204,6 +208,7 @@ def timeIntegration_elementwise(
204208
tau_ou,
205209
sigma_ou,
206210
key,
211+
integration_method
207212
)
208213

209214
# Iterating through time steps
@@ -222,7 +227,6 @@ def timeIntegration_elementwise(
222227
inh_ou,
223228
)
224229

225-
226230
def get_update_step(
227231
startind,
228232
t,
@@ -255,6 +259,7 @@ def get_update_step(
255259
tau_ou,
256260
sigma_ou,
257261
key,
262+
integration_method
258263
):
259264
key, subkey_exc = random.split(key)
260265
noise_exc = random.normal(subkey_exc, (N, len(t)))
@@ -269,7 +274,7 @@ def S_E(x):
269274
def S_I(x):
270275
return 1.0 / (1.0 + jnp.exp(-a_inh * (x - mu_inh)))
271276

272-
def update_step(state, _):
277+
def step_rhs(state):
273278
exc_history, inh_history, exc_ou, inh_ou, i = state
274279

275280
# Vectorized calculation of delayed excitatory input
@@ -307,18 +312,64 @@ def update_step(state, _):
307312
+ inh_ou # ou noise
308313
)
309314
)
315+
316+
exc_ou_rhs = (exc_ou_mean - exc_ou) * dt / tau_ou + sigma_ou * sqrt_dt * noise_exc[:, i - startind]
317+
inh_ou_rhs = (inh_ou_mean - inh_ou) * dt / tau_ou + sigma_ou * sqrt_dt * noise_inh[:, i - startind]
318+
319+
return exc_rhs, inh_rhs, exc_ou_rhs, inh_ou_rhs
320+
321+
def euler(state):
322+
exc_rhs, inh_rhs, exc_ou_rhs, inh_ou_rhs = step_rhs(state)
323+
exc_history, inh_history, exc_ou, inh_ou, i = state
310324
# Euler integration
311325
# make sure e and i variables do not exceed 1 (can only happen with noise)
312326
exc_new = jnp.clip(exc_history[:, -1] + dt * exc_rhs, 0, 1)
313327
inh_new = jnp.clip(inh_history[:, -1] + dt * inh_rhs, 0, 1)
314328

315329
# Update Ornstein-Uhlenbeck process for noise
316330
exc_ou = (
317-
exc_ou + (exc_ou_mean - exc_ou) * dt / tau_ou + sigma_ou * sqrt_dt * noise_exc[:, i - startind]
331+
exc_ou + exc_ou_rhs
318332
) # mV/ms
319333
inh_ou = (
320-
inh_ou + (inh_ou_mean - inh_ou) * dt / tau_ou + sigma_ou * sqrt_dt * noise_inh[:, i - startind]
334+
inh_ou + inh_ou_rhs
335+
) # mV/ms
336+
337+
return exc_new, inh_new, exc_ou, inh_ou
338+
339+
def heun(state):
340+
# TODO
341+
exc_k1, inh_k1, exc_ou_rhs, inh_ou_rhs = step_rhs(state)
342+
343+
# Update Ornstein-Uhlenbeck process for noise
344+
exc_ou = (
345+
exc_ou + exc_ou_rhs
321346
) # mV/ms
347+
inh_ou = (
348+
inh_ou + inh_ou_rhs
349+
) # mV/ms
350+
351+
# make sure e and i variables do not exceed 1 (can only happen with noise)
352+
exc_new = jnp.clip(exc_history[:, -1] + dt * exc_rhs, 0, 1)
353+
inh_new = jnp.clip(inh_history[:, -1] + dt * inh_rhs, 0, 1)
354+
355+
exc_k1_history = jnp.concatenate((exc_history[:, 1:], jnp.expand_dims(exc_new, axis=1)), axis=1)
356+
inh_k1_history = jnp.concatenate((inh_history[:, 1:], jnp.expand_dims(inh_new, axis=1)), axis=1)
357+
358+
new_state = exc_k1_history, inh_k1_history, exc_ou, inh_ou
359+
exc_k2, inh_k2, _, _ = step_rhs(new_state)
360+
exc_new = ...
361+
inh_new = ...
362+
return exc_new, inh_new, exc_ou, inh_ou
363+
364+
def update_step(state, _):
365+
exc_history, inh_history, exc_ou, inh_ou, i = state
366+
if integration_method == 'euler':
367+
integration_f = euler
368+
else if integration_method == 'heun':
369+
integration_f = heun
370+
else:
371+
raise Exception(f'Integration method {integration_method} not implemented.')
372+
exc_new, inh_new, exc_ou, inh_ou = integration_f(state)
322373

323374
return (
324375
(

neurolib/models/wc/loadDefaultParams.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,4 +81,6 @@ def loadDefaultParams(Cmat=None, Dmat=None, seed=None):
8181
params.exc_ou = np.zeros((params.N,))
8282
params.inh_ou = np.zeros((params.N,))
8383

84+
params.integration_method = 'euler'
85+
8486
return params

tests/test_jax.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ def test_single_node_deterministic(self):
2626
model_jax = WCModel_jax(seed=0)
2727
model_jax.params["duration"] = 1.0 * 1000
2828
model_jax.params["sigma_ou"] = 0.0
29+
model_jax.params['integration_method'] = 'euler'
2930

3031
model_jax.run()
3132

@@ -48,6 +49,7 @@ def test_single_node_dist(self):
4849
model_jax = WCModel_jax()
4950
model_jax.params["duration"] = 5.0 * 1000
5051
model_jax.params["sigma_ou"] = 0.01
52+
model_jax.params['integration_method'] = 'euler'
5153

5254
model_jax.run()
5355

@@ -86,6 +88,7 @@ def test_network(self):
8688
model.params["duration"] = 10 * 1000
8789
model.params["sigma_ou"] = 0.0
8890
model.params["K_gl"] = 0.6
91+
model_jax.params['integration_method'] = 'euler'
8992

9093
# local node input parameter
9194
model.params["exc_ext"] = 0.72

0 commit comments

Comments
 (0)