Skip to content

Commit 1643719

Browse files
AayushSabharwalChrisRackauckas
authored andcommitted
feat: add initialization support
1 parent aaefbb7 commit 1643719

File tree

4 files changed

+62
-3
lines changed

4 files changed

+62
-3
lines changed

src/integrators/type.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ mutable struct DDEIntegrator{algType, IIP, uType, tType, P, eigenType, tTypeNoUn
3333
ksEltype, SolType, F, CacheType, IType, FP, O, dAbsType,
3434
dRelType, H,
3535
tstopsType, discType, FSALType, EventErrorType,
36-
CallbackCacheType, DV} <:
36+
CallbackCacheType, DV, IA} <:
3737
AbstractDDEIntegrator{algType, IIP, uType, tType}
3838
sol::SolType
3939
u::uType
@@ -95,6 +95,7 @@ mutable struct DDEIntegrator{algType, IIP, uType, tType, P, eigenType, tTypeNoUn
9595
integrator::IType
9696
fsalfirst::FSALType
9797
fsallast::FSALType
98+
initializealg::IA
9899
end
99100

100101
function (integrator::DDEIntegrator)(t, deriv::Type = Val{0}; idxs = nothing)

src/solve.jl

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ function DiffEqBase.__init(prob::DiffEqBase.AbstractDDEProblem,
6666
discontinuity_interp_points::Int = 10,
6767
discontinuity_abstol = eltype(prob.tspan)(1 // Int64(10)^12),
6868
discontinuity_reltol = 0,
69+
initializealg = DDEDefaultInit(),
6970
kwargs...)
7071
if haskey(kwargs, :initial_order)
7172
@warn "initial_order has been deprecated. Please specify order_discontinuity_t0 in the DDEProblem instead."
@@ -350,7 +351,7 @@ function DiffEqBase.__init(prob::DiffEqBase.AbstractDDEProblem,
350351
typeof(d_discontinuities_propagated),
351352
typeof(fsalfirst),
352353
typeof(last_event_error), typeof(callback_cache),
353-
typeof(differential_vars)}(sol, u, k,
354+
typeof(differential_vars), typeof(initializealg)}(sol, u, k,
354355
t0,
355356
tType(dt),
356357
f_with_history,
@@ -402,10 +403,11 @@ function DiffEqBase.__init(prob::DiffEqBase.AbstractDDEProblem,
402403
stats,
403404
history,
404405
differential_vars,
405-
ode_integrator, fsalfirst, fsallast)
406+
ode_integrator, fsalfirst, fsallast, initializealg)
406407

407408
# initialize DDE integrator
408409
if initialize_integrator
410+
DiffEqBase.initialize_dae!(integrator)
409411
initialize_solution!(integrator)
410412
OrdinaryDiffEqCore.initialize_callbacks!(integrator, initialize_save)
411413
OrdinaryDiffEqCore.initialize!(integrator)
@@ -538,3 +540,18 @@ function initialize_tstops_d_discontinuities_propagated(::Type{T}, tstops,
538540

539541
return tstops_propagated, d_discontinuities_propagated
540542
end
543+
544+
struct DDEDefaultInit <: DiffEqBase.DAEInitializationAlgorithm end
545+
546+
function DiffEqBase.initialize_dae!(integrator::DDEIntegrator, initializealg = integrator.initializealg)
547+
OrdinaryDiffEqCore._initialize_dae!(integrator, integrator.sol.prob, initializealg,
548+
Val(DiffEqBase.isinplace(integrator.sol.prob)))
549+
end
550+
551+
function OrdinaryDiffEqCore._initialize_dae!(integrator::DDEIntegrator, prob, ::DDEDefaultInit, isinplace)
552+
if SciMLBase.has_initializeprob(prob.f)
553+
OrdinaryDiffEqCore._initialize_dae!(integrator, prob, SciMLBase.OverrideInit(), isinplace)
554+
else
555+
OrdinaryDiffEqCore._initialize_dae!(integrator, prob, SciMLBase.CheckInit(), isinplace)
556+
end
557+
end

test/integrators/initialization.jl

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
using DelayDiffEq
2+
using SciMLBase
3+
using LinearAlgebra
4+
using Test
5+
6+
@testset "CheckInit" begin
7+
u0_good = [0.99, 0.01, 0.0]
8+
sir_history(p, t) = [1.0, 0.0, 0.0]
9+
tspan = (0.0, 40.0)
10+
p == 0.5, τ = 4.0)
11+
12+
function sir_ddae!(du, u, h, p, t)
13+
S, I, R = u
14+
γ, τ = p
15+
infection = γ * I * S
16+
Sd, Id, _ = h(p, t - τ)
17+
recovery = γ * Id * Sd
18+
@inbounds begin
19+
du[1] = -infection
20+
du[2] = infection - recovery
21+
du[3] = S + I + R - 1
22+
end
23+
nothing
24+
end
25+
26+
prob_ddae = DDEProblem(
27+
DDEFunction{true}(sir_ddae!;
28+
mass_matrix = Diagonal([1.0, 1.0, 0.0])),
29+
u0,
30+
sir_history,
31+
tspan,
32+
p;
33+
constant_lags = (p.τ,))
34+
alg = MethodOfSteps(Rosenbrock23())
35+
@test_nowarn init(prob_ddae, alg)
36+
prob.u0[1] = 2.0
37+
@test_throws SciMLBase.CheckInitFailureError init(prob_ddae, alg)
38+
end

test/runtests.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,9 @@ if GROUP == "All" || GROUP == "Integrators"
8484
@time @safetestset "Verner Tests" begin
8585
include("integrators/verner.jl")
8686
end
87+
@time @safetestset "Initialization" begin
88+
include("integrators/initialization.jl")
89+
end
8790
end
8891

8992
if GROUP == "All" || GROUP == "Regression"

0 commit comments

Comments
 (0)