Skip to content

Commit 4c47afe

Browse files
committed
separate out ODESystem construction
1 parent 66a6ee6 commit 4c47afe

File tree

8 files changed

+75
-45
lines changed

8 files changed

+75
-45
lines changed

demo/STGneuron.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,8 @@ dynamics = HodgkinHuxley(channels, gradients);
2626
@named neuron = CompartmentSystem(Vₘ, dynamics;
2727
geometry = geo,
2828
extensions = [calcium_conversion]);
29-
t_total = 5000.0
30-
sim = Simulation(neuron, t_total * ms)
29+
30+
sim = Simulation(neuron, (0ms, 4000ms))
3131
solution = solve(sim, Rosenbrock23());
3232

3333
# Plot at 5kHz sampling

demo/hodgkinhuxley.jl

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
# Classic Hodgkin Huxley neuron with a "current pulse" stimulus
22
using Conductor, Unitful, ModelingToolkit, OrdinaryDiffEq, Plots
33
import Unitful: mV, mS, cm, µm, pA, nA, mA, µA, ms
4-
import Conductor: Na, K # shorter aliases for Sodium/Potassium
54

65
Vₘ = ParentScope(MembranePotential(-65mV))
76

@@ -24,15 +23,13 @@ kdr_kinetics = [
2423
@named leak = IonChannel(Leak, max_g = 0.3mS / cm^2)
2524

2625
channels = [NaV, Kdr, leak];
27-
reversals = Equilibria([Na => 50.0mV, K => -77.0mV, Leak => -54.4mV])
28-
26+
reversals = Equilibria([Sodium => 50.0mV, Potassium => -77.0mV, Leak => -54.4mV])
2927
@named pulse_stim = PulseTrain(amplitude = 400.0pA, duration = 100ms, delay = 100ms)
30-
3128
dynamics = HodgkinHuxley(channels, reversals);
3229

3330
@named neuron = Compartment(Vₘ, dynamics; geometry = Sphere(radius = 20µm),
34-
stimuli = [pulse_stim])
31+
stimuli = [pulse_stim])
3532

36-
sim = Simulation(neuron, 300ms)
33+
sim = Simulation(neuron, (0ms,300ms))
3734
solution = solve(sim, Rosenbrock23(), abstol = 0.01, reltol = 0.01, dtmax = 100.0);
3835
plot(solution; size = (1200, 800))

src/Conductor.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,11 @@ abstract type AbstractConductanceSystem <: AbstractTimeDependentSystem end
117117
abstract type AbstractCompartmentSystem <: AbstractTimeDependentSystem end
118118
abstract type AbstractNeuronalNetworkSystem <: AbstractTimeDependentSystem end
119119

120+
const AbstractConductorSystem = Union{AbstractCurrentSystem,
121+
AbstractConductanceSystem,
122+
AbstractCompartmentSystem,
123+
AbstractNeuronalNetworkSystem}
124+
120125
# Model properties
121126
abstract type ConductanceModel end
122127

src/simulation.jl

Lines changed: 30 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
1-
function simplify_simulation(sys, time)
1+
2+
# for now, fallback to ODESystem conversion
3+
function ModelingToolkit.ODESystem(sys::AbstractConductorSystem; simplify = true)
24
odesys = convert(ODESystem, sys)
3-
t_val = ustrip(Float64, ms, time)
4-
return t_val, structural_simplify(odesys)
5+
return simplify ? structural_simplify(odesys) : odesys
56
end
67

8+
# simplified = settunable(simplified, tunable)
9+
710
"""
811
$(TYPEDSIGNATURES)
912
@@ -12,16 +15,11 @@ duration, `time`.
1215
1316
If `return_system == true`, returns a simplified `ODESystem` instead.
1417
"""
15-
function Simulation(neuron::AbstractCompartmentSystem, time::Time; return_system = false,
16-
jac = false, sparse = false,
17-
parallel = Symbolics.SerialForm())
18-
t_val, simplified = simplify_simulation(neuron, time)
19-
if return_system
20-
return simplified
21-
else
22-
@info repr("text/plain", simplified)
23-
return ODEProblem(simplified, [], (0.0, t_val), []; jac, sparse, parallel)
24-
end
18+
function Simulation(neuron::AbstractCompartmentSystem, tspan;
19+
simplify = true, parallel = Symbolics.SerialForm(), kwargs...)
20+
simplified = ODESystem(neuron; simplify)
21+
tstart, tstop = time_span(tspan)
22+
return ODEProblem(simplified, [], (tstart, tstop), []; parallel, kwargs...)
2523
end
2624

2725
struct NetworkParameters{T}
@@ -31,46 +29,49 @@ end
3129

3230
Base.getindex(x::NetworkParameters, i) = x.ps[i]
3331
topology(x::NetworkParameters) = getfield(x, :topology)
32+
3433
function get_weights(integrator, model)
3534
topo = topology(integrator.p)
3635
return graph(topo)[model]
3736
end
38-
function Simulation(network::NeuronalNetworkSystem, time::Time; return_system = false,
39-
jac = false, sparse = false, parallel = Symbolics.SerialForm(), continuous_events = false,
40-
refractory = true)
41-
t_val, simplified = simplify_simulation(network, time)
42-
return_system && return simplified
37+
38+
function Simulation(network::NeuronalNetworkSystem, tspan; simplify = true,
39+
parallel = Symbolics.SerialForm(), continuous_events = false,
40+
refractory = true, kwargs...)
41+
odesys = ODESystem(network; simplify)
42+
tstart, tstop = time_span(tspan)
4343
if !any(iseventbased.(synaptic_systems(network)))
44-
return ODEProblem(simplified, [], (0.0, t_val), []; jac, sparse, parallel)
44+
return ODEProblem(odesys, [], (tstart, tstop), []; parallel, kwargs...)
4545
else
46-
cb = generate_callback(network, simplified; continuous_events, refractory)
47-
prob = ODEProblem(simplified, [], (0.0, t_val), []; callback = cb, jac, sparse, parallel)
48-
remake(prob; p = NetworkParameters(prob.p, get_topology(network) ))
46+
cb = generate_callback(network, odesys; continuous_events, refractory)
47+
prob = ODEProblem(odesys, [], (tstart, tstop), [];
48+
callback = cb, parallel, kwargs...)
49+
remake(prob; p = NetworkParameters(prob.p, get_topology(network)))
4950
end
5051
end
5152

5253
# if continuous, condition has vector cb signature: cond(out, u, t, integrator)
53-
function generate_callback_condition(network, simplified; continuous_events, refractory)
54-
voltage_indices = map_voltage_indices(network, simplified; roots_only = true)
54+
function generate_callback_condition(network, odesys; continuous_events, refractory)
55+
voltage_indices = map_voltage_indices(network, odesys; roots_only = true)
5556
if continuous_events
5657
return ContinuousSpikeDetection(voltage_indices)
5758
else # discrete condition for each compartment
5859
return [DiscreteSpikeDetection(voltage_index, refractory) for voltage_index in voltage_indices]
5960
end
6061
end
6162

62-
function generate_callback_affects(network, simplified)
63+
function generate_callback_affects(network, odesys)
6364
spike_affects = []
6465
for sys in synaptic_systems(network)
65-
push!(spike_affects, SpikeAffect(sys, network, simplified))
66+
push!(spike_affects, SpikeAffect(sys, network, odesys))
6667
end
6768
tailcall = nothing # placeholder for voltage reset
6869
return NetworkAffects(spike_affects, tailcall)
6970
end
7071

71-
function generate_callback(network, simplified; continuous_events, refractory)
72-
cb_condition = generate_callback_condition(network, simplified; continuous_events, refractory)
73-
cb_affect = generate_callback_affects(network, simplified)
72+
function generate_callback(network, odesys; continuous_events, refractory)
73+
cb_condition = generate_callback_condition(network, odesys; continuous_events, refractory)
74+
cb_affect = generate_callback_affects(network, odesys)
7475
if continuous_events
7576
return VectorContinuousCallback(cb_condition, cb_affect,
7677
length(cb_condition.voltage_indices))

src/util.jl

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,14 @@ function build_toplevel(system)
5353
build_toplevel!(dvs, ps, eqs, defs, system)
5454
end
5555

56+
function time_span(tspan::Tuple{Time,Time})
57+
ustrip(Float64, ms, tspan[1]), ustrip(Float64, ms, tspan[2])
58+
end
59+
60+
time_span(tspan::Tuple{Float64,Float64}) = tspan
61+
time_span(tspan::Time) = zero(Float64), ustrip(Float64, ms, tspan)
62+
time_span(tspan::Real) = zero(Float64), Float64(tspan)
63+
5664
heaviside(x) = ifelse(x > zero(x), one(x), zero(x))
5765
@register_symbolic heaviside(x)
5866
ModelingToolkit.get_unit(op::typeof(heaviside), args) = ms^-1
@@ -74,3 +82,22 @@ function set_symarray_metadata(x, ctx, val)
7482
end
7583
end
7684

85+
function settunable(sys, ps)
86+
new_ps = setmetadata.(ps, ModelingToolkit.VariableTunable, true)
87+
@set! sys.ps = union(new_ps, parameters(sys))
88+
return sys
89+
end
90+
91+
function setbounds(sys, p_dists)
92+
new_ps = parameters(sys)
93+
for pair in p_dists
94+
i = findfirst(isequal(first(pair)), new_ps)
95+
i == nothing && throw("Parameter $(first(pair)) not found.")
96+
new_ps[i] = setmetadata.(new_ps, ModelingToolkit.VariableBounds, second(pair))
97+
end
98+
@set! sys.ps = union(new_ps, parameters(sys))
99+
return sys
100+
end
101+
102+
function setdistributions(sys, ps) end
103+

test/hodgkinhuxley.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,8 +54,7 @@ using Unitful: mV, mS, cm, µm, µA, ms, pA
5454
states(neuron),
5555
parameters(neuron)]) == [11, 11, 8]
5656

57-
time = 300.0
58-
sim_sys = Simulation(neuron, time * ms; return_system = true)
57+
sim_sys = ODESystem(neuron)
5958

6059
@test length.([equations(sim_sys),
6160
states(sim_sys),
@@ -136,6 +135,7 @@ using Unitful: mV, mS, cm, µm, µA, ms, pA
136135
return nothing
137136
end
138137

138+
time = 300.0
139139
byhand_prob = ODEProblem{true}(hodgkin_huxley!, u0, (0.0, 300.0), p)
140140
mtk_prob = ODEProblem(sim_sys, [], (0.0, time), [])
141141

test/prinzneuron.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,7 @@ include(normpath(@__DIR__, "..", "demo", "prinz_kinetics.jl"))
2424
states(neuron),
2525
parameters(neuron)]) == [31, 31, 17]
2626

27-
time = 2000
28-
simul_sys = Simulation(neuron, time * ms; return_system = true)
27+
simul_sys = ODESystem(neuron)
2928

3029
@test length.([equations(simul_sys),
3130
states(simul_sys),
@@ -184,12 +183,13 @@ include(normpath(@__DIR__, "..", "demo", "prinz_kinetics.jl"))
184183
end
185184
end
186185

186+
time = 2000.0
187187
byhand_prob = ODEProblem{true}(prinz_neuron!, u0, (0.0, time), p)
188188
mtk_prob = ODEProblem(simul_sys, [], (0.0, time), [])
189189
byhand_sol = solve(byhand_prob, Rosenbrock23(), reltol = 1e-8, abstol = 1e-8)
190190
current_mtk_sol = solve(mtk_prob, Rosenbrock23(), reltol = 1e-8, abstol = 1e-8)
191191

192-
tsteps = 0.0:0.025:2000.0
192+
tsteps = 0.0:0.025:time
193193
byhand_out = Array(byhand_sol(tsteps, idxs = 2))
194194
current_mtk_out = current_mtk_sol(tsteps)[Vₘ]
195195

test/simplesynapse.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -64,8 +64,7 @@ import Conductor: Na, K
6464
states(network),
6565
parameters(network)]) == [24, 24, 20]
6666

67-
ttot = 250.0
68-
simul_sys = Simulation(network, ttot * ms; return_system = true)
67+
simul_sys = ODESystem(network)
6968

7069
@test length.([equations(simul_sys),
7170
states(simul_sys),
@@ -235,8 +234,9 @@ import Conductor: Na, K
235234
end
236235

237236
# Solve and check for invariance
237+
ttot = 250.0
238238
byhand_prob = ODEProblem{true}(simple_synapse!, u0, (0.0, ttot), p)
239-
sim = Simulation(network, ttot * ms)
239+
sim = Simulation(network, (0.0ms, ttot * ms))
240240
byhand_sol = solve(byhand_prob, Rosenbrock23(), reltol = 1e-9, abstol = 1e-9,
241241
saveat = 0.025)
242242
current_mtk_sol = solve(sim, Rosenbrock23(), reltol = 1e-9, abstol = 1e-9,

0 commit comments

Comments
 (0)