Skip to content

Commit a15c670

Browse files
committed
fix sym validation
1 parent 90ce80d commit a15c670

File tree

2 files changed

+49
-57
lines changed

2 files changed

+49
-57
lines changed

src/systems/diffeqs/odesystem.jl

Lines changed: 39 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -221,7 +221,7 @@ end
221221
function ODESystem(deqs::AbstractVector{<:Equation}, iv, dvs, ps;
222222
controls = Num[],
223223
observed = Equation[],
224-
constraints = Equation[],
224+
constraintsystem = nothing,
225225
systems = ODESystem[],
226226
tspan = nothing,
227227
name = nothing,
@@ -286,26 +286,17 @@ function ODESystem(deqs::AbstractVector{<:Equation}, iv, dvs, ps;
286286
cont_callbacks = SymbolicContinuousCallbacks(continuous_events)
287287
disc_callbacks = SymbolicDiscreteCallbacks(discrete_events)
288288

289-
constraintsys = nothing
290-
if !isempty(constraints)
291-
constraintsys = process_constraint_system(constraints, dvs′, ps′, iv, systems)
292-
dvset = Set(dvs′)
293-
pset = Set(ps′)
294-
for st in get_unknowns(constraintsys)
295-
iscall(st) ?
296-
!in(operation(st)(iv), dvset) && push!(dvs′, st) :
297-
!in(st, dvset) && push!(dvs′, st)
298-
end
299-
for p in parameters(constraintsys)
300-
!in(p, pset) && push!(ps′, p)
301-
end
302-
end
303-
304289
if is_dde === nothing
305290
is_dde = _check_if_dde(deqs, iv′, systems)
306291
end
292+
293+
if !isempty(systems)
294+
cons = get_constraintsystems.(systems)
295+
@set! constraintsystem.systems = cons
296+
end
297+
307298
ODESystem(Threads.atomic_add!(SYSTEM_COUNT, UInt(1)),
308-
deqs, iv′, dvs′, ps′, tspan, var_to_name, ctrl′, observed, constraintsys, tgrad, jac,
299+
deqs, iv′, dvs′, ps′, tspan, var_to_name, ctrl′, observed, constraintsystem, tgrad, jac,
309300
ctrl_jac, Wfact, Wfact_t, name, description, systems,
310301
defaults, guesses, nothing, initializesystem,
311302
initialization_eqs, schedule, connector_type, preface, cont_callbacks,
@@ -377,9 +368,22 @@ function ODESystem(eqs, iv; constraints = Equation[], kwargs...)
377368
end
378369
algevars = setdiff(allunknowns, diffvars)
379370

371+
if !isempty(constraints)
372+
consvars = OrderedSet()
373+
constraintsystem = process_constraint_system(constraints, allunknowns, new_ps, iv)
374+
for st in get_unknowns(constraintsystem)
375+
iscall(st) ?
376+
!in(operation(st)(iv), allunknowns) && push!(consvars, st) :
377+
!in(st, allunknowns) && push!(consvars, st)
378+
end
379+
for p in parameters(constraintsystem)
380+
!in(p, new_ps) && push!(new_ps, p)
381+
end
382+
end
383+
380384
# the orders here are very important!
381385
return ODESystem(Equation[diffeq; algeeq; compressed_eqs], iv,
382-
collect(Iterators.flatten((diffvars, algevars))), collect(new_ps); constraints, kwargs...)
386+
collect(Iterators.flatten((diffvars, algevars, consvars))), collect(new_ps); constraintsystem, kwargs...)
383387
end
384388

385389
# NOTE: equality does not check cached Jacobian
@@ -791,55 +795,38 @@ function Base.show(io::IO, mime::MIME"text/plain", sys::ODESystem; hint = true,
791795
end
792796

793797
# Validate that all the variables in the BVP constraints are well-formed states or parameters.
794-
# - Any callable with multiple arguments will error.
795798
# - Callable/delay variables (e.g. of the form x(0.6) should be unknowns of the system (and have one arg, etc.)
796799
# - Callable/delay parameters should be parameters of the system (and have one arg, etc.)
797-
function validate_constraint_syms(constraintsts, constraintps, sts, ps, iv)
800+
function process_constraint_system(constraints::Vector{Equation}, sts, ps, iv; consname = :cons)
801+
isempty(constraints) && return nothing
802+
803+
constraintsts = OrderedSet()
804+
constraintps = OrderedSet()
805+
806+
# Hack? to extract parameters from callable variables in constraints.
807+
for cons in constraints
808+
collect_vars!(constraintsts, constraintps, cons, iv)
809+
end
810+
811+
# Validate the states.
798812
for var in constraintsts
799813
if !iscall(var)
800-
occursin(iv, var) && var sts || throw(ArgumentError("Time-dependent variable $var is not an unknown of the system."))
814+
occursin(iv, var) && (var sts || throw(ArgumentError("Time-dependent variable $var is not an unknown of the system.")))
801815
elseif length(arguments(var)) > 1
802816
throw(ArgumentError("Too many arguments for variable $var."))
803817
elseif length(arguments(var)) == 1
804-
arg = first(arguments(var))
818+
arg = only(arguments(var))
805819
operation(var)(iv) sts ||
806820
throw(ArgumentError("Variable $var is not a variable of the ODESystem. Called variables must be variables of the ODESystem."))
807821

808822
isequal(arg, iv) || isparameter(arg) || arg isa Integer || arg isa AbstractFloat ||
809823
throw(ArgumentError("Invalid argument specified for variable $var. The argument of the variable should be either $iv, a parameter, or a value specifying the time that the constraint holds."))
824+
825+
isparameter(arg) && push!(constraintps, arg)
810826
else
811827
var sts && @warn "Variable $var has no argument. It will be interpreted as $var($iv), and the constraint will apply to the entire interval."
812828
end
813829
end
814830

815-
for var in constraintps
816-
!iscall(var) && continue
817-
818-
if length(arguments(var)) > 1
819-
throw(ArgumentError("Too many arguments for parameter $var in equation $eq."))
820-
elseif length(arguments(var)) == 1
821-
arg = first(arguments(var))
822-
operation(var) ps || throw(ArgumentError("Parameter $var is not a parameter of the ODESystem. Called parameters must be parameters of the ODESystem."))
823-
824-
isequal(arg, iv) ||
825-
arg isa Integer ||
826-
arg isa AbstractFloat ||
827-
throw(ArgumentError("Invalid argument specified for callable parameter $var. The argument of the parameter should be either $iv, a parameter, or a value specifying the time that the constraint holds."))
828-
end
829-
end
830-
end
831-
832-
function process_constraint_system(constraints::Vector{Equation}, sts, ps, iv, subsys::Vector{ODESystem}; name = :cons)
833-
isempty(constraints) && return nothing
834-
835-
constraintsts = OrderedSet()
836-
constraintps = OrderedSet()
837-
838-
for cons in constraints
839-
syms = collect_vars!(constraintsts, constraintps, cons, iv)
840-
end
841-
validate_constraint_syms(constraintsts, constraintps, Set(sts), Set(ps), iv)
842-
843-
constraint_subsys = get_constraintsystem.(subsys)
844-
ConstraintsSystem(constraints, collect(constraintsts), collect(constraintps); systems = constraint_subsys, name)
831+
ConstraintsSystem(constraints, collect(constraintsts), collect(constraintps); name = consname)
845832
end

test/odesystem.jl

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1562,23 +1562,28 @@ end
15621562
# Test variables + parameters infer correctly.
15631563
@mtkbuild sys = ODESystem(eqs, t; constraints = cons)
15641564
@test issetequal(parameters(sys), [a, c, d, e])
1565-
@test issetequal(unknowns(sys), [x(t), y(t)])
1565+
@test issetequal(unknowns(sys), [x(t), y(t), z(t)])
15661566

15671567
@parameters t_c
15681568
cons = [x(t_c) ~ 3]
15691569
@mtkbuild sys = ODESystem(eqs, t; constraints = cons)
1570-
@test_broken issetequal(parameters(sys), [a, e, t_c]) # TODO: unbreak this.
1570+
@test issetequal(parameters(sys), [a, e, t_c])
1571+
1572+
@parameters g(..) h i
1573+
cons = [g(h, i) * x(3) ~ c]
1574+
@mtkbuild sys = ODESystem(eqs, t; constraints = cons)
1575+
@test issetequal(parameters(sys), [g, h, i, a, e, c])
15711576

15721577
# Test that bad constraints throw errors.
1573-
cons = [x(3, 4) ~ 3]
1578+
cons = [x(3, 4) ~ 3] # unknowns cannot have multiple args.
15741579
@test_throws ArgumentError @mtkbuild sys = ODESystem(eqs, t; constraints = cons)
15751580

1576-
cons = [x(y(t)) ~ 2]
1581+
cons = [x(y(t)) ~ 2] # unknown arg must be parameter, value, or t
15771582
@test_throws ArgumentError @mtkbuild sys = ODESystem(eqs, t; constraints = cons)
15781583

15791584
@variables u(t) v
15801585
cons = [x(t) * u ~ 3]
15811586
@test_throws ArgumentError @mtkbuild sys = ODESystem(eqs, t; constraints = cons)
15821587
cons = [x(t) * v ~ 3]
1583-
@test_nowarn @mtkbuild sys = ODESystem(eqs, t; constraints = cons)
1588+
@test_throws ArgumentError @mtkbuild sys = ODESystem(eqs, t; constraints = cons) # Need time argument.
15841589
end

0 commit comments

Comments
 (0)