diff --git a/src/systems/diffeqs/odesystem.jl b/src/systems/diffeqs/odesystem.jl index 68ea0b48e3..d1cd01ce7b 100644 --- a/src/systems/diffeqs/odesystem.jl +++ b/src/systems/diffeqs/odesystem.jl @@ -319,8 +319,7 @@ function ODESystem(eqs, iv; kwargs...) compressed_eqs = Equation[] # equations that need to be expanded later, like `connect(a, b)` for eq in eqs eq.lhs isa Union{Symbolic, Number} || (push!(compressed_eqs, eq); continue) - collect_vars!(allunknowns, ps, eq.lhs, iv) - collect_vars!(allunknowns, ps, eq.rhs, iv) + collect_vars!(allunknowns, ps, eq, iv) if isdiffeq(eq) diffvar, _ = var_from_nested_derivative(eq.lhs) if check_scope_depth(getmetadata(diffvar, SymScope, LocalScope()), 0) @@ -337,11 +336,9 @@ function ODESystem(eqs, iv; kwargs...) end for eq in get(kwargs, :parameter_dependencies, Equation[]) if eq isa Pair - collect_vars!(allunknowns, ps, eq[1], iv) - collect_vars!(allunknowns, ps, eq[2], iv) + collect_vars!(allunknowns, ps, eq, iv) else - collect_vars!(allunknowns, ps, eq.lhs, iv) - collect_vars!(allunknowns, ps, eq.rhs, iv) + collect_vars!(allunknowns, ps, eq, iv) end end for ssys in get(kwargs, :systems, ODESystem[]) diff --git a/src/systems/discrete_system/discrete_system.jl b/src/systems/discrete_system/discrete_system.jl index 7103cfca80..99f76a8ce9 100644 --- a/src/systems/discrete_system/discrete_system.jl +++ b/src/systems/discrete_system/discrete_system.jl @@ -175,8 +175,7 @@ function DiscreteSystem(eqs, iv; kwargs...) ps = OrderedSet() iv = value(iv) for eq in eqs - collect_vars!(allunknowns, ps, eq.lhs, iv; op = Shift) - collect_vars!(allunknowns, ps, eq.rhs, iv; op = Shift) + collect_vars!(allunknowns, ps, eq, iv; op = Shift) if iscall(eq.lhs) && operation(eq.lhs) isa Shift isequal(iv, operation(eq.lhs).t) || throw(ArgumentError("A DiscreteSystem can only have one independent variable.")) @@ -187,11 +186,9 @@ function DiscreteSystem(eqs, iv; kwargs...) end for eq in get(kwargs, :parameter_dependencies, Equation[]) if eq isa Pair - collect_vars!(allunknowns, ps, eq[1], iv) - collect_vars!(allunknowns, ps, eq[2], iv) + collect_vars!(allunknowns, ps, eq, iv) else - collect_vars!(allunknowns, ps, eq.lhs, iv) - collect_vars!(allunknowns, ps, eq.rhs, iv) + collect_vars!(allunknowns, ps, eq, iv) end end new_ps = OrderedSet() diff --git a/src/systems/nonlinear/nonlinearsystem.jl b/src/systems/nonlinear/nonlinearsystem.jl index 46bf032d6f..c649b9b287 100644 --- a/src/systems/nonlinear/nonlinearsystem.jl +++ b/src/systems/nonlinear/nonlinearsystem.jl @@ -166,16 +166,13 @@ function NonlinearSystem(eqs; kwargs...) allunknowns = OrderedSet() ps = OrderedSet() for eq in eqs - collect_vars!(allunknowns, ps, eq.lhs, nothing) - collect_vars!(allunknowns, ps, eq.rhs, nothing) + collect_vars!(allunknowns, ps, eq, nothing) end for eq in get(kwargs, :parameter_dependencies, Equation[]) if eq isa Pair - collect_vars!(allunknowns, ps, eq[1], nothing) - collect_vars!(allunknowns, ps, eq[2], nothing) + collect_vars!(allunknowns, ps, eq, nothing) else - collect_vars!(allunknowns, ps, eq.lhs, nothing) - collect_vars!(allunknowns, ps, eq.rhs, nothing) + collect_vars!(allunknowns, ps, eq, nothing) end end new_ps = OrderedSet() diff --git a/src/utils.jl b/src/utils.jl index e8ed131d78..c13d8a480a 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -492,20 +492,19 @@ recursively searches through all subsystems of `sys`, increasing the depth if it function collect_scoped_vars!(unknowns, parameters, sys, iv; depth = 1, op = Differential) if has_eqs(sys) for eq in get_eqs(sys) - eq isa Equation || continue - eq.lhs isa Union{Symbolic, Number} || continue - collect_vars!(unknowns, parameters, eq.lhs, iv; depth, op) - collect_vars!(unknowns, parameters, eq.rhs, iv; depth, op) + eqtype_supports_collect_vars(eq) || continue + if eq isa Equation + eq.lhs isa Union{Symbolic, Number} || continue + end + collect_vars!(unknowns, parameters, eq, iv; depth, op) end end if has_parameter_dependencies(sys) for eq in get_parameter_dependencies(sys) if eq isa Pair - collect_vars!(unknowns, parameters, eq[1], iv; depth, op) - collect_vars!(unknowns, parameters, eq[2], iv; depth, op) + collect_vars!(unknowns, parameters, eq, iv; depth, op) else - collect_vars!(unknowns, parameters, eq.lhs, iv; depth, op) - collect_vars!(unknowns, parameters, eq.rhs, iv; depth, op) + collect_vars!(unknowns, parameters, eq, iv; depth, op) end end end @@ -529,6 +528,29 @@ function collect_vars!(unknowns, parameters, expr, iv; depth = 0, op = Different return nothing end +""" + $(TYPEDSIGNATURES) + +Indicate whether the given equation type (Equation, Pair, etc) supports `collect_vars!`. +Can be dispatched by higher-level libraries to indicate support. +""" +eqtype_supports_collect_vars(eq) = false +eqtype_supports_collect_vars(eq::Equation) = true +eqtype_supports_collect_vars(eq::Pair) = true + +function collect_vars!(unknowns, parameters, eq::Equation, iv; + depth = 0, op = Differential) + collect_vars!(unknowns, parameters, eq.lhs, iv; depth, op) + collect_vars!(unknowns, parameters, eq.rhs, iv; depth, op) + return nothing +end + +function collect_vars!(unknowns, parameters, p::Pair, iv; depth = 0, op = Differential) + collect_vars!(unknowns, parameters, p[1], iv; depth, op) + collect_vars!(unknowns, parameters, p[2], iv; depth, op) + return nothing +end + function collect_var!(unknowns, parameters, var, iv; depth = 0) isequal(var, iv) && return nothing check_scope_depth(getmetadata(var, SymScope, LocalScope()), depth) || return nothing