@@ -102,6 +102,8 @@ def timeIntegration_args(params):
102102
103103 # ------------------------------------------------------------------------
104104
105+ integration_method = params ['integration_method' ]
106+
105107 return (
106108 startind ,
107109 t ,
@@ -134,6 +136,7 @@ def timeIntegration_args(params):
134136 tau_ou ,
135137 sigma_ou ,
136138 key ,
139+ integration_method
137140 )
138141
139142
@@ -170,6 +173,7 @@ def timeIntegration_elementwise(
170173 tau_ou ,
171174 sigma_ou ,
172175 key ,
176+ integration_method
173177):
174178
175179 update_step = get_update_step (
@@ -204,6 +208,7 @@ def timeIntegration_elementwise(
204208 tau_ou ,
205209 sigma_ou ,
206210 key ,
211+ integration_method
207212 )
208213
209214 # Iterating through time steps
@@ -222,7 +227,6 @@ def timeIntegration_elementwise(
222227 inh_ou ,
223228 )
224229
225-
226230def get_update_step (
227231 startind ,
228232 t ,
@@ -255,6 +259,7 @@ def get_update_step(
255259 tau_ou ,
256260 sigma_ou ,
257261 key ,
262+ integration_method
258263):
259264 key , subkey_exc = random .split (key )
260265 noise_exc = random .normal (subkey_exc , (N , len (t )))
@@ -269,7 +274,7 @@ def S_E(x):
269274 def S_I (x ):
270275 return 1.0 / (1.0 + jnp .exp (- a_inh * (x - mu_inh )))
271276
272- def update_step (state , _ ):
277+ def step_rhs (state ):
273278 exc_history , inh_history , exc_ou , inh_ou , i = state
274279
275280 # Vectorized calculation of delayed excitatory input
@@ -307,18 +312,64 @@ def update_step(state, _):
307312 + inh_ou # ou noise
308313 )
309314 )
315+
316+ exc_ou_rhs = (exc_ou_mean - exc_ou ) * dt / tau_ou + sigma_ou * sqrt_dt * noise_exc [:, i - startind ]
317+ inh_ou_rhs = (inh_ou_mean - inh_ou ) * dt / tau_ou + sigma_ou * sqrt_dt * noise_inh [:, i - startind ]
318+
319+ return exc_rhs , inh_rhs , exc_ou_rhs , inh_ou_rhs
320+
321+ def euler (state ):
322+ exc_rhs , inh_rhs , exc_ou_rhs , inh_ou_rhs = step_rhs (state )
323+ exc_history , inh_history , exc_ou , inh_ou , i = state
310324 # Euler integration
311325 # make sure e and i variables do not exceed 1 (can only happen with noise)
312326 exc_new = jnp .clip (exc_history [:, - 1 ] + dt * exc_rhs , 0 , 1 )
313327 inh_new = jnp .clip (inh_history [:, - 1 ] + dt * inh_rhs , 0 , 1 )
314328
315329 # Update Ornstein-Uhlenbeck process for noise
316330 exc_ou = (
317- exc_ou + ( exc_ou_mean - exc_ou ) * dt / tau_ou + sigma_ou * sqrt_dt * noise_exc [:, i - startind ]
331+ exc_ou + exc_ou_rhs
318332 ) # mV/ms
319333 inh_ou = (
320- inh_ou + (inh_ou_mean - inh_ou ) * dt / tau_ou + sigma_ou * sqrt_dt * noise_inh [:, i - startind ]
334+ inh_ou + inh_ou_rhs
335+ ) # mV/ms
336+
337+ return exc_new , inh_new , exc_ou , inh_ou
338+
339+ def heun (state ):
340+ # TODO
341+ exc_k1 , inh_k1 , exc_ou_rhs , inh_ou_rhs = step_rhs (state )
342+
343+ # Update Ornstein-Uhlenbeck process for noise
344+ exc_ou = (
345+ exc_ou + exc_ou_rhs
321346 ) # mV/ms
347+ inh_ou = (
348+ inh_ou + inh_ou_rhs
349+ ) # mV/ms
350+
351+ # make sure e and i variables do not exceed 1 (can only happen with noise)
352+ exc_new = jnp .clip (exc_history [:, - 1 ] + dt * exc_rhs , 0 , 1 )
353+ inh_new = jnp .clip (inh_history [:, - 1 ] + dt * inh_rhs , 0 , 1 )
354+
355+ exc_k1_history = jnp .concatenate ((exc_history [:, 1 :], jnp .expand_dims (exc_new , axis = 1 )), axis = 1 )
356+ inh_k1_history = jnp .concatenate ((inh_history [:, 1 :], jnp .expand_dims (inh_new , axis = 1 )), axis = 1 )
357+
358+ new_state = exc_k1_history , inh_k1_history , exc_ou , inh_ou
359+ exc_k2 , inh_k2 , _ , _ = step_rhs (new_state )
360+ exc_new = ...
361+ inh_new = ...
362+ return exc_new , inh_new , exc_ou , inh_ou
363+
364+ def update_step (state , _ ):
365+ exc_history , inh_history , exc_ou , inh_ou , i = state
366+ if integration_method == 'euler' :
367+ integration_f = euler
368+ else if integration_method == 'heun' :
369+ integration_f = heun
370+ else :
371+ raise Exception (f'Integration method { integration_method } not implemented.' )
372+ exc_new , inh_new , exc_ou , inh_ou = integration_f (state )
322373
323374 return (
324375 (
0 commit comments