diff --git a/src/structural_transformation/StructuralTransformations.jl b/src/structural_transformation/StructuralTransformations.jl index 681025cb81..1cd4c7615f 100644 --- a/src/structural_transformation/StructuralTransformations.jl +++ b/src/structural_transformation/StructuralTransformations.jl @@ -59,7 +59,7 @@ using DocStringExtensions export tearing, dae_index_lowering, check_consistency export dummy_derivative export sorted_incidence_matrix, - pantelides!, pantelides_reassemble, tearing_reassemble, find_solvables!, + pantelides!, pantelides_reassemble, find_solvables!, linear_subsys_adjmat! export tearing_substitution export torn_system_jacobian_sparsity @@ -69,9 +69,9 @@ export computed_highest_diff_variables export shift2term, lower_shift_varname, simplify_shifts, distribute_shift include("utils.jl") +include("tearing.jl") include("pantelides.jl") include("bipartite_tearing/modia_tearing.jl") -include("tearing.jl") include("symbolics_tearing.jl") include("partial_state_selection.jl") include("codegen.jl") diff --git a/src/structural_transformation/bipartite_tearing/modia_tearing.jl b/src/structural_transformation/bipartite_tearing/modia_tearing.jl index 5da873afdf..b931c61137 100644 --- a/src/structural_transformation/bipartite_tearing/modia_tearing.jl +++ b/src/structural_transformation/bipartite_tearing/modia_tearing.jl @@ -62,19 +62,22 @@ function tear_graph_block_modia!(var_eq_matching, ict, solvable_graph, eqs, vars return nothing end -function build_var_eq_matching(structure::SystemStructure, ::Type{U} = Unassigned; - varfilter::F2 = v -> true, eqfilter::F3 = eq -> true) where {U, F2, F3} +function build_var_eq_matching(structure::SystemStructure; + varfilter::F2, eqfilter::F3) where {F2, F3} @unpack graph, solvable_graph = structure - var_eq_matching = maximal_matching(graph, eqfilter, varfilter, U) + var_eq_matching = maximal_matching(graph, eqfilter, varfilter, MatchedVarT) matching_len = max(length(var_eq_matching), maximum(x -> x isa Int ? x : 0, var_eq_matching, init = 0)) return complete(var_eq_matching, matching_len), matching_len end -function tear_graph_modia(structure::SystemStructure, isder::F = nothing, - ::Type{U} = Unassigned; - varfilter::F2 = v -> true, - eqfilter::F3 = eq -> true) where {F, U, F2, F3} +@kwdef struct ModiaTearing{F, F2, F3} + isder::F = nothing + varfilter::F2 = Returns(true) + eqfilter::F3 = Returns(true) +end + +function (alg::ModiaTearing)(structure::SystemStructure) # It would be possible here to simply iterate over all variables and attempt to # use tearEquations! to produce a matching that greedily selects the minimal # number of torn variables. However, we can do this process faster if we first @@ -86,8 +89,11 @@ function tear_graph_modia(structure::SystemStructure, isder::F = nothing, # to have optimal solutions that cannot be found by this process. We will not # find them here [TODO: It would be good to have an explicit example of this.] + isder = alg.isder + varfilter = alg.varfilter + eqfilter = alg.eqfilter @unpack graph, solvable_graph = structure - var_eq_matching, matching_len = build_var_eq_matching(structure, U; varfilter, eqfilter) + var_eq_matching, matching_len = build_var_eq_matching(structure; varfilter, eqfilter) full_var_eq_matching = copy(var_eq_matching) var_sccs = find_var_sccs(graph, var_eq_matching) vargraph = DiCMOBiGraph{true}(graph, 0, Matching(matching_len)) @@ -126,5 +132,6 @@ function tear_graph_modia(structure::SystemStructure, isder::F = nothing, tear_graph_block_modia!(var_eq_matching, ict, solvable_graph, free_eqs, BitSet(free_vars), isder) end - return var_eq_matching, full_var_eq_matching, var_sccs + + return TearingResult(var_eq_matching, full_var_eq_matching, var_sccs), (;) end diff --git a/src/structural_transformation/partial_state_selection.jl b/src/structural_transformation/partial_state_selection.jl index ab8e7f0f3d..98e40e9988 100644 --- a/src/structural_transformation/partial_state_selection.jl +++ b/src/structural_transformation/partial_state_selection.jl @@ -1,11 +1,9 @@ -struct SelectedState end - function dummy_derivative_graph!(state::TransformationState, jac = nothing; state_priority = nothing, log = Val(false), kwargs...) state.structure.solvable_graph === nothing && find_solvables!(state; kwargs...) complete!(state.structure) var_eq_matching = complete(pantelides!(state; kwargs...)) - dummy_derivative_graph!(state.structure, var_eq_matching, jac, state_priority, log) + dummy_derivative_graph!(state.structure, var_eq_matching, jac, state_priority, log; kwargs...) end struct DummyDerivativeSummary @@ -15,7 +13,8 @@ end function dummy_derivative_graph!( structure::SystemStructure, var_eq_matching, jac = nothing, - state_priority = nothing, ::Val{log} = Val(false)) where {log} + state_priority = nothing, ::Val{log} = Val(false); + tearing_alg::TearingAlgorithm = DummyDerivativeTearing(), kwargs...) where {log} @unpack eq_to_diff, var_to_diff, graph = structure diff_to_eq = invview(eq_to_diff) diff_to_var = invview(var_to_diff) @@ -173,8 +172,9 @@ function dummy_derivative_graph!( @warn "The number of dummy derivatives ($n_dummys) does not match the number of differentiated equations ($n_diff_eqs)." end - ret = tearing_with_dummy_derivatives(structure, BitSet(dummy_derivatives)) - (ret..., DummyDerivativeSummary(var_dummy_scc, var_state_priority)) + tearing_result, extra = tearing_alg(structure, BitSet(dummy_derivatives)) + extra = (; extra..., ddsummary = DummyDerivativeSummary(var_dummy_scc, var_state_priority)) + return tearing_result, extra end function is_present(structure, v)::Bool @@ -201,7 +201,9 @@ function isdiffed((structure, dummy_derivatives), v)::Bool diff_to_var[v] !== nothing && is_some_diff(structure, dummy_derivatives, v) end -function tearing_with_dummy_derivatives(structure, dummy_derivatives) +struct DummyDerivativeTearing <: TearingAlgorithm end + +function (::DummyDerivativeTearing)(structure::SystemStructure, dummy_derivatives::Union{BitSet, Tuple{}} = ()) @unpack var_to_diff = structure # We can eliminate variables that are not selected (differential # variables). Selected unknowns are differentiated variables that are not @@ -213,18 +215,18 @@ function tearing_with_dummy_derivatives(structure, dummy_derivatives) can_eliminate[v] = true end end - var_eq_matching, full_var_eq_matching, - var_sccs = tear_graph_modia(structure, - Base.Fix1(isdiffed, (structure, dummy_derivatives)), - Union{Unassigned, SelectedState}; - varfilter = Base.Fix1(getindex, can_eliminate)) + modia_tearing = ModiaTearing(; + isder = Base.Fix1(isdiffed, (structure, dummy_derivatives)), + varfilter = Base.Fix1(getindex, can_eliminate) + ) + tearing_result, _ = modia_tearing(structure) for v in 𝑑vertices(structure.graph) is_present(structure, v) || continue dv = var_to_diff[v] (dv === nothing || !is_some_diff(structure, dummy_derivatives, dv)) && continue - var_eq_matching[v] = SelectedState() + tearing_result.var_eq_matching[v] = SelectedState() end - return var_eq_matching, full_var_eq_matching, var_sccs, can_eliminate + return tearing_result, (; can_eliminate) end diff --git a/src/structural_transformation/symbolics_tearing.jl b/src/structural_transformation/symbolics_tearing.jl index e079ba30b5..540944bb6a 100644 --- a/src/structural_transformation/symbolics_tearing.jl +++ b/src/structural_transformation/symbolics_tearing.jl @@ -1047,9 +1047,15 @@ differential variables. - `var_sccs`: The topologically sorted strongly connected components of the system according to `full_var_eq_matching`. """ -function tearing_reassemble(state::TearingState, var_eq_matching::Matching, - full_var_eq_matching::Matching, var_sccs::Vector{Vector{Int}}; simplify = false, mm, - array_hack = true, fully_determined = true) +@kwdef struct DefaultReassembleAlgorithm <: ReassembleAlgorithm + simplify::Bool = false + array_hack::Bool = true +end + +function (alg::DefaultReassembleAlgorithm)(state::TearingState, tearing_result::TearingResult, mm::Union{SparseMatrixCLIL, Nothing}; fully_determined::Bool = true, kw...) + @unpack simplify, array_hack = alg + @unpack var_eq_matching, full_var_eq_matching, var_sccs = tearing_result + extra_eqs_vars = get_extra_eqs_vars( state, var_eq_matching, full_var_eq_matching, fully_determined) neweqs = collect(equations(state)) @@ -1314,25 +1320,25 @@ end ndims = ndims(arr) end -function tearing(state::TearingState; kwargs...) +function tearing(state::TearingState; tearing_alg::TearingAlgorithm = DummyDerivativeTearing(), + kwargs...) state.structure.solvable_graph === nothing && find_solvables!(state; kwargs...) complete!(state.structure) - tearing_with_dummy_derivatives(state.structure, ()) + tearing_alg(state.structure) end """ - tearing(sys; simplify=false) + tearing(sys) Tear the nonlinear equations in system. When `simplify=true`, we simplify the new residual equations after tearing. End users are encouraged to call [`mtkcompile`](@ref) instead, which calls this function internally. """ function tearing(sys::AbstractSystem, state = TearingState(sys); mm = nothing, - simplify = false, array_hack = true, fully_determined = true, kwargs...) - var_eq_matching, full_var_eq_matching, var_sccs, can_eliminate = tearing(state) - invalidate_cache!(tearing_reassemble( - state, var_eq_matching, full_var_eq_matching, var_sccs; mm, - simplify, array_hack, fully_determined)) + reassemble_alg::ReassembleAlgorithm = DefaultReassembleAlgorithm(), + fully_determined = true, kwargs...) + tearing_result, extras = tearing(state; kwargs...) + invalidate_cache!(reassemble_alg(state, tearing_result, mm; fully_determined)) end """ @@ -1341,8 +1347,9 @@ end Perform index reduction and use the dummy derivative technique to ensure that the system is balanced. """ -function dummy_derivative(sys, state = TearingState(sys); simplify = false, - mm = nothing, array_hack = true, fully_determined = true, kwargs...) +function dummy_derivative(sys, state = TearingState(sys); + reassemble_alg::ReassembleAlgorithm = DefaultReassembleAlgorithm(), + mm = nothing, fully_determined = true, kwargs...) jac = let state = state (eqs, vars) -> begin symeqs = EquationsView(state)[eqs] @@ -1364,10 +1371,7 @@ function dummy_derivative(sys, state = TearingState(sys); simplify = false, p end end - var_eq_matching, full_var_eq_matching, var_sccs, - can_eliminate, summary = dummy_derivative_graph!( - state, jac; state_priority, - kwargs...) - tearing_reassemble(state, var_eq_matching, full_var_eq_matching, var_sccs; - simplify, mm, array_hack, fully_determined) + tearing_result, extras = dummy_derivative_graph!( + state, jac; state_priority, kwargs...) + reassemble_alg(state, tearing_result, mm; fully_determined) end diff --git a/src/structural_transformation/tearing.jl b/src/structural_transformation/tearing.jl index 67933ffe0e..77a004a823 100644 --- a/src/structural_transformation/tearing.jl +++ b/src/structural_transformation/tearing.jl @@ -83,3 +83,65 @@ function free_equations(graph, vars_scc, var_eq_matching, varfilter::F) where {F end findall(!, seen_eqs) end + +struct SelectedState end +const MatchingT{T} = Matching{T, Vector{Union{T, Int}}} +const MatchedVarT = Union{Unassigned, SelectedState} +const VarEqMatchingT = MatchingT{MatchedVarT} + +""" + $TYPEDEF + +A struct containing the results of tearing. + +# Fields + +$TYPEDFIELDS +""" +struct TearingResult + """ + The variable-equation matching. Differential variables are matched to `SelectedState`. + The derivative of a differential variable is matched to the corresponding differential + equation. Solved variables are matched to the equation they are solved from. Algebraic + variables are matched to `unassigned`. + """ + var_eq_matching::VarEqMatchingT + """ + The variable-equation matching prior to tearing. This is the maximal matching used to + compute `var_sccs` (see below). For generating the torn system, `var_eq_matching` is + the source of truth. This should only be used to identify algebraic equations in each + SCC. + """ + full_var_eq_matching::VarEqMatchingT + """ + The partitioning of variables into strongly connected components (SCCs). The SCCs are + sorted in dependency order, so each SCC depends on variables in previous SCCs. + """ + var_sccs::Vector{Vector{Int}} +end + +""" + $TYPEDEF + +Supertype for all tearing algorithms. A tearing algorithm takes as input the +`SystemStructure` along with any other necessary arguments. + +The output of a tearing algorithm must be a `TearingResult` and a `NamedTuple` of +any additional data computed in the process that may be useful for further processing. +""" +abstract type TearingAlgorithm end + +""" + $TYPEDEF + +Supertype for all reassembling algorithms. A reassembling algorithm takes as input the +`TearingState`, `TearingResult` and integer incidence matrix `mm::SparseMatrixCLIL`. The +matrix `mm` may be `nothing`. The algorithm must also accept arbitrary keyword arguments. +The following keyword arguments will always be provided: +- `fully_determined::Bool`: flag indicating whether the system is fully determined. + +The output of a reassembling algorithm must be the torn system. + +A reassemble algorithm must also implement `with_fully_determined` +""" +abstract type ReassembleAlgorithm end diff --git a/src/systems/systems.jl b/src/systems/systems.jl index 49efb3e1d4..9173be44fc 100644 --- a/src/systems/systems.jl +++ b/src/systems/systems.jl @@ -31,12 +31,14 @@ function mtkcompile( sys::AbstractSystem; additional_passes = [], simplify = false, split = true, allow_symbolic = false, allow_parameter = true, conservative = false, fully_determined = true, inputs = Any[], outputs = Any[], - disturbance_inputs = Any[], + disturbance_inputs = Any[], array_hack = true, kwargs...) isscheduled(sys) && throw(RepeatedStructuralSimplificationError()) - newsys′ = __mtkcompile(sys; simplify, + reassemble_alg = get(kwargs, :reassemble_alg, + StructuralTransformations.DefaultReassembleAlgorithm(; simplify, array_hack)) + newsys′ = __mtkcompile(sys; allow_symbolic, allow_parameter, conservative, fully_determined, - inputs, outputs, disturbance_inputs, additional_passes, + inputs, outputs, disturbance_inputs, additional_passes, reassemble_alg, kwargs...) if newsys′ isa Tuple @assert length(newsys′) == 2 @@ -59,7 +61,7 @@ function mtkcompile( end end -function __mtkcompile(sys::AbstractSystem; simplify = false, +function __mtkcompile(sys::AbstractSystem; inputs = Any[], outputs = Any[], disturbance_inputs = Any[], sort_eqs = true, @@ -72,7 +74,7 @@ function __mtkcompile(sys::AbstractSystem; simplify = false, return sys end if isempty(equations(sys)) && !is_time_dependent(sys) && !_iszero(cost(sys)) - return simplify_optimization_system(sys; kwargs..., sort_eqs, simplify) + return simplify_optimization_system(sys; kwargs..., sort_eqs) end sys, statemachines = extract_top_level_statemachines(sys) @@ -94,7 +96,7 @@ function __mtkcompile(sys::AbstractSystem; simplify = false, end if isempty(brown_vars) return mtkcompile!( - state; simplify, inputs, outputs, disturbance_inputs, kwargs...) + state; inputs, outputs, disturbance_inputs, kwargs...) else Is = Int[] Js = Int[] @@ -129,7 +131,7 @@ function __mtkcompile(sys::AbstractSystem; simplify = false, if !iszero(new_idxs[i]) && invview(var_to_diff)[i] === nothing] ode_sys = mtkcompile( - sys; simplify, inputs, outputs, disturbance_inputs, kwargs...) + sys; inputs, outputs, disturbance_inputs, kwargs...) eqs = equations(ode_sys) sorted_g_rows = zeros(Num, length(eqs), size(g, 2)) for (i, eq) in enumerate(eqs) diff --git a/src/systems/systemstructure.jl b/src/systems/systemstructure.jl index c867306c34..d05fc74dd0 100644 --- a/src/systems/systemstructure.jl +++ b/src/systems/systemstructure.jl @@ -920,13 +920,13 @@ function make_eqs_zero_equals!(ts::TearingState) copyto!(get_eqs(ts.sys), neweqs) end -function mtkcompile!(state::TearingState; simplify = false, +function mtkcompile!(state::TearingState; check_consistency = true, fully_determined = true, warn_initialize_determined = true, inputs = Any[], outputs = Any[], disturbance_inputs = Any[], kwargs...) if !is_time_dependent(state.sys) - return _mtkcompile!(state; simplify, check_consistency, + return _mtkcompile!(state; check_consistency, inputs, outputs, disturbance_inputs, fully_determined, kwargs...) end @@ -956,7 +956,7 @@ function mtkcompile!(state::TearingState; simplify = false, if length(tss) > 1 make_eqs_zero_equals!(tss[continuous_id]) # simplify as normal - sys = _mtkcompile!(tss[continuous_id]; simplify, + sys = _mtkcompile!(tss[continuous_id]; inputs = [inputs; clocked_inputs[continuous_id]], outputs, disturbance_inputs, check_consistency, fully_determined, kwargs...) @@ -986,13 +986,13 @@ function mtkcompile!(state::TearingState; simplify = false, state.sys = sys end - sys = _mtkcompile!(state; simplify, check_consistency, + sys = _mtkcompile!(state; check_consistency, inputs, outputs, disturbance_inputs, fully_determined, kwargs...) return sys end -function _mtkcompile!(state::TearingState; simplify = false, +function _mtkcompile!(state::TearingState; check_consistency = true, fully_determined = true, warn_initialize_determined = false, dummy_derivative = true, inputs = Any[], outputs = Any[], @@ -1014,17 +1014,17 @@ function _mtkcompile!(state::TearingState; simplify = false, end if fully_determined && dummy_derivative sys = ModelingToolkit.dummy_derivative( - sys, state; simplify, mm, check_consistency, kwargs...) + sys, state; mm, check_consistency, kwargs...) elseif fully_determined var_eq_matching = pantelides!(state; finalize = false, kwargs...) sys = pantelides_reassemble(state, var_eq_matching) state = TearingState(sys) sys, mm = ModelingToolkit.alias_elimination!(state; fully_determined, kwargs...) sys = ModelingToolkit.dummy_derivative( - sys, state; simplify, mm, check_consistency, fully_determined, kwargs...) + sys, state; mm, check_consistency, fully_determined, kwargs...) else sys = ModelingToolkit.tearing( - sys, state; simplify, mm, check_consistency, fully_determined, kwargs...) + sys, state; mm, check_consistency, fully_determined, kwargs...) end fullunknowns = [observables(sys); unknowns(sys)] @set! sys.observed = ModelingToolkit.topsort_equations(observed(sys), fullunknowns) diff --git a/test/structural_transformation/tearing.jl b/test/structural_transformation/tearing.jl index 4025f7a298..d6e76918cb 100644 --- a/test/structural_transformation/tearing.jl +++ b/test/structural_transformation/tearing.jl @@ -97,8 +97,8 @@ newsys = tearing(sys) # e5 [ 1 1 | 1 ] let state = TearingState(sys) - torn_matching, = tearing(state) - S = StructuralTransformations.reordered_matrix(sys, torn_matching) + result, = tearing(state) + S = StructuralTransformations.reordered_matrix(sys, result.var_eq_matching) @test S == [1 0 0 0 1 1 1 0 0 0 1 1 1 0 0