Skip to content

Commit 90ce80d

Browse files
committed
refactor tests
1 parent ec386fe commit 90ce80d

File tree

3 files changed

+39
-55
lines changed

3 files changed

+39
-55
lines changed

src/systems/diffeqs/abstractodesystem.jl

Lines changed: 24 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -943,9 +943,10 @@ function SciMLBase.BVProblem{iip, specialize}(sys::AbstractODESystem, u0map = []
943943

944944
sts = unknowns(sys)
945945
ps = parameters(sys)
946+
constraintsys = get_constraintsystem(sys)
946947

947-
if !isnothing(constraints)
948-
(length(constraints) + length(u0map) > length(sts)) &&
948+
if !isnothing(constraintsys)
949+
(length(constraints(constraintsys)) + length(u0map) > length(sts)) &&
949950
@warn "The BVProblem is overdetermined. The total number of conditions (# constraints + # fixed initial values given by u0map) exceeds the total number of states. The BVP solvers will default to doing a nonlinear least-squares optimization."
950951
end
951952

@@ -976,33 +977,35 @@ function generate_function_bc(sys::ODESystem, u0, u0_idxs, tspan, iip)
976977
ps = get_ps(sys)
977978
np = length(ps)
978979
ns = length(sts)
979-
conssys = get_constraintsystem(sys)
980-
cons = constraints(conssys)
981-
982980
stidxmap = Dict([v => i for (i, v) in enumerate(sts)])
983981
pidxmap = Dict([v => i for (i, v) in enumerate(ps)])
984982

985983
@variables sol(..)[1:ns] p[1:np]
986-
exprs = Any[]
987984

988-
for st in get_unknowns(cons)
989-
x = operation(st)
990-
t = first(arguments(st))
991-
idx = stidxmap[x(iv)]
985+
conssys = get_constraintsystem(sys)
986+
cons = Any[]
987+
if !isnothing(conssys)
988+
cons = [con.lhs - con.rhs for con in constraints(conssys)]
992989

993-
cons = Symbolics.substitute(cons, Dict(x(t) => sol(t)[idx]))
994-
end
990+
for st in get_unknowns(conssys)
991+
x = operation(st)
992+
t = only(arguments(st))
993+
idx = stidxmap[x(iv)]
995994

996-
for var in get_parameters(cons)
997-
if iscall(var)
998-
x = operation(var)
999-
t = arguments(var)[1]
1000-
idx = pidxmap[x]
995+
cons = map(c -> Symbolics.substitute(c, Dict(x(t) => sol(t)[idx])), cons)
996+
end
1001997

1002-
cons = Symbolics.substitute(cons, Dict(x(t) => p[idx]))
1003-
else
1004-
idx = pidxmap[var]
1005-
cons = Symbolics.substitute(cons, Dict(var => p[idx]))
998+
for var in parameters(conssys)
999+
if iscall(var)
1000+
x = operation(var)
1001+
t = only(arguments(var))
1002+
idx = pidxmap[x]
1003+
1004+
cons = map(c -> Symbolics.substitute(c, Dict(x(t) => p[idx])), cons)
1005+
else
1006+
idx = pidxmap[var]
1007+
cons = map(c -> Symbolics.substitute(c, Dict(var => p[idx])), cons)
1008+
end
10061009
end
10071010
end
10081011

src/systems/diffeqs/odesystem.jl

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -802,13 +802,11 @@ function validate_constraint_syms(constraintsts, constraintps, sts, ps, iv)
802802
throw(ArgumentError("Too many arguments for variable $var."))
803803
elseif length(arguments(var)) == 1
804804
arg = first(arguments(var))
805-
operation(var)(iv) sts || throw(ArgumentError("Variable $var is not a variable of the ODESystem. Called variables must be variables of the ODESystem."))
805+
operation(var)(iv) sts ||
806+
throw(ArgumentError("Variable $var is not a variable of the ODESystem. Called variables must be variables of the ODESystem."))
806807

807-
isequal(arg, iv) ||
808-
isparameter(arg) ||
809-
arg isa Integer ||
810-
arg isa AbstractFloat ||
811-
throw(ArgumentError("Invalid argument specified for variable $var. The argument of the variable should be either $iv, a parameter, or a value specifying the time that the constraint holds."))
808+
isequal(arg, iv) || isparameter(arg) || arg isa Integer || arg isa AbstractFloat ||
809+
throw(ArgumentError("Invalid argument specified for variable $var. The argument of the variable should be either $iv, a parameter, or a value specifying the time that the constraint holds."))
812810
else
813811
var sts && @warn "Variable $var has no argument. It will be interpreted as $var($iv), and the constraint will apply to the entire interval."
814812
end
@@ -824,7 +822,6 @@ function validate_constraint_syms(constraintsts, constraintps, sts, ps, iv)
824822
operation(var) ps || throw(ArgumentError("Parameter $var is not a parameter of the ODESystem. Called parameters must be parameters of the ODESystem."))
825823

826824
isequal(arg, iv) ||
827-
isparameter(arg) ||
828825
arg isa Integer ||
829826
arg isa AbstractFloat ||
830827
throw(ArgumentError("Invalid argument specified for callable parameter $var. The argument of the parameter should be either $iv, a parameter, or a value specifying the time that the constraint holds."))

test/bvproblem.jl

Lines changed: 11 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@ using BenchmarkTools
55
using ModelingToolkit
66
using SciMLBase
77
using ModelingToolkit: t_nounits as t, D_nounits as D
8-
import ModelingToolkit: process_constraints
98

109
### Test Collocation solvers on simple problems
1110
solvers = [MIRK4]
@@ -113,8 +112,8 @@ let
113112
end
114113

115114
u0 = [1., 2.]; p = [1.5, 1., 1., 3.]
116-
genbc_iip = ModelingToolkit.generate_function_bc(lksys, nothing, u0, [1, 2], tspan, true)
117-
genbc_oop = ModelingToolkit.generate_function_bc(lksys, nothing, u0, [1, 2], tspan, false)
115+
genbc_iip = ModelingToolkit.generate_function_bc(lksys, u0, [1, 2], tspan, true)
116+
genbc_oop = ModelingToolkit.generate_function_bc(lksys, u0, [1, 2], tspan, false)
118117

119118
bvpi1 = SciMLBase.BVProblem(lotkavolterra!, bc!, [1.,2.], tspan, p)
120119
bvpi2 = SciMLBase.BVProblem(lotkavolterra!, genbc_iip, [1.,2.], tspan, p)
@@ -131,7 +130,8 @@ let
131130
@test sol1 sol2
132131

133132
# Test with a constraint.
134-
constraints = [y(0.5) ~ 2.]
133+
constr = [y(0.5) ~ 2.]
134+
@mtkbuild lksys = ODESystem(eqs, t; constraints = constr)
135135

136136
function bc!(resid, u, p, t)
137137
resid[1] = u(0.0)[1] - 1.
@@ -142,13 +142,13 @@ let
142142
end
143143

144144
u0 = [1, 1.]
145-
genbc_iip = ModelingToolkit.generate_function_bc(lksys, constraints, u0, [1], tspan, true)
146-
genbc_oop = ModelingToolkit.generate_function_bc(lksys, constraints, u0, [1], tspan, false)
145+
genbc_iip = ModelingToolkit.generate_function_bc(lksys, u0, [1], tspan, true)
146+
genbc_oop = ModelingToolkit.generate_function_bc(lksys, u0, [1], tspan, false)
147147

148148
bvpi1 = SciMLBase.BVProblem(lotkavolterra!, bc!, u0, tspan, p)
149149
bvpi2 = SciMLBase.BVProblem(lotkavolterra!, genbc_iip, u0, tspan, p)
150-
bvpi3 = SciMLBase.BVProblem{true, SciMLBase.AutoSpecialize}(lksys, [x(t) => 1.], tspan; guesses = [y(t) => 1.], constraints)
151-
bvpi4 = SciMLBase.BVProblem{true, SciMLBase.FullSpecialize}(lksys, [x(t) => 1.], tspan; guesses = [y(t) => 1.], constraints)
150+
bvpi3 = SciMLBase.BVProblem{true, SciMLBase.AutoSpecialize}(lksys, [x(t) => 1.], tspan; guesses = [y(t) => 1.])
151+
bvpi4 = SciMLBase.BVProblem{true, SciMLBase.FullSpecialize}(lksys, [x(t) => 1.], tspan; guesses = [y(t) => 1.])
152152

153153
sol1 = @btime solve($bvpi1, MIRK4(), dt = 0.01)
154154
sol2 = @btime solve($bvpi2, MIRK4(), dt = 0.01)
@@ -158,7 +158,7 @@ let
158158

159159
bvpo1 = BVProblem(lotkavolterra, bc, u0, tspan, p)
160160
bvpo2 = BVProblem(lotkavolterra, genbc_oop, u0, tspan, p)
161-
bvpo3 = SciMLBase.BVProblem{false, SciMLBase.FullSpecialize}(lksys, [x(t) => 1.], tspan; guesses = [y(t) => 1.], constraints)
161+
bvpo3 = SciMLBase.BVProblem{false, SciMLBase.FullSpecialize}(lksys, [x(t) => 1.], tspan; guesses = [y(t) => 1.])
162162

163163
sol1 = @btime solve($bvpo1, MIRK4(), dt = 0.05)
164164
sol2 = @btime solve($bvpo2, MIRK4(), dt = 0.05)
@@ -197,12 +197,6 @@ function test_solvers(solvers, prob, u0map, constraints, equations = []; dt = 0.
197197
end
198198
end
199199

200-
solvers = [RadauIIa3, RadauIIa5, RadauIIa7,
201-
LobattoIIIa2, LobattoIIIa4, LobattoIIIa5,
202-
LobattoIIIb2, LobattoIIIb3, LobattoIIIb4, LobattoIIIb5,
203-
LobattoIIIc2, LobattoIIIc3, LobattoIIIc4, LobattoIIIc5]
204-
weird = [MIRK2, MIRK5, RadauIIa2]
205-
daesolvers = []
206200
# Simple ODESystem with BVP constraints.
207201
let
208202
@parameters α=1.5 β=1.0 γ=3.0 δ=1.0
@@ -222,24 +216,14 @@ let
222216

223217
# Testing that more complicated constr give correct solutions.
224218
constr = [y(.2) + x(.8) ~ 3., y(.3) ~ 2.]
219+
@mtkbuild lksys = ODESystem(eqs, t; constraints = constr)
225220
bvp = SciMLBase.BVProblem{false, SciMLBase.FullSpecialize}(lksys, u0map, tspan; guesses)
226221
test_solvers(solvers, bvp, u0map, constr; dt = 0.05)
227222

228223
constr =* β - x(.6) ~ 0.0, y(.2) ~ 3.]
224+
@mtkbuild lksys = ODESystem(eqs, t; constraints = constr)
229225
bvp = SciMLBase.BVProblem{true, SciMLBase.AutoSpecialize}(lksys, u0map, tspan; guesses)
230226
test_solvers(solvers, bvp, u0map, constr)
231-
232-
# Testing that errors are properly thrown when malformed constr are given.
233-
@variables bad(..)
234-
constr = [x(1.) + bad(3.) ~ 10]
235-
@test_throws ErrorException lksys = ODESystem(eqs, t; constraints = constr)
236-
237-
constr = [x(t) + y(t) ~ 3]
238-
@test_throws ErrorException lksys = ODESystem(eqs, t; constraints = constr)
239-
240-
@parameters bad2
241-
constr = [bad2 + x(0.) ~ 3]
242-
@test_throws ErrorException lksys = ODESystem(eqs, t; constraints = constr)
243227
end
244228

245229
# Cartesian pendulum from the docs.

0 commit comments

Comments
 (0)