Skip to content

Commit ee8a073

Browse files
test: test initialization on static array problems
1 parent a5b4d3f commit ee8a073

File tree

1 file changed

+52
-26
lines changed

1 file changed

+52
-26
lines changed

test/initializationsystem.jl

Lines changed: 52 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -593,22 +593,36 @@ end
593593
@parameters p q
594594
@brownian a b
595595
x = _x(t)
596-
596+
sarray_ctor = splat(SVector)
597597
# `System` constructor creates appropriate type with mtkbuild
598598
# `Problem` and `alg` create the problem to test and allow calling `init` with
599599
# the correct solver.
600600
# `rhss` allows adding terms to the end of equations (only 2 equations allowed) to influence
601601
# the system type (brownian vars to turn it into an SDE).
602-
@testset "$Problem with $(SciMLBase.parameterless_type(alg))" for (System, Problem, alg, rhss) in [
603-
(ModelingToolkit.System, ODEProblem, Tsit5(), zeros(2)),
604-
(ModelingToolkit.System, SDEProblem, ImplicitEM(), [a, b]),
605-
(ModelingToolkit.System, DDEProblem, MethodOfSteps(Tsit5()), [_x(t - 0.1), 0.0]),
606-
(ModelingToolkit.System, SDDEProblem, ImplicitEM(), [_x(t - 0.1) + a, b])
607-
]
602+
@testset "$Problem with $(SciMLBase.parameterless_type(alg)) and $ctor ctor" for ((System, Problem, alg, rhss), (ctor, expectedT)) in Iterators.product(
603+
[
604+
(ModelingToolkit.System, ODEProblem, Tsit5(), zeros(2)),
605+
(ModelingToolkit.System, SDEProblem, ImplicitEM(), [a, b]),
606+
(ModelingToolkit.System, DDEProblem,
607+
MethodOfSteps(Tsit5()), [_x(t - 0.1), 0.0]),
608+
(ModelingToolkit.System, SDDEProblem, ImplicitEM(), [_x(t - 0.1) + a, b])
609+
],
610+
[(identity, Any), (sarray_ctor, SVector)])
611+
u0_constructor = p_constructor = ctor
612+
if ctor !== identity
613+
Problem = Problem{false}
614+
end
608615
function test_parameter(prob, sym, val)
609616
if prob.u0 !== nothing
617+
@test prob.u0 isa expectedT
610618
@test init(prob, alg).ps[sym] val
611619
end
620+
@test prob.p.tunable isa expectedT
621+
initprob = prob.f.initialization_data.initializeprob
622+
if state_values(initprob) !== nothing
623+
@test state_values(initprob) isa expectedT
624+
end
625+
@test parameter_values(initprob).tunable isa expectedT
612626
@test solve(prob, alg).ps[sym] val
613627
end
614628
function test_initializesystem(sys, u0map, pmap, p, equation)
@@ -625,72 +639,72 @@ end
625639
@mtkbuild sys = System(
626640
[D(x) ~ x * q + rhss[1], D(y) ~ y * p + rhss[2]], t; defaults = [p => missing], guesses = [p => 1.0])
627641
pmap[p] = 2q
628-
prob = Problem(sys, u0map, (0.0, 1.0), pmap)
642+
prob = Problem(sys, u0map, (0.0, 1.0), pmap; u0_constructor, p_constructor)
629643
test_parameter(prob, p, 2.0)
630644
prob2 = remake(prob; u0 = u0map, p = pmap)
631-
prob2.ps[p] = 0.0
645+
prob2 = remake(prob2; p = setp_oop(prob2, p)(prob2, 0.0))
632646
test_parameter(prob2, p, 2.0)
633647
# `missing` default, provided guess
634648
@mtkbuild sys = System(
635649
[D(x) ~ x + rhss[1], p ~ x + y + rhss[2]], t; defaults = [p => missing], guesses = [p => 0.0])
636-
prob = Problem(sys, u0map, (0.0, 1.0))
650+
prob = Problem(sys, u0map, (0.0, 1.0); u0_constructor, p_constructor)
637651
test_parameter(prob, p, 2.0)
638652
test_initializesystem(sys, u0map, pmap, p, 0 ~ p - x - y)
639653
prob2 = remake(prob; u0 = u0map)
640-
prob2.ps[p] = 0.0
654+
prob2 = remake(prob2; p = setp_oop(prob2, p)(prob2, 0.0))
641655
test_parameter(prob2, p, 2.0)
642656

643657
# `missing` to Problem, equation from default
644658
@mtkbuild sys = System(
645659
[D(x) ~ x * q + rhss[1], D(y) ~ y * p + rhss[2]], t; defaults = [p => 2q], guesses = [p => 1.0])
646660
pmap[p] = missing
647-
prob = Problem(sys, u0map, (0.0, 1.0), pmap)
661+
prob = Problem(sys, u0map, (0.0, 1.0), pmap; u0_constructor, p_constructor)
648662
test_parameter(prob, p, 2.0)
649663
test_initializesystem(sys, u0map, pmap, p, 0 ~ 2q - p)
650664
prob2 = remake(prob; u0 = u0map, p = pmap)
651-
prob2.ps[p] = 0.0
665+
prob2 = remake(prob2; p = setp_oop(prob2, p)(prob2, 0.0))
652666
test_parameter(prob2, p, 2.0)
653667
# `missing` to Problem, provided guess
654668
@mtkbuild sys = System(
655669
[D(x) ~ x + rhss[1], p ~ x + y + rhss[2]], t; guesses = [p => 0.0])
656-
prob = Problem(sys, u0map, (0.0, 1.0), pmap)
670+
prob = Problem(sys, u0map, (0.0, 1.0), pmap; u0_constructor, p_constructor)
657671
test_parameter(prob, p, 2.0)
658672
test_initializesystem(sys, u0map, pmap, p, 0 ~ x + y - p)
659673
prob2 = remake(prob; u0 = u0map, p = pmap)
660-
prob2.ps[p] = 0.0
674+
prob2 = remake(prob2; p = setp_oop(prob2, p)(prob2, 0.0))
661675
test_parameter(prob2, p, 2.0)
662676

663677
# No `missing`, default and guess
664678
@mtkbuild sys = System(
665679
[D(x) ~ x * q + rhss[1], D(y) ~ y * p + rhss[2]], t; defaults = [p => 2q], guesses = [p => 0.0])
666680
delete!(pmap, p)
667-
prob = Problem(sys, u0map, (0.0, 1.0), pmap)
681+
prob = Problem(sys, u0map, (0.0, 1.0), pmap; u0_constructor, p_constructor)
668682
test_parameter(prob, p, 2.0)
669683
test_initializesystem(sys, u0map, pmap, p, 0 ~ 2q - p)
670684
prob2 = remake(prob; u0 = u0map, p = pmap)
671-
prob2.ps[p] = 0.0
685+
prob2 = remake(prob2; p = setp_oop(prob2, p)(prob2, 0.0))
672686
test_parameter(prob2, p, 2.0)
673687

674688
# Default overridden by Problem, guess provided
675689
@mtkbuild sys = System(
676690
[D(x) ~ q * x + rhss[1], D(y) ~ y * p + rhss[2]], t; defaults = [p => 2q], guesses = [p => 1.0])
677691
_pmap = merge(pmap, Dict(p => q))
678-
prob = Problem(sys, u0map, (0.0, 1.0), _pmap)
692+
prob = Problem(sys, u0map, (0.0, 1.0), _pmap; u0_constructor, p_constructor)
679693
test_parameter(prob, p, _pmap[q])
680694
test_initializesystem(sys, u0map, _pmap, p, 0 ~ q - p)
681695
# Problem dependent value with guess, no `missing`
682696
@mtkbuild sys = System(
683697
[D(x) ~ y * q + p + rhss[1], D(y) ~ x * p + q + rhss[2]], t; guesses = [p => 0.0])
684698
_pmap = merge(pmap, Dict(p => 3q))
685-
prob = Problem(sys, u0map, (0.0, 1.0), _pmap)
699+
prob = Problem(sys, u0map, (0.0, 1.0), _pmap; u0_constructor, p_constructor)
686700
test_parameter(prob, p, 3pmap[q])
687701

688702
# Should not be solved for:
689703
# Override dependent default with direct value
690704
@mtkbuild sys = System(
691705
[D(x) ~ q * x + rhss[1], D(y) ~ y * p + rhss[2]], t; defaults = [p => 2q], guesses = [p => 1.0])
692706
_pmap = merge(pmap, Dict(p => 1.0))
693-
prob = Problem(sys, u0map, (0.0, 1.0), _pmap)
707+
prob = Problem(sys, u0map, (0.0, 1.0), _pmap; u0_constructor, p_constructor)
694708
@test prob.ps[p] 1.0
695709
initsys = prob.f.initialization_data.initializeprob.f.sys
696710
@test is_parameter(initsys, p)
@@ -699,7 +713,7 @@ end
699713
@parameters r::Int s::Int
700714
@mtkbuild sys = System(
701715
[D(x) ~ s * x + rhss[1], D(y) ~ y * r + rhss[2]], t; defaults = [s => 2r], guesses = [s => 1.0])
702-
prob = Problem(sys, u0map, (0.0, 1.0), [r => 1])
716+
prob = Problem(sys, u0map, (0.0, 1.0), [r => 1]; u0_constructor, p_constructor)
703717
@test prob.ps[r] == 1
704718
@test prob.ps[s] == 2
705719
initsys = prob.f.initialization_data.initializeprob.f.sys
@@ -713,7 +727,7 @@ end
713727

714728
# Unsatisfiable initialization
715729
prob = Problem(sys, [x => 1.0, y => 1.0], (0.0, 1.0),
716-
[p => 2.0]; initialization_eqs = [x^2 + y^2 ~ 3])
730+
[p => 2.0]; initialization_eqs = [x^2 + y^2 ~ 3], u0_constructor, p_constructor)
717731
@test prob.f.initialization_data !== nothing
718732
@test solve(prob, alg).retcode == ReturnCode.InitialFailure
719733
cache = init(prob, alg)
@@ -790,8 +804,17 @@ end
790804

791805
prob_alg_combinations = zip(
792806
[NonlinearProblem, NonlinearLeastSquaresProblem], [nl_algs, nlls_algs])
793-
@testset "Parameter initialization" begin
807+
sarray_ctor = splat(SVector)
808+
@testset "Parameter initialization with ctor $ctor" for (ctor, expectedT) in [
809+
(identity, Any),
810+
(sarray_ctor, SVector)
811+
]
812+
u0_constructor = p_constructor = ctor
794813
function test_parameter(prob, alg, param, val)
814+
if prob.u0 !== nothing
815+
@test prob.u0 isa expectedT
816+
end
817+
@test prob.p.tunable isa expectedT
795818
integ = init(prob, alg)
796819
@test integ.ps[param]val rtol=1e-5
797820
# some algorithms are a little temperamental
@@ -817,19 +840,22 @@ end
817840
# guesses = [q => 1.0], initialization_eqs = [p^2 + q^2 + 2p * q ~ 0])
818841

819842
for (probT, algs) in prob_alg_combinations
820-
prob = probT(sys, [])
843+
if ctor != identity
844+
probT = probT{false}
845+
end
846+
prob = probT(sys, []; u0_constructor, p_constructor)
821847
@test prob.f.initialization_data !== nothing
822848
@test prob.f.initialization_data.initializeprobmap === nothing
823849
for alg in algs
824850
test_parameter(prob, alg, q, -2.0)
825851
end
826852

827853
# `update_initializeprob!` works
828-
prob.ps[p] = -2.0
854+
prob = remake(prob; p = setp_oop(prob, p)(prob, -2.0))
829855
for alg in algs
830856
test_parameter(prob, alg, q, 2.0)
831857
end
832-
prob.ps[p] = 2.0
858+
prob = remake(prob; p = setp_oop(prob, p)(prob, 2.0))
833859

834860
# `remake` works
835861
prob2 = remake(prob; p = [p => -2.0])

0 commit comments

Comments
 (0)