|
221 | 221 | function ODESystem(deqs::AbstractVector{<:Equation}, iv, dvs, ps; |
222 | 222 | controls = Num[], |
223 | 223 | observed = Equation[], |
224 | | - constraints = Equation[], |
| 224 | + constraintsystem = nothing, |
225 | 225 | systems = ODESystem[], |
226 | 226 | tspan = nothing, |
227 | 227 | name = nothing, |
@@ -286,26 +286,17 @@ function ODESystem(deqs::AbstractVector{<:Equation}, iv, dvs, ps; |
286 | 286 | cont_callbacks = SymbolicContinuousCallbacks(continuous_events) |
287 | 287 | disc_callbacks = SymbolicDiscreteCallbacks(discrete_events) |
288 | 288 |
|
289 | | - constraintsys = nothing |
290 | | - if !isempty(constraints) |
291 | | - constraintsys = process_constraint_system(constraints, dvs′, ps′, iv, systems) |
292 | | - dvset = Set(dvs′) |
293 | | - pset = Set(ps′) |
294 | | - for st in get_unknowns(constraintsys) |
295 | | - iscall(st) ? |
296 | | - !in(operation(st)(iv), dvset) && push!(dvs′, st) : |
297 | | - !in(st, dvset) && push!(dvs′, st) |
298 | | - end |
299 | | - for p in parameters(constraintsys) |
300 | | - !in(p, pset) && push!(ps′, p) |
301 | | - end |
302 | | - end |
303 | | - |
304 | 289 | if is_dde === nothing |
305 | 290 | is_dde = _check_if_dde(deqs, iv′, systems) |
306 | 291 | end |
| 292 | + |
| 293 | + if !isempty(systems) |
| 294 | + cons = get_constraintsystems.(systems) |
| 295 | + @set! constraintsystem.systems = cons |
| 296 | + end |
| 297 | + |
307 | 298 | ODESystem(Threads.atomic_add!(SYSTEM_COUNT, UInt(1)), |
308 | | - deqs, iv′, dvs′, ps′, tspan, var_to_name, ctrl′, observed, constraintsys, tgrad, jac, |
| 299 | + deqs, iv′, dvs′, ps′, tspan, var_to_name, ctrl′, observed, constraintsystem, tgrad, jac, |
309 | 300 | ctrl_jac, Wfact, Wfact_t, name, description, systems, |
310 | 301 | defaults, guesses, nothing, initializesystem, |
311 | 302 | initialization_eqs, schedule, connector_type, preface, cont_callbacks, |
@@ -377,9 +368,22 @@ function ODESystem(eqs, iv; constraints = Equation[], kwargs...) |
377 | 368 | end |
378 | 369 | algevars = setdiff(allunknowns, diffvars) |
379 | 370 |
|
| 371 | + if !isempty(constraints) |
| 372 | + consvars = OrderedSet() |
| 373 | + constraintsystem = process_constraint_system(constraints, allunknowns, new_ps, iv) |
| 374 | + for st in get_unknowns(constraintsystem) |
| 375 | + iscall(st) ? |
| 376 | + !in(operation(st)(iv), allunknowns) && push!(consvars, st) : |
| 377 | + !in(st, allunknowns) && push!(consvars, st) |
| 378 | + end |
| 379 | + for p in parameters(constraintsystem) |
| 380 | + !in(p, new_ps) && push!(new_ps, p) |
| 381 | + end |
| 382 | + end |
| 383 | + |
380 | 384 | # the orders here are very important! |
381 | 385 | return ODESystem(Equation[diffeq; algeeq; compressed_eqs], iv, |
382 | | - collect(Iterators.flatten((diffvars, algevars))), collect(new_ps); constraints, kwargs...) |
| 386 | + collect(Iterators.flatten((diffvars, algevars, consvars))), collect(new_ps); constraintsystem, kwargs...) |
383 | 387 | end |
384 | 388 |
|
385 | 389 | # NOTE: equality does not check cached Jacobian |
@@ -791,55 +795,38 @@ function Base.show(io::IO, mime::MIME"text/plain", sys::ODESystem; hint = true, |
791 | 795 | end |
792 | 796 |
|
793 | 797 | # Validate that all the variables in the BVP constraints are well-formed states or parameters. |
794 | | -# - Any callable with multiple arguments will error. |
795 | 798 | # - Callable/delay variables (e.g. of the form x(0.6) should be unknowns of the system (and have one arg, etc.) |
796 | 799 | # - Callable/delay parameters should be parameters of the system (and have one arg, etc.) |
797 | | -function validate_constraint_syms(constraintsts, constraintps, sts, ps, iv) |
| 800 | +function process_constraint_system(constraints::Vector{Equation}, sts, ps, iv; consname = :cons) |
| 801 | + isempty(constraints) && return nothing |
| 802 | + |
| 803 | + constraintsts = OrderedSet() |
| 804 | + constraintps = OrderedSet() |
| 805 | + |
| 806 | + # Hack? to extract parameters from callable variables in constraints. |
| 807 | + for cons in constraints |
| 808 | + collect_vars!(constraintsts, constraintps, cons, iv) |
| 809 | + end |
| 810 | + |
| 811 | + # Validate the states. |
798 | 812 | for var in constraintsts |
799 | 813 | if !iscall(var) |
800 | | - occursin(iv, var) && var ∈ sts || throw(ArgumentError("Time-dependent variable $var is not an unknown of the system.")) |
| 814 | + occursin(iv, var) && (var ∈ sts || throw(ArgumentError("Time-dependent variable $var is not an unknown of the system."))) |
801 | 815 | elseif length(arguments(var)) > 1 |
802 | 816 | throw(ArgumentError("Too many arguments for variable $var.")) |
803 | 817 | elseif length(arguments(var)) == 1 |
804 | | - arg = first(arguments(var)) |
| 818 | + arg = only(arguments(var)) |
805 | 819 | operation(var)(iv) ∈ sts || |
806 | 820 | throw(ArgumentError("Variable $var is not a variable of the ODESystem. Called variables must be variables of the ODESystem.")) |
807 | 821 |
|
808 | 822 | isequal(arg, iv) || isparameter(arg) || arg isa Integer || arg isa AbstractFloat || |
809 | 823 | throw(ArgumentError("Invalid argument specified for variable $var. The argument of the variable should be either $iv, a parameter, or a value specifying the time that the constraint holds.")) |
| 824 | + |
| 825 | + isparameter(arg) && push!(constraintps, arg) |
810 | 826 | else |
811 | 827 | var ∈ sts && @warn "Variable $var has no argument. It will be interpreted as $var($iv), and the constraint will apply to the entire interval." |
812 | 828 | end |
813 | 829 | end |
814 | 830 |
|
815 | | - for var in constraintps |
816 | | - !iscall(var) && continue |
817 | | - |
818 | | - if length(arguments(var)) > 1 |
819 | | - throw(ArgumentError("Too many arguments for parameter $var in equation $eq.")) |
820 | | - elseif length(arguments(var)) == 1 |
821 | | - arg = first(arguments(var)) |
822 | | - operation(var) ∈ ps || throw(ArgumentError("Parameter $var is not a parameter of the ODESystem. Called parameters must be parameters of the ODESystem.")) |
823 | | - |
824 | | - isequal(arg, iv) || |
825 | | - arg isa Integer || |
826 | | - arg isa AbstractFloat || |
827 | | - throw(ArgumentError("Invalid argument specified for callable parameter $var. The argument of the parameter should be either $iv, a parameter, or a value specifying the time that the constraint holds.")) |
828 | | - end |
829 | | - end |
830 | | -end |
831 | | - |
832 | | -function process_constraint_system(constraints::Vector{Equation}, sts, ps, iv, subsys::Vector{ODESystem}; name = :cons) |
833 | | - isempty(constraints) && return nothing |
834 | | - |
835 | | - constraintsts = OrderedSet() |
836 | | - constraintps = OrderedSet() |
837 | | - |
838 | | - for cons in constraints |
839 | | - syms = collect_vars!(constraintsts, constraintps, cons, iv) |
840 | | - end |
841 | | - validate_constraint_syms(constraintsts, constraintps, Set(sts), Set(ps), iv) |
842 | | - |
843 | | - constraint_subsys = get_constraintsystem.(subsys) |
844 | | - ConstraintsSystem(constraints, collect(constraintsts), collect(constraintps); systems = constraint_subsys, name) |
| 831 | + ConstraintsSystem(constraints, collect(constraintsts), collect(constraintps); name = consname) |
845 | 832 | end |
0 commit comments