@@ -74,7 +74,7 @@ def timeIntegration_args(params):
7474
7575 # ------------------------------------------------------------------------
7676 # Initialization
77- t = jnp .arange (1 , jnp .round (duration , 6 ) / dt + 1 ) * dt # Time variable (ms)
77+ t = jnp .arange (1 , 1 + jnp .round (duration , 6 ) / dt ) * dt # Time variable (ms)
7878 sqrt_dt = jnp .sqrt (dt )
7979
8080 max_global_delay = int (jnp .max (Dmat_ndt ))
@@ -87,8 +87,8 @@ def timeIntegration_args(params):
8787 exc_ext_baseline = params ["exc_ext_baseline" ]
8888 inh_ext_baseline = params ["inh_ext_baseline" ]
8989
90- exc_ext = mu .adjustArrayShape_jax (params ["exc_ext" ], jnp .zeros ((N , startind + len (t ))))
91- inh_ext = mu .adjustArrayShape_jax (params ["inh_ext" ], jnp .zeros ((N , startind + len (t ))))
90+ exc_ext = mu .adjustArrayShape_jax (params ["exc_ext" ], jnp .zeros ((N , len (t ))))
91+ inh_ext = mu .adjustArrayShape_jax (params ["inh_ext" ], jnp .zeros ((N , len (t ))))
9292
9393 # Set initial values
9494 # if initial values are just a Nx1 array
@@ -209,7 +209,7 @@ def timeIntegration_elementwise(
209209 # Iterating through time steps
210210 (exc_history , inh_history , exc_ou , inh_ou , i ), (excs_new , inhs_new ) = jax .lax .scan (
211211 update_step ,
212- (exc_init , inh_init , exc_ou_init , inh_ou_init , startind ),
212+ (exc_init , inh_init , exc_ou_init , inh_ou_init , 0 ),
213213 xs = None ,
214214 length = len (t ),
215215 )
@@ -287,7 +287,7 @@ def update_step(state, _):
287287 - c_inhexc * inh_history [:, - 1 ] # input from the inhibitory population
288288 + exc_input_d # input from other nodes
289289 + exc_ext_baseline # baseline external input (static)
290- + exc_ext [:, i - 1 ] # time-dependent external input
290+ + exc_ext [:, i ] # time-dependent external input
291291 )
292292 + exc_ou # ou noise
293293 )
@@ -302,7 +302,7 @@ def update_step(state, _):
302302 c_excinh * exc_history [:, - 1 ] # input from the excitatory population
303303 - c_inhinh * inh_history [:, - 1 ] # input from within the inhibitory population
304304 + inh_ext_baseline # baseline external input (static)
305- + inh_ext [:, i - 1 ] # time-dependent external input
305+ + inh_ext [:, i ] # time-dependent external input
306306 )
307307 + inh_ou # ou noise
308308 )
@@ -313,12 +313,8 @@ def update_step(state, _):
313313 inh_new = jnp .clip (inh_history [:, - 1 ] + dt * inh_rhs , 0 , 1 )
314314
315315 # Update Ornstein-Uhlenbeck process for noise
316- exc_ou = (
317- exc_ou + (exc_ou_mean - exc_ou ) * dt / tau_ou + sigma_ou * sqrt_dt * noise_exc [:, i - startind ]
318- ) # mV/ms
319- inh_ou = (
320- inh_ou + (inh_ou_mean - inh_ou ) * dt / tau_ou + sigma_ou * sqrt_dt * noise_inh [:, i - startind ]
321- ) # mV/ms
316+ exc_ou = exc_ou + (exc_ou_mean - exc_ou ) * dt / tau_ou + sigma_ou * sqrt_dt * noise_exc [:, i ] # mV/ms
317+ inh_ou = inh_ou + (inh_ou_mean - inh_ou ) * dt / tau_ou + sigma_ou * sqrt_dt * noise_inh [:, i ] # mV/ms
322318
323319 return (
324320 (
0 commit comments