Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/structural_transformation/StructuralTransformations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ using DocStringExtensions
export tearing, dae_index_lowering, check_consistency
export dummy_derivative
export sorted_incidence_matrix,
pantelides!, pantelides_reassemble, tearing_reassemble, find_solvables!,
pantelides!, pantelides_reassemble, find_solvables!,
linear_subsys_adjmat!
export tearing_substitution
export torn_system_jacobian_sparsity
Expand All @@ -69,9 +69,9 @@ export computed_highest_diff_variables
export shift2term, lower_shift_varname, simplify_shifts, distribute_shift

include("utils.jl")
include("tearing.jl")
include("pantelides.jl")
include("bipartite_tearing/modia_tearing.jl")
include("tearing.jl")
include("symbolics_tearing.jl")
include("partial_state_selection.jl")
include("codegen.jl")
Expand Down
25 changes: 16 additions & 9 deletions src/structural_transformation/bipartite_tearing/modia_tearing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -62,19 +62,22 @@ function tear_graph_block_modia!(var_eq_matching, ict, solvable_graph, eqs, vars
return nothing
end

function build_var_eq_matching(structure::SystemStructure, ::Type{U} = Unassigned;
varfilter::F2 = v -> true, eqfilter::F3 = eq -> true) where {U, F2, F3}
function build_var_eq_matching(structure::SystemStructure;
varfilter::F2, eqfilter::F3) where {F2, F3}
@unpack graph, solvable_graph = structure
var_eq_matching = maximal_matching(graph, eqfilter, varfilter, U)
var_eq_matching = maximal_matching(graph, eqfilter, varfilter, MatchedVarT)
matching_len = max(length(var_eq_matching),
maximum(x -> x isa Int ? x : 0, var_eq_matching, init = 0))
return complete(var_eq_matching, matching_len), matching_len
end

function tear_graph_modia(structure::SystemStructure, isder::F = nothing,
::Type{U} = Unassigned;
varfilter::F2 = v -> true,
eqfilter::F3 = eq -> true) where {F, U, F2, F3}
@kwdef struct ModiaTearing{F, F2, F3}
isder::F = nothing
varfilter::F2 = Returns(true)
eqfilter::F3 = Returns(true)
end

function (alg::ModiaTearing)(structure::SystemStructure)
# It would be possible here to simply iterate over all variables and attempt to
# use tearEquations! to produce a matching that greedily selects the minimal
# number of torn variables. However, we can do this process faster if we first
Expand All @@ -86,8 +89,11 @@ function tear_graph_modia(structure::SystemStructure, isder::F = nothing,
# to have optimal solutions that cannot be found by this process. We will not
# find them here [TODO: It would be good to have an explicit example of this.]

isder = alg.isder
varfilter = alg.varfilter
eqfilter = alg.eqfilter
@unpack graph, solvable_graph = structure
var_eq_matching, matching_len = build_var_eq_matching(structure, U; varfilter, eqfilter)
var_eq_matching, matching_len = build_var_eq_matching(structure; varfilter, eqfilter)
full_var_eq_matching = copy(var_eq_matching)
var_sccs = find_var_sccs(graph, var_eq_matching)
vargraph = DiCMOBiGraph{true}(graph, 0, Matching(matching_len))
Expand Down Expand Up @@ -126,5 +132,6 @@ function tear_graph_modia(structure::SystemStructure, isder::F = nothing,
tear_graph_block_modia!(var_eq_matching, ict, solvable_graph, free_eqs,
BitSet(free_vars), isder)
end
return var_eq_matching, full_var_eq_matching, var_sccs

return TearingResult(var_eq_matching, full_var_eq_matching, var_sccs), (;)
end
30 changes: 16 additions & 14 deletions src/structural_transformation/partial_state_selection.jl
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
struct SelectedState end

function dummy_derivative_graph!(state::TransformationState, jac = nothing;
state_priority = nothing, log = Val(false), kwargs...)
state.structure.solvable_graph === nothing && find_solvables!(state; kwargs...)
complete!(state.structure)
var_eq_matching = complete(pantelides!(state; kwargs...))
dummy_derivative_graph!(state.structure, var_eq_matching, jac, state_priority, log)
dummy_derivative_graph!(state.structure, var_eq_matching, jac, state_priority, log; kwargs...)
end

struct DummyDerivativeSummary
Expand All @@ -15,7 +13,8 @@ end

function dummy_derivative_graph!(
structure::SystemStructure, var_eq_matching, jac = nothing,
state_priority = nothing, ::Val{log} = Val(false)) where {log}
state_priority = nothing, ::Val{log} = Val(false);
tearing_alg::TearingAlgorithm = DummyDerivativeTearing(), kwargs...) where {log}
@unpack eq_to_diff, var_to_diff, graph = structure
diff_to_eq = invview(eq_to_diff)
diff_to_var = invview(var_to_diff)
Expand Down Expand Up @@ -173,8 +172,9 @@ function dummy_derivative_graph!(
@warn "The number of dummy derivatives ($n_dummys) does not match the number of differentiated equations ($n_diff_eqs)."
end

ret = tearing_with_dummy_derivatives(structure, BitSet(dummy_derivatives))
(ret..., DummyDerivativeSummary(var_dummy_scc, var_state_priority))
tearing_result, extra = tearing_alg(structure, BitSet(dummy_derivatives))
extra = (; extra..., ddsummary = DummyDerivativeSummary(var_dummy_scc, var_state_priority))
return tearing_result, extra
end

function is_present(structure, v)::Bool
Expand All @@ -201,7 +201,9 @@ function isdiffed((structure, dummy_derivatives), v)::Bool
diff_to_var[v] !== nothing && is_some_diff(structure, dummy_derivatives, v)
end

function tearing_with_dummy_derivatives(structure, dummy_derivatives)
struct DummyDerivativeTearing <: TearingAlgorithm end

function (::DummyDerivativeTearing)(structure::SystemStructure, dummy_derivatives::Union{BitSet, Tuple{}} = ())
@unpack var_to_diff = structure
# We can eliminate variables that are not selected (differential
# variables). Selected unknowns are differentiated variables that are not
Expand All @@ -213,18 +215,18 @@ function tearing_with_dummy_derivatives(structure, dummy_derivatives)
can_eliminate[v] = true
end
end
var_eq_matching, full_var_eq_matching,
var_sccs = tear_graph_modia(structure,
Base.Fix1(isdiffed, (structure, dummy_derivatives)),
Union{Unassigned, SelectedState};
varfilter = Base.Fix1(getindex, can_eliminate))
modia_tearing = ModiaTearing(;
isder = Base.Fix1(isdiffed, (structure, dummy_derivatives)),
varfilter = Base.Fix1(getindex, can_eliminate)
)
tearing_result, _ = modia_tearing(structure)

for v in 𝑑vertices(structure.graph)
is_present(structure, v) || continue
dv = var_to_diff[v]
(dv === nothing || !is_some_diff(structure, dummy_derivatives, dv)) && continue
var_eq_matching[v] = SelectedState()
tearing_result.var_eq_matching[v] = SelectedState()
end

return var_eq_matching, full_var_eq_matching, var_sccs, can_eliminate
return tearing_result, (; can_eliminate)
end
42 changes: 23 additions & 19 deletions src/structural_transformation/symbolics_tearing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1047,9 +1047,15 @@ differential variables.
- `var_sccs`: The topologically sorted strongly connected components of the system
according to `full_var_eq_matching`.
"""
function tearing_reassemble(state::TearingState, var_eq_matching::Matching,
full_var_eq_matching::Matching, var_sccs::Vector{Vector{Int}}; simplify = false, mm,
array_hack = true, fully_determined = true)
@kwdef struct DefaultReassembleAlgorithm <: ReassembleAlgorithm
simplify::Bool = false
array_hack::Bool = true
end

function (alg::DefaultReassembleAlgorithm)(state::TearingState, tearing_result::TearingResult, mm::Union{SparseMatrixCLIL, Nothing}; fully_determined::Bool = true, kw...)
@unpack simplify, array_hack = alg
@unpack var_eq_matching, full_var_eq_matching, var_sccs = tearing_result

extra_eqs_vars = get_extra_eqs_vars(
state, var_eq_matching, full_var_eq_matching, fully_determined)
neweqs = collect(equations(state))
Expand Down Expand Up @@ -1314,25 +1320,25 @@ end
ndims = ndims(arr)
end

function tearing(state::TearingState; kwargs...)
function tearing(state::TearingState; tearing_alg::TearingAlgorithm = DummyDerivativeTearing(),
kwargs...)
state.structure.solvable_graph === nothing && find_solvables!(state; kwargs...)
complete!(state.structure)
tearing_with_dummy_derivatives(state.structure, ())
tearing_alg(state.structure)
end

"""
tearing(sys; simplify=false)
tearing(sys)

Tear the nonlinear equations in system. When `simplify=true`, we simplify the
new residual equations after tearing. End users are encouraged to call [`mtkcompile`](@ref)
instead, which calls this function internally.
"""
function tearing(sys::AbstractSystem, state = TearingState(sys); mm = nothing,
simplify = false, array_hack = true, fully_determined = true, kwargs...)
var_eq_matching, full_var_eq_matching, var_sccs, can_eliminate = tearing(state)
invalidate_cache!(tearing_reassemble(
state, var_eq_matching, full_var_eq_matching, var_sccs; mm,
simplify, array_hack, fully_determined))
reassemble_alg::ReassembleAlgorithm = DefaultReassembleAlgorithm(),
fully_determined = true, kwargs...)
tearing_result, extras = tearing(state; kwargs...)
invalidate_cache!(reassemble_alg(state, tearing_result, mm; fully_determined))
end

"""
Expand All @@ -1341,8 +1347,9 @@ end
Perform index reduction and use the dummy derivative technique to ensure that
the system is balanced.
"""
function dummy_derivative(sys, state = TearingState(sys); simplify = false,
mm = nothing, array_hack = true, fully_determined = true, kwargs...)
function dummy_derivative(sys, state = TearingState(sys);
reassemble_alg::ReassembleAlgorithm = DefaultReassembleAlgorithm(),
mm = nothing, fully_determined = true, kwargs...)
jac = let state = state
(eqs, vars) -> begin
symeqs = EquationsView(state)[eqs]
Expand All @@ -1364,10 +1371,7 @@ function dummy_derivative(sys, state = TearingState(sys); simplify = false,
p
end
end
var_eq_matching, full_var_eq_matching, var_sccs,
can_eliminate, summary = dummy_derivative_graph!(
state, jac; state_priority,
kwargs...)
tearing_reassemble(state, var_eq_matching, full_var_eq_matching, var_sccs;
simplify, mm, array_hack, fully_determined)
tearing_result, extras = dummy_derivative_graph!(
state, jac; state_priority, kwargs...)
reassemble_alg(state, tearing_result, mm; fully_determined)
end
62 changes: 62 additions & 0 deletions src/structural_transformation/tearing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -83,3 +83,65 @@ function free_equations(graph, vars_scc, var_eq_matching, varfilter::F) where {F
end
findall(!, seen_eqs)
end

struct SelectedState end
const MatchingT{T} = Matching{T, Vector{Union{T, Int}}}
const MatchedVarT = Union{Unassigned, SelectedState}
const VarEqMatchingT = MatchingT{MatchedVarT}

"""
$TYPEDEF

A struct containing the results of tearing.

# Fields

$TYPEDFIELDS
"""
struct TearingResult
"""
The variable-equation matching. Differential variables are matched to `SelectedState`.
The derivative of a differential variable is matched to the corresponding differential
equation. Solved variables are matched to the equation they are solved from. Algebraic
variables are matched to `unassigned`.
"""
var_eq_matching::VarEqMatchingT
"""
The variable-equation matching prior to tearing. This is the maximal matching used to
compute `var_sccs` (see below). For generating the torn system, `var_eq_matching` is
the source of truth. This should only be used to identify algebraic equations in each
SCC.
"""
full_var_eq_matching::VarEqMatchingT
"""
The partitioning of variables into strongly connected components (SCCs). The SCCs are
sorted in dependency order, so each SCC depends on variables in previous SCCs.
"""
var_sccs::Vector{Vector{Int}}
end

"""
$TYPEDEF

Supertype for all tearing algorithms. A tearing algorithm takes as input the
`SystemStructure` along with any other necessary arguments.

The output of a tearing algorithm must be a `TearingResult` and a `NamedTuple` of
any additional data computed in the process that may be useful for further processing.
"""
abstract type TearingAlgorithm end

"""
$TYPEDEF

Supertype for all reassembling algorithms. A reassembling algorithm takes as input the
`TearingState`, `TearingResult` and integer incidence matrix `mm::SparseMatrixCLIL`. The
matrix `mm` may be `nothing`. The algorithm must also accept arbitrary keyword arguments.
The following keyword arguments will always be provided:
- `fully_determined::Bool`: flag indicating whether the system is fully determined.

The output of a reassembling algorithm must be the torn system.

A reassemble algorithm must also implement `with_fully_determined`
"""
abstract type ReassembleAlgorithm end
16 changes: 9 additions & 7 deletions src/systems/systems.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,14 @@ function mtkcompile(
sys::AbstractSystem; additional_passes = [], simplify = false, split = true,
allow_symbolic = false, allow_parameter = true, conservative = false, fully_determined = true,
inputs = Any[], outputs = Any[],
disturbance_inputs = Any[],
disturbance_inputs = Any[], array_hack = true,
kwargs...)
isscheduled(sys) && throw(RepeatedStructuralSimplificationError())
newsys′ = __mtkcompile(sys; simplify,
reassemble_alg = get(kwargs, :reassemble_alg,
StructuralTransformations.DefaultReassembleAlgorithm(; simplify, array_hack))
newsys′ = __mtkcompile(sys;
allow_symbolic, allow_parameter, conservative, fully_determined,
inputs, outputs, disturbance_inputs, additional_passes,
inputs, outputs, disturbance_inputs, additional_passes, reassemble_alg,
kwargs...)
if newsys′ isa Tuple
@assert length(newsys′) == 2
Expand All @@ -59,7 +61,7 @@ function mtkcompile(
end
end

function __mtkcompile(sys::AbstractSystem; simplify = false,
function __mtkcompile(sys::AbstractSystem;
inputs = Any[], outputs = Any[],
disturbance_inputs = Any[],
sort_eqs = true,
Expand All @@ -72,7 +74,7 @@ function __mtkcompile(sys::AbstractSystem; simplify = false,
return sys
end
if isempty(equations(sys)) && !is_time_dependent(sys) && !_iszero(cost(sys))
return simplify_optimization_system(sys; kwargs..., sort_eqs, simplify)
return simplify_optimization_system(sys; kwargs..., sort_eqs)
end

sys, statemachines = extract_top_level_statemachines(sys)
Expand All @@ -94,7 +96,7 @@ function __mtkcompile(sys::AbstractSystem; simplify = false,
end
if isempty(brown_vars)
return mtkcompile!(
state; simplify, inputs, outputs, disturbance_inputs, kwargs...)
state; inputs, outputs, disturbance_inputs, kwargs...)
else
Is = Int[]
Js = Int[]
Expand Down Expand Up @@ -129,7 +131,7 @@ function __mtkcompile(sys::AbstractSystem; simplify = false,
if !iszero(new_idxs[i]) &&
invview(var_to_diff)[i] === nothing]
ode_sys = mtkcompile(
sys; simplify, inputs, outputs, disturbance_inputs, kwargs...)
sys; inputs, outputs, disturbance_inputs, kwargs...)
eqs = equations(ode_sys)
sorted_g_rows = zeros(Num, length(eqs), size(g, 2))
for (i, eq) in enumerate(eqs)
Expand Down
Loading
Loading