Skip to content

Commit c02dfc7

Browse files
1b15lenasal
andcommitted
start timeIntegration i at 0 and not startind
Co-authored-by: Lena Salfenmoser <lenasal@users.noreply.github.com>
1 parent de0e813 commit c02dfc7

File tree

1 file changed

+8
-12
lines changed

1 file changed

+8
-12
lines changed

neurolib/models/jax/wc/timeIntegration.py

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ def timeIntegration_args(params):
7474

7575
# ------------------------------------------------------------------------
7676
# Initialization
77-
t = jnp.arange(1, jnp.round(duration, 6) / dt + 1) * dt # Time variable (ms)
77+
t = jnp.arange(1, 1 + jnp.round(duration, 6) / dt) * dt # Time variable (ms)
7878
sqrt_dt = jnp.sqrt(dt)
7979

8080
max_global_delay = int(jnp.max(Dmat_ndt))
@@ -87,8 +87,8 @@ def timeIntegration_args(params):
8787
exc_ext_baseline = params["exc_ext_baseline"]
8888
inh_ext_baseline = params["inh_ext_baseline"]
8989

90-
exc_ext = mu.adjustArrayShape_jax(params["exc_ext"], jnp.zeros((N, startind + len(t))))
91-
inh_ext = mu.adjustArrayShape_jax(params["inh_ext"], jnp.zeros((N, startind + len(t))))
90+
exc_ext = mu.adjustArrayShape_jax(params["exc_ext"], jnp.zeros((N, len(t))))
91+
inh_ext = mu.adjustArrayShape_jax(params["inh_ext"], jnp.zeros((N, len(t))))
9292

9393
# Set initial values
9494
# if initial values are just a Nx1 array
@@ -209,7 +209,7 @@ def timeIntegration_elementwise(
209209
# Iterating through time steps
210210
(exc_history, inh_history, exc_ou, inh_ou, i), (excs_new, inhs_new) = jax.lax.scan(
211211
update_step,
212-
(exc_init, inh_init, exc_ou_init, inh_ou_init, startind),
212+
(exc_init, inh_init, exc_ou_init, inh_ou_init, 0),
213213
xs=None,
214214
length=len(t),
215215
)
@@ -287,7 +287,7 @@ def update_step(state, _):
287287
- c_inhexc * inh_history[:, -1] # input from the inhibitory population
288288
+ exc_input_d # input from other nodes
289289
+ exc_ext_baseline # baseline external input (static)
290-
+ exc_ext[:, i - 1] # time-dependent external input
290+
+ exc_ext[:, i] # time-dependent external input
291291
)
292292
+ exc_ou # ou noise
293293
)
@@ -302,7 +302,7 @@ def update_step(state, _):
302302
c_excinh * exc_history[:, -1] # input from the excitatory population
303303
- c_inhinh * inh_history[:, -1] # input from within the inhibitory population
304304
+ inh_ext_baseline # baseline external input (static)
305-
+ inh_ext[:, i - 1] # time-dependent external input
305+
+ inh_ext[:, i] # time-dependent external input
306306
)
307307
+ inh_ou # ou noise
308308
)
@@ -313,12 +313,8 @@ def update_step(state, _):
313313
inh_new = jnp.clip(inh_history[:, -1] + dt * inh_rhs, 0, 1)
314314

315315
# Update Ornstein-Uhlenbeck process for noise
316-
exc_ou = (
317-
exc_ou + (exc_ou_mean - exc_ou) * dt / tau_ou + sigma_ou * sqrt_dt * noise_exc[:, i - startind]
318-
) # mV/ms
319-
inh_ou = (
320-
inh_ou + (inh_ou_mean - inh_ou) * dt / tau_ou + sigma_ou * sqrt_dt * noise_inh[:, i - startind]
321-
) # mV/ms
316+
exc_ou = exc_ou + (exc_ou_mean - exc_ou) * dt / tau_ou + sigma_ou * sqrt_dt * noise_exc[:, i] # mV/ms
317+
inh_ou = inh_ou + (inh_ou_mean - inh_ou) * dt / tau_ou + sigma_ou * sqrt_dt * noise_inh[:, i] # mV/ms
322318

323319
return (
324320
(

0 commit comments

Comments
 (0)