Skip to content

Commit ec386fe

Browse files
committed
Refactor constraints
1 parent 86d4144 commit ec386fe

File tree

8 files changed

+568
-273
lines changed

8 files changed

+568
-273
lines changed

src/ModelingToolkit.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,10 @@ include("systems/imperative_affect.jl")
150150
include("systems/callbacks.jl")
151151
include("systems/problem_utils.jl")
152152

153+
include("systems/optimization/constraints_system.jl")
154+
include("systems/optimization/optimizationsystem.jl")
155+
include("systems/optimization/modelingtoolkitize.jl")
156+
153157
include("systems/nonlinear/nonlinearsystem.jl")
154158
include("systems/nonlinear/homotopy_continuation.jl")
155159
include("systems/diffeqs/odesystem.jl")
@@ -165,10 +169,6 @@ include("systems/discrete_system/discrete_system.jl")
165169

166170
include("systems/jumps/jumpsystem.jl")
167171

168-
include("systems/optimization/constraints_system.jl")
169-
include("systems/optimization/optimizationsystem.jl")
170-
include("systems/optimization/modelingtoolkitize.jl")
171-
172172
include("systems/pde/pdesystem.jl")
173173

174174
include("systems/sparsematrixclil.jl")

src/systems/abstractsystem.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -983,6 +983,7 @@ for prop in [:eqs
983983
:structure
984984
:op
985985
:constraints
986+
:constraintsystem
986987
:controls
987988
:loss
988989
:bcs

src/systems/diffeqs/abstractodesystem.jl

Lines changed: 43 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -827,6 +827,12 @@ function DiffEqBase.ODEProblem{iip, specialize}(sys::AbstractODESystem, u0map =
827827
if !iscomplete(sys)
828828
error("A completed system is required. Call `complete` or `structural_simplify` on the system before creating an `ODEProblem`")
829829
end
830+
831+
if !isnothing(get_constraintsystem(sys))
832+
error("An ODESystem with constraints cannot be used to construct a regular ODEProblem.
833+
Consider a BVProblem instead.")
834+
end
835+
830836
f, u0, p = process_SciMLProblem(ODEFunction{iip, specialize}, sys, u0map, parammap;
831837
t = tspan !== nothing ? tspan[1] : tspan,
832838
check_length, warn_initialize_determined, eval_expression, eval_module, kwargs...)
@@ -866,18 +872,23 @@ Create a boundary value problem from the [`ODESystem`](@ref).
866872
must have either an initial guess supplied using `guesses` or a fixed initial
867873
value specified using `u0map`.
868874
869-
`constraints` are used to specify boundary conditions to the ODESystem in the
870-
form of equations. These values should specify values that state variables should
875+
Boundary value conditions are supplied to ODESystems
876+
in the form of a ConstraintsSystem. These equations
877+
should specify values that state variables should
871878
take at specific points, as in `x(0.5) ~ 1`). More general constraints that
872879
should hold over the entire solution, such as `x(t)^2 + y(t)^2`, should be
873-
specified as one of the equations used to build the `ODESystem`. Below is an example.
880+
specified as one of the equations used to build the `ODESystem`.
881+
882+
If an ODESystem without `constraints` is specified, it will be treated as an initial value problem.
874883
875884
```julia
876-
@parameters g
885+
@parameters g t_c = 0.5
877886
@variables x(..) y(t) [state_priority = 10] λ(t)
878887
eqs = [D(D(x(t))) ~ λ * x(t)
879888
D(D(y)) ~ λ * y - g
880889
x(t)^2 + y^2 ~ 1]
890+
cstr = [x(0.5) ~ 1]
891+
@named cstrs = ConstraintsSystem(cstr, t)
881892
@mtkbuild pend = ODESystem(eqs, t)
882893
883894
tspan = (0.0, 1.5)
@@ -889,9 +900,7 @@ specified as one of the equations used to build the `ODESystem`. Below is an exa
889900
bvp = SciMLBase.BVProblem{true, SciMLBase.AutoSpecialize}(pend, u0map, tspan, parammap; constraints, guesses, check_length = false)
890901
```
891902
892-
If no `constraints` are specified, the problem will be treated as an initial value problem.
893-
894-
If the `ODESystem` has algebraic equations like `x(t)^2 + y(t)^2`, the resulting
903+
If the `ODESystem` has algebraic equations, like `x(t)^2 + y(t)^2`, the resulting
895904
`BVProblem` must be solved using BVDAE solvers, such as Ascher.
896905
"""
897906
function SciMLBase.BVProblem(sys::AbstractODESystem, args...; kwargs...)
@@ -916,7 +925,7 @@ end
916925
function SciMLBase.BVProblem{iip, specialize}(sys::AbstractODESystem, u0map = [],
917926
tspan = get_tspan(sys),
918927
parammap = DiffEqBase.NullParameters();
919-
constraints = nothing, guesses = Dict(),
928+
guesses = Dict(),
920929
version = nothing, tgrad = false,
921930
callback = nothing,
922931
check_length = true,
@@ -930,21 +939,14 @@ function SciMLBase.BVProblem{iip, specialize}(sys::AbstractODESystem, u0map = []
930939
end
931940
!isnothing(callback) && error("BVP solvers do not support callbacks.")
932941

933-
has_alg_eqs(sys) && error("The BVProblem currently does not support ODESystems with algebraic equations.") # Remove this when the BVDAE solvers get updated, the codegen should work when it does.
942+
has_alg_eqs(sys) && error("The BVProblem constructor currently does not support ODESystems with algebraic equations.") # Remove this when the BVDAE solvers get updated, the codegen should work when it does.
934943

935-
constraintsts = nothing
936-
constraintps = nothing
937944
sts = unknowns(sys)
938945
ps = parameters(sys)
939946

940-
# Constraint validation
941947
if !isnothing(constraints)
942-
constraints isa Equation ||
943-
constraints isa Vector{Equation} ||
944-
error("Constraints must be specified as an equation or a vector of equations.")
945-
946948
(length(constraints) + length(u0map) > length(sts)) &&
947-
error("The BVProblem is overdetermined. The total number of conditions (# constraints + # fixed initial values given by u0map) cannot exceed the total number of states.")
949+
@warn "The BVProblem is overdetermined. The total number of conditions (# constraints + # fixed initial values given by u0map) exceeds the total number of states. The BVP solvers will default to doing a nonlinear least-squares optimization."
948950
end
949951

950952
# ODESystems without algebraic equations should use both fixed values + guesses
@@ -957,97 +959,60 @@ function SciMLBase.BVProblem{iip, specialize}(sys::AbstractODESystem, u0map = []
957959
stidxmap = Dict([v => i for (i, v) in enumerate(sts)])
958960
u0_idxs = has_alg_eqs(sys) ? collect(1:length(sts)) : [stidxmap[k] for (k,v) in u0map]
959961

960-
bc = process_constraints(sys, constraints, u0, u0_idxs, tspan, iip)
961-
962+
bc = generate_function_bc(sys, u0, u0_idxs, tspan, iip)
962963
return BVProblem{iip}(f, bc, u0, tspan, p; kwargs...)
963964
end
964965

965966
get_callback(prob::BVProblem) = error("BVP solvers do not support callbacks.")
966967

967-
# Validate that all the variables in the BVP constraints are well-formed states or parameters.
968-
function validate_constraint_syms(eq, constraintsts, constraintps, sts, ps, iv)
969-
for var in constraintsts
970-
if length(arguments(var)) > 1
971-
error("Too many arguments for variable $var.")
972-
elseif isequal(arguments(var)[1], iv)
973-
var sts || error("Constraint equation $eq contains a variable $var that is not a variable of the ODESystem.")
974-
error("Constraint equation $eq contains a variable $var that does not have a specified argument. Such equations should be specified as algebraic equations to the ODESystem rather than a boundary constraints.")
975-
else
976-
operation(var)(iv) sts || error("Constraint equation $eq contains a variable $(operation(var)) that is not a variable of the ODESystem.")
977-
end
978-
end
979-
980-
for var in constraintps
981-
if !iscall(var)
982-
var ps || error("Constraint equation $eq contains a parameter $var that is not a parameter of the ODESystem.")
983-
else
984-
length(arguments(var)) > 1 && error("Too many arguments for parameter $var.")
985-
operation(var) ps || error("Constraint equations contain a parameter $var that is not a parameter of the ODESystem.")
986-
end
987-
end
988-
end
989-
990968
"""
991-
process_constraints(sys, constraints, u0, tspan, iip)
969+
generate_function_bc(sys::ODESystem, u0, u0_idxs, tspan, iip)
992970
993-
Given an ODESystem with some constraints, generate the boundary condition function.
971+
Given an ODESystem with constraints, generate the boundary condition function to pass to boundary value problem solvers.
994972
"""
995-
function process_constraints(sys::ODESystem, constraints, u0, u0_idxs, tspan, iip)
996-
973+
function generate_function_bc(sys::ODESystem, u0, u0_idxs, tspan, iip)
997974
iv = get_iv(sys)
998975
sts = get_unknowns(sys)
999976
ps = get_ps(sys)
1000977
np = length(ps)
1001978
ns = length(sts)
979+
conssys = get_constraintsystem(sys)
980+
cons = constraints(conssys)
1002981

1003982
stidxmap = Dict([v => i for (i, v) in enumerate(sts)])
1004983
pidxmap = Dict([v => i for (i, v) in enumerate(ps)])
1005984

1006985
@variables sol(..)[1:ns] p[1:np]
1007986
exprs = Any[]
1008987

1009-
constraintsts = OrderedSet()
1010-
constraintps = OrderedSet()
988+
for st in get_unknowns(cons)
989+
x = operation(st)
990+
t = first(arguments(st))
991+
idx = stidxmap[x(iv)]
1011992

1012-
!isnothing(constraints) && for cons in constraints
1013-
collect_vars!(constraintsts, constraintps, cons, iv)
1014-
validate_constraint_syms(cons, constraintsts, constraintps, Set(sts), Set(ps), iv)
1015-
expr = cons.rhs - cons.lhs
1016-
1017-
for st in constraintsts
1018-
x = operation(st)
1019-
t = arguments(st)[1]
1020-
idx = stidxmap[x(iv)]
1021-
1022-
expr = Symbolics.substitute(expr, Dict(x(t) => sol(t)[idx]))
1023-
end
993+
cons = Symbolics.substitute(cons, Dict(x(t) => sol(t)[idx]))
994+
end
1024995

1025-
for var in constraintps
1026-
if iscall(var)
1027-
x = operation(var)
1028-
t = arguments(var)[1]
1029-
idx = pidxmap[x]
996+
for var in get_parameters(cons)
997+
if iscall(var)
998+
x = operation(var)
999+
t = arguments(var)[1]
1000+
idx = pidxmap[x]
10301001

1031-
expr = Symbolics.substitute(expr, Dict(x(t) => p[idx]))
1032-
else
1033-
idx = pidxmap[var]
1034-
expr = Symbolics.substitute(expr, Dict(var => p[idx]))
1035-
end
1002+
cons = Symbolics.substitute(cons, Dict(x(t) => p[idx]))
1003+
else
1004+
idx = pidxmap[var]
1005+
cons = Symbolics.substitute(cons, Dict(var => p[idx]))
10361006
end
1037-
1038-
empty!(constraintsts)
1039-
empty!(constraintps)
1040-
push!(exprs, expr)
10411007
end
10421008

1043-
init_cond_exprs = Any[]
1044-
1009+
init_conds = Any[]
10451010
for i in u0_idxs
10461011
expr = sol(tspan[1])[i] - u0[i]
1047-
push!(init_cond_exprs, expr)
1012+
push!(init_conds, expr)
10481013
end
10491014

1050-
exprs = vcat(init_cond_exprs, exprs)
1015+
exprs = vcat(init_conds, cons)
10511016
bcs = Symbolics.build_function(exprs, sol, p, expression = Val{false})
10521017
if iip
10531018
return (resid, u, p, t) -> bcs[2](resid, u, p)

0 commit comments

Comments
 (0)