diff --git a/src/structural_transformation/StructuralTransformations.jl b/src/structural_transformation/StructuralTransformations.jl index 1220d517cc..f0124d7f4b 100644 --- a/src/structural_transformation/StructuralTransformations.jl +++ b/src/structural_transformation/StructuralTransformations.jl @@ -11,7 +11,8 @@ using SymbolicUtils: maketerm, iscall using ModelingToolkit using ModelingToolkit: ODESystem, AbstractSystem, var_from_nested_derivative, Differential, - unknowns, equations, vars, Symbolic, diff2term_with_unit, value, + unknowns, equations, vars, Symbolic, diff2term_with_unit, + shift2term_with_unit, value, operation, arguments, Sym, Term, simplify, symbolic_linear_solve, isdiffeq, isdifferential, isirreducible, empty_substitutions, get_substitutions, @@ -22,7 +23,8 @@ using ModelingToolkit: ODESystem, AbstractSystem, var_from_nested_derivative, Di get_postprocess_fbody, vars!, IncrementalCycleTracker, add_edge_checked!, topological_sort, invalidate_cache!, Substitutions, get_or_construct_tearing_state, - filter_kwargs, lower_varname_with_unit, setio, SparseMatrixCLIL, + filter_kwargs, lower_varname_with_unit, + lower_shift_varname_with_unit, setio, SparseMatrixCLIL, get_fullvars, has_equations, observed, Schedule, schedule @@ -63,6 +65,7 @@ export torn_system_jacobian_sparsity export full_equations export but_ordered_incidence, lowest_order_variable_mask, highest_order_variable_mask export computed_highest_diff_variables +export shift2term, lower_shift_varname include("utils.jl") include("pantelides.jl") diff --git a/src/structural_transformation/symbolics_tearing.jl b/src/structural_transformation/symbolics_tearing.jl index fe3ce45430..9818bba361 100644 --- a/src/structural_transformation/symbolics_tearing.jl +++ b/src/structural_transformation/symbolics_tearing.jl @@ -237,49 +237,23 @@ function check_diff_graph(var_to_diff, fullvars) end =# -function tearing_reassemble(state::TearingState, var_eq_matching, - full_var_eq_matching = nothing; simplify = false, mm = nothing, cse_hack = true, array_hack = true) - @unpack fullvars, sys, structure = state - @unpack solvable_graph, var_to_diff, eq_to_diff, graph = structure - extra_vars = Int[] - if full_var_eq_matching !== nothing - for v in 𝑑vertices(state.structure.graph) - eq = full_var_eq_matching[v] - eq isa Int && continue - push!(extra_vars, v) - end - end +""" +Replace derivatives of non-selected unknown variables by dummy derivatives. - neweqs = collect(equations(state)) - # Terminology and Definition: - # - # A general DAE is in the form of `F(u'(t), u(t), p, t) == 0`. We can - # characterize variables in `u(t)` into two classes: differential variables - # (denoted `v(t)`) and algebraic variables (denoted `z(t)`). Differential - # variables are marked as `SelectedState` and they are differentiated in the - # DAE system, i.e. `v'(t)` are all the variables in `u'(t)` that actually - # appear in the system. Algebraic variables are variables that are not - # differential variables. - # - # Dummy derivatives may determine that some differential variables are - # algebraic variables in disguise. The derivative of such variables are - # called dummy derivatives. - - # Step 1: - # Replace derivatives of non-selected unknown variables by dummy derivatives +State selection may determine that some differential variables are +algebraic variables in disguise. The derivative of such variables are +called dummy derivatives. - if ModelingToolkit.has_iv(state.sys) - iv = get_iv(state.sys) - if is_only_discrete(state.structure) - D = Shift(iv, 1) - else - D = Differential(iv) - end - else - iv = D = nothing - end +`SelectedState` information is no longer needed after this function is called. +State selection is done. All non-differentiated variables are algebraic +variables, and all variables that appear differentiated are differential variables. +""" +function substitute_derivatives_algevars!( + ts::TearingState, neweqs, var_eq_matching, dummy_sub; iv = nothing, D = nothing) + @unpack fullvars, sys, structure = ts + @unpack solvable_graph, var_to_diff, eq_to_diff, graph = structure diff_to_var = invview(var_to_diff) - dummy_sub = Dict() + for var in 1:length(fullvars) dv = var_to_diff[var] dv === nothing && continue @@ -310,294 +284,451 @@ function tearing_reassemble(state::TearingState, var_eq_matching, diff_to_var[dv] = nothing end end +end - # `SelectedState` information is no longer needed past here. State selection - # is done. All non-differentiated variables are algebraic variables, and all - # variables that appear differentiated are differential variables. - - ### extract partition information - is_solvable = let solvable_graph = solvable_graph - (eq, iv) -> eq isa Int && iv isa Int && BipartiteEdge(eq, iv) in solvable_graph - end - - # if var is like D(x) - isdervar = let diff_to_var = diff_to_var - var -> diff_to_var[var] !== nothing - end - var_order = let diff_to_var = diff_to_var - dv -> begin - order = 0 - while (dv′ = diff_to_var[dv]) !== nothing - order += 1 - dv = dv′ - end - order, dv - end - end - - #retear = BitSet() - # There are three cases where we want to generate new variables to convert - # the system into first order (semi-implicit) ODEs. - # - # 1. To first order: - # Whenever higher order differentiated variable like `D(D(D(x)))` appears, - # we introduce new variables `x_t`, `x_tt`, and `x_ttt` and new equations - # ``` - # D(x_tt) = x_ttt - # D(x_t) = x_tt - # D(x) = x_t - # ``` - # and replace `D(x)` to `x_t`, `D(D(x))` to `x_tt`, and `D(D(D(x)))` to - # `x_ttt`. - # - # 2. To implicit to semi-implicit ODEs: - # 2.1: Unsolvable derivative: - # If one derivative variable `D(x)` is unsolvable in all the equations it - # appears in, then we introduce a new variable `x_t`, a new equation - # ``` - # D(x) ~ x_t - # ``` - # and replace all other `D(x)` to `x_t`. - # - # 2.2: Solvable derivative: - # If one derivative variable `D(x)` is solvable in at least one of the - # equations it appears in, then we introduce a new variable `x_t`. One of - # the solvable equations must be in the form of `0 ~ L(D(x), u...)` and - # there exists a function `l` such that `D(x) ~ l(u...)`. We should replace - # it to - # ``` - # 0 ~ x_t - l(u...) - # D(x) ~ x_t - # ``` - # and replace all other `D(x)` to `x_t`. - # - # Observe that we don't need to actually introduce a new variable `x_t`, as - # the above equations can be lowered to - # ``` - # x_t := l(u...) - # D(x) ~ x_t - # ``` - # where `:=` denotes assignment. - # - # As a final note, in all the above cases where we need to introduce new - # variables and equations, don't add them when they already exist. +#= +There are three cases where we want to generate new variables to convert +the system into first order (semi-implicit) ODEs. + +1. To first order: +Whenever higher order differentiated variable like `D(D(D(x)))` appears, +we introduce new variables `x_t`, `x_tt`, and `x_ttt` and new equations +``` +D(x_tt) = x_ttt +D(x_t) = x_tt +D(x) = x_t +``` +and replace `D(x)` to `x_t`, `D(D(x))` to `x_tt`, and `D(D(D(x)))` to +`x_ttt`. + +2. To implicit to semi-implicit ODEs: +2.1: Unsolvable derivative: +If one derivative variable `D(x)` is unsolvable in all the equations it +appears in, then we introduce a new variable `x_t`, a new equation +``` +D(x) ~ x_t +``` +and replace all other `D(x)` to `x_t`. + +2.2: Solvable derivative: +If one derivative variable `D(x)` is solvable in at least one of the +equations it appears in, then we introduce a new variable `x_t`. One of +the solvable equations must be in the form of `0 ~ L(D(x), u...)` and +there exists a function `l` such that `D(x) ~ l(u...)`. We should replace +it to +``` +0 ~ x_t - l(u...) +D(x) ~ x_t +``` +and replace all other `D(x)` to `x_t`. + +Observe that we don't need to actually introduce a new variable `x_t`, as +the above equations can be lowered to +``` +x_t := l(u...) +D(x) ~ x_t +``` +where `:=` denotes assignment. + +As a final note, in all the above cases where we need to introduce new +variables and equations, don't add them when they already exist. + +###### DISCRETE SYSTEMS ####### + +Documenting the differences to structural simplification for discrete systems: + +In discrete systems everything gets shifted forward a timestep by `shift_discrete_system` +in order to properly generate the difference equations. + +In the system x(k) ~ x(k-1) + x(k-2), becomes Shift(t, 1)(x(t)) ~ x(t) + Shift(t, -1)(x(t)) + +The lowest-order term is Shift(t, k)(x(t)), instead of x(t). As such we actually want +dummy variables for the k-1 lowest order terms instead of the k-1 highest order terms. + +Shift(t, -1)(x(t)) -> x\_{t-1}(t) + +Since Shift(t, -1)(x) is not a derivative, it is directly substituted in `fullvars`. +No equation or variable is added for it. + +For ODESystems D(D(D(x))) in equations is recursively substituted as D(x) ~ x_t, D(x_t) ~ x_tt, etc. +The analogue for discrete systems, Shift(t, 1)(Shift(t,1)(Shift(t,1)(Shift(t, -3)(x(t))))) +does not actually appear. So `total_sub` in generate_system_equations` is directly +initialized with all of the lowered variables `Shift(t, -3)(x) -> x_t-3(t)`, etc. +=# +""" +Generate new derivative variables for the system. + +Effects on the system structure: +- fullvars: add the new derivative variables x_t +- neweqs: add the identity equations for the new variables, D(x) ~ x_t +- graph: update graph with the new equations and variables, and their connections +- solvable_graph: +- var_eq_matching: match D(x) to the added identity equation D(x) ~ x_t +""" +function generate_derivative_variables!( + ts::TearingState, neweqs, var_eq_matching; mm = nothing, iv = nothing, D = nothing) + @unpack fullvars, sys, structure = ts + @unpack solvable_graph, var_to_diff, eq_to_diff, graph = structure eq_var_matching = invview(var_eq_matching) + diff_to_var = invview(var_to_diff) + is_discrete = is_only_discrete(structure) linear_eqs = mm === nothing ? Dict{Int, Int}() : Dict(reverse(en) for en in enumerate(mm.nzrows)) + + # For variable x, make dummy derivative x_t if the + # derivative is in the system for v in 1:length(var_to_diff) dv = var_to_diff[v] dv isa Int || continue solved = var_eq_matching[dv] isa Int solved && continue - # check if there's `D(x) = x_t` already - local v_t, dummy_eq - for eq in 𝑑neighbors(solvable_graph, dv) - mi = get(linear_eqs, eq, 0) - iszero(mi) && continue - row = @view mm[mi, :] - nzs = nonzeros(row) - rvs = SparseArrays.nonzeroinds(row) - # note that `v_t` must not be differentiated - if length(nzs) == 2 && - (abs(nzs[1]) == 1 && nzs[1] == -nzs[2]) && - (v_t = rvs[1] == dv ? rvs[2] : rvs[1]; - diff_to_var[v_t] === nothing) - @assert dv in rvs - dummy_eq = eq - @goto FOUND_DUMMY_EQ - end + + # If there's `D(x) = x_t` already, update mappings and continue without + # adding new equations/variables + dd = find_duplicate_dd(dv, solvable_graph, diff_to_var, linear_eqs, mm) + if !isnothing(dd) + dummy_eq, v_t = dd + var_to_diff[v_t] = var_to_diff[dv] + var_eq_matching[dv] = unassigned + eq_var_matching[dummy_eq] = dv + continue end + dx = fullvars[dv] - # add `x_t` - order, lv = var_order(dv) - x_t = lower_varname_withshift(fullvars[lv], iv, order) - push!(fullvars, simplify_shifts(x_t)) - v_t = length(fullvars) - v_t_idx = add_vertex!(var_to_diff) - add_vertex!(graph, DST) - # TODO: do we care about solvable_graph? We don't use them after - # `dummy_derivative_graph`. - add_vertex!(solvable_graph, DST) - # var_eq_matching is a bit odd. - # length(var_eq_matching) == length(invview(var_eq_matching)) + order, lv = var_order(dv, diff_to_var) + x_t = is_discrete ? lower_shift_varname_with_unit(fullvars[dv], iv) : + lower_varname_with_unit(fullvars[lv], iv, order) + + # Add `x_t` to the graph + v_t = add_dd_variable!(structure, fullvars, x_t, dv) + # Add `D(x) - x_t ~ 0` to the graph + dummy_eq = add_dd_equation!(structure, neweqs, 0 ~ dx - x_t, dv, v_t) + + # Update matching push!(var_eq_matching, unassigned) - @assert v_t_idx == ndsts(graph) == ndsts(solvable_graph) == length(fullvars) == - length(var_eq_matching) - # add `D(x) - x_t ~ 0` - push!(neweqs, 0 ~ x_t - dx) - add_vertex!(graph, SRC) - dummy_eq = length(neweqs) - add_edge!(graph, dummy_eq, dv) - add_edge!(graph, dummy_eq, v_t) - add_vertex!(solvable_graph, SRC) - add_edge!(solvable_graph, dummy_eq, dv) - @assert nsrcs(graph) == nsrcs(solvable_graph) == dummy_eq - @label FOUND_DUMMY_EQ - var_to_diff[v_t] = var_to_diff[dv] var_eq_matching[dv] = unassigned eq_var_matching[dummy_eq] = dv end +end + +""" +Check if there's `D(x) ~ x_t` already. +""" +function find_duplicate_dd(dv, solvable_graph, diff_to_var, linear_eqs, mm) + for eq in 𝑑neighbors(solvable_graph, dv) + mi = get(linear_eqs, eq, 0) + iszero(mi) && continue + row = @view mm[mi, :] + nzs = nonzeros(row) + rvs = SparseArrays.nonzeroinds(row) + # note that `v_t` must not be differentiated + if length(nzs) == 2 && + (abs(nzs[1]) == 1 && nzs[1] == -nzs[2]) && + (v_t = rvs[1] == dv ? rvs[2] : rvs[1]; + diff_to_var[v_t] === nothing) + @assert dv in rvs + return eq, v_t + end + end + return nothing +end + +""" +Add a dummy derivative variable x_t corresponding to symbolic variable D(x) +which has index dv in `fullvars`. Return the new index of x_t. +""" +function add_dd_variable!(s::SystemStructure, fullvars, x_t, dv) + push!(fullvars, simplify_shifts(x_t)) + v_t = length(fullvars) + v_t_idx = add_vertex!(s.var_to_diff) + add_vertex!(s.graph, DST) + # TODO: do we care about solvable_graph? We don't use them after + # `dummy_derivative_graph`. + add_vertex!(s.solvable_graph, DST) + s.var_to_diff[v_t] = s.var_to_diff[dv] + v_t +end + +""" +Add the equation D(x) - x_t ~ 0 to `neweqs`. `dv` and `v_t` are the indices +of the higher-order derivative variable and the newly-introduced dummy +derivative variable. Return the index of the new equation in `neweqs`. +""" +function add_dd_equation!(s::SystemStructure, neweqs, eq, dv, v_t) + push!(neweqs, eq) + add_vertex!(s.graph, SRC) + dummy_eq = length(neweqs) + add_edge!(s.graph, dummy_eq, dv) + add_edge!(s.graph, dummy_eq, v_t) + add_vertex!(s.solvable_graph, SRC) + add_edge!(s.solvable_graph, dummy_eq, dv) + dummy_eq +end + +""" +Solve the equations in `neweqs` to obtain the final equations of the +system. + +For each equation of `neweqs`, do one of the following: + 1. If the equation is solvable for a differentiated variable D(x), + then solve for D(x), and add D(x) ~ sol as a differential equation + of the system. + 2. If the equation is solvable for an un-differentiated variable x, + solve for x and then add x ~ sol as a solved equation. These will + become observables. + 3. If the equation is not solvable, add it as an algebraic equation. + +Solved equations are added to `total_sub`. Occurrences of differential +or solved variables on the RHS of the final equations will get substituted. +The topological sort of the equations ensures that variables are solved for +before they appear in equations. + +Reorder the equations and unknowns to be: + [diffeqs; ...] + [diffvars; ...] +such that the mass matrix is: + [I 0 + 0 0]. + +Order the new equations and variables such that the differential equations +and variables come first. Return the new equations, the solved equations, +the new orderings, and the number of solved variables and equations. +""" +function generate_system_equations!(state::TearingState, neweqs, var_eq_matching; + simplify = false, iv = nothing, D = nothing) + @unpack fullvars, sys, structure = state + @unpack solvable_graph, var_to_diff, eq_to_diff, graph = structure + eq_var_matching = invview(var_eq_matching) + diff_to_var = invview(var_to_diff) + + total_sub = Dict() + if is_only_discrete(structure) + for (i, v) in enumerate(fullvars) + op = operation(v) + op isa Shift && (op.steps < 0) && + begin + lowered = lower_shift_varname_with_unit(v, iv) + total_sub[v] = lowered + fullvars[i] = lowered + end + end + end + + # if var is like D(x) or Shift(t, 1)(x) + isdervar = let diff_to_var = diff_to_var + var -> diff_to_var[var] !== nothing + end + + # Extract partition information + is_solvable = let solvable_graph = solvable_graph + (eq, iv) -> eq isa Int && iv isa Int && BipartiteEdge(eq, iv) in solvable_graph + end - # Will reorder equations and unknowns to be: - # [diffeqs; ...] - # [diffvars; ...] - # such that the mass matrix is: - # [I 0 - # 0 0]. - diffeq_idxs = Int[] - algeeq_idxs = Int[] diff_eqs = Equation[] - alge_eqs = Equation[] + diffeq_idxs = Int[] diff_vars = Int[] - subeqs = Equation[] - solved_equations = Int[] - solved_variables = Int[] - # Solve solvable equations + alge_eqs = Equation[] + algeeq_idxs = Int[] + solved_eqs = Equation[] + solvedeq_idxs = Int[] + solved_vars = Int[] + toporder = topological_sort(DiCMOBiGraph{false}(graph, var_eq_matching)) eqs = Iterators.reverse(toporder) - total_sub = Dict() idep = iv + + # Generate equations. + # Solvable equations of differential variables D(x) become differential equations + # Solvable equations of non-differential variables become observable equations + # Non-solvable equations become algebraic equations. for ieq in eqs iv = eq_var_matching[ieq] - if is_solvable(ieq, iv) - # We don't solve differential equations, but we will need to try to - # convert it into the mass matrix form. - # We cannot solve the differential variable like D(x) - if isdervar(iv) - isnothing(D) && - error("Differential found in a non-differential system. Likely this is a bug in the construction of an initialization system. Please report this issue with a reproducible example. Offending equation: $(equations(sys)[ieq])") - order, lv = var_order(iv) - dx = D(simplify_shifts(lower_varname_withshift( - fullvars[lv], idep, order - 1))) - eq = dx ~ simplify_shifts(Symbolics.fixpoint_sub( - Symbolics.symbolic_linear_solve(neweqs[ieq], - fullvars[iv]), - total_sub; operator = ModelingToolkit.Shift)) - for e in 𝑑neighbors(graph, iv) - e == ieq && continue - for v in 𝑠neighbors(graph, e) - add_edge!(graph, e, v) - end - rem_edge!(graph, e, iv) - end - push!(diff_eqs, eq) - total_sub[simplify_shifts(eq.lhs)] = eq.rhs - push!(diffeq_idxs, ieq) - push!(diff_vars, diff_to_var[iv]) - continue + eq = neweqs[ieq] + + if is_solvable(ieq, iv) && isdervar(iv) + var = fullvars[iv] + isnothing(D) && throw(UnexpectedDifferentialError(equations(sys)[ieq])) + order, lv = var_order(iv, diff_to_var) + dx = D(simplify_shifts(fullvars[lv])) + + neweq = make_differential_equation(var, dx, eq, total_sub) + for e in 𝑑neighbors(graph, iv) + e == ieq && continue + rem_edge!(graph, e, iv) end - eq = neweqs[ieq] + + total_sub[simplify_shifts(neweq.lhs)] = neweq.rhs + push!(diff_eqs, neweq) + push!(diffeq_idxs, ieq) + push!(diff_vars, diff_to_var[iv]) + elseif is_solvable(ieq, iv) var = fullvars[iv] - residual = eq.lhs - eq.rhs - a, b, islinear = linear_expansion(residual, var) - @assert islinear - # 0 ~ a * var + b - # var ~ -b/a - if ModelingToolkit._iszero(a) - @warn "Tearing: solving $eq for $var is singular!" - else - rhs = -b / a - neweq = var ~ Symbolics.fixpoint_sub( - simplify ? - Symbolics.simplify(rhs) : rhs, - total_sub; operator = ModelingToolkit.Shift) - push!(subeqs, neweq) - push!(solved_equations, ieq) - push!(solved_variables, iv) + neweq = make_solved_equation(var, eq, total_sub; simplify) + !isnothing(neweq) && begin + push!(solved_eqs, neweq) + push!(solvedeq_idxs, ieq) + push!(solved_vars, iv) end else - eq = neweqs[ieq] - rhs = eq.rhs - if !(eq.lhs isa Number && eq.lhs == 0) - rhs = eq.rhs - eq.lhs - end - push!(alge_eqs, 0 ~ Symbolics.fixpoint_sub(rhs, total_sub)) + neweq = make_algebraic_equation(eq, total_sub) + push!(alge_eqs, neweq) push!(algeeq_idxs, ieq) end end - # TODO: BLT sorting + + # Generate new equations and orderings neweqs = [diff_eqs; alge_eqs] - inveqsperm = [diffeq_idxs; algeeq_idxs] - eqsperm = zeros(Int, nsrcs(graph)) - for (i, v) in enumerate(inveqsperm) - eqsperm[v] = i - end + eq_ordering = [diffeq_idxs; algeeq_idxs] diff_vars_set = BitSet(diff_vars) if length(diff_vars_set) != length(diff_vars) error("Tearing internal error: lowering DAE into semi-implicit ODE failed!") end - solved_variables_set = BitSet(solved_variables) - invvarsperm = [diff_vars; - setdiff!(setdiff(1:ndsts(graph), diff_vars_set), - solved_variables_set)] + solved_vars_set = BitSet(solved_vars) + var_ordering = [diff_vars; + setdiff!(setdiff(1:ndsts(graph), diff_vars_set), + solved_vars_set)] + + return neweqs, solved_eqs, eq_ordering, var_ordering, length(solved_vars), + length(solved_vars_set) +end + +""" +Occurs when a variable D(x) occurs in a non-differential system. +""" +struct UnexpectedDifferentialError + eq::Equation +end + +function Base.showerror(io::IO, err::UnexpectedDifferentialError) + error("Differential found in a non-differential system. Likely this is a bug in the construction of an initialization system. Please report this issue with a reproducible example. Offending equation: $(err.eq)") +end + +""" +Generate a first-order differential equation whose LHS is `dx`. + +`var` and `dx` represent the same variable, but `var` may be a higher-order differential and `dx` is always first-order. For example, if `var` is D(D(x)), then `dx` would be `D(x_t)`. Solve `eq` for `var`, substitute previously solved variables, and return the differential equation. +""" +function make_differential_equation(var, dx, eq, total_sub) + dx ~ simplify_shifts(Symbolics.fixpoint_sub( + Symbolics.symbolic_linear_solve(eq, var), + total_sub; operator = ModelingToolkit.Shift)) +end + +""" +Generate an algebraic equation. Substitute solved variables into `eq` and return the equation. +""" +function make_algebraic_equation(eq, total_sub) + rhs = eq.rhs + if !(eq.lhs isa Number && eq.lhs == 0) + rhs = eq.rhs - eq.lhs + end + 0 ~ simplify_shifts(Symbolics.fixpoint_sub(rhs, total_sub)) +end + +""" +Solve equation `eq` for `var`, substitute previously solved variables, and return the solved equation. +""" +function make_solved_equation(var, eq, total_sub; simplify = false) + residual = eq.lhs - eq.rhs + a, b, islinear = linear_expansion(residual, var) + @assert islinear + # 0 ~ a * var + b + # var ~ -b/a + if ModelingToolkit._iszero(a) + @warn "Tearing: solving $eq for $var is singular!" + return nothing + else + rhs = -b / a + return var ~ simplify_shifts(Symbolics.fixpoint_sub( + simplify ? + Symbolics.simplify(rhs) : rhs, + total_sub; operator = ModelingToolkit.Shift)) + end +end + +""" +Given the ordering returned by `generate_system_equations!`, update the +tearing state to account for the new order. Permute the variables and equations. +Eliminate the solved variables and equations from the graph and permute the +graph's vertices to account for the new variable/equation ordering. +""" +# TODO: BLT sorting +function reorder_vars!(state::TearingState, var_eq_matching, eq_ordering, + var_ordering, nsolved_eq, nsolved_var) + @unpack solvable_graph, var_to_diff, eq_to_diff, graph = state.structure + + eqsperm = zeros(Int, nsrcs(graph)) + for (i, v) in enumerate(eq_ordering) + eqsperm[v] = i + end varsperm = zeros(Int, ndsts(graph)) - for (i, v) in enumerate(invvarsperm) + for (i, v) in enumerate(var_ordering) varsperm[v] = i end - deps = Vector{Int}[i == 1 ? Int[] : collect(1:(i - 1)) - for i in 1:length(solved_equations)] # Contract the vertices in the structure graph to make the structure match # the new reality of the system we've just created. - graph = contract_variables(graph, var_eq_matching, varsperm, eqsperm, - length(solved_variables), length(solved_variables_set)) + new_graph = contract_variables(graph, var_eq_matching, varsperm, eqsperm, + nsolved_eq, nsolved_var) - # Update system - new_var_to_diff = complete(DiffGraph(length(invvarsperm))) + new_var_to_diff = complete(DiffGraph(length(var_ordering))) for (v, d) in enumerate(var_to_diff) v′ = varsperm[v] (v′ > 0 && d !== nothing) || continue d′ = varsperm[d] new_var_to_diff[v′] = d′ > 0 ? d′ : nothing end - new_eq_to_diff = complete(DiffGraph(length(inveqsperm))) + new_eq_to_diff = complete(DiffGraph(length(eq_ordering))) for (v, d) in enumerate(eq_to_diff) v′ = eqsperm[v] (v′ > 0 && d !== nothing) || continue d′ = eqsperm[d] new_eq_to_diff[v′] = d′ > 0 ? d′ : nothing end + new_fullvars = state.fullvars[var_ordering] + + # Update system structure + @set! state.structure.graph = complete(new_graph) + @set! state.structure.var_to_diff = new_var_to_diff + @set! state.structure.eq_to_diff = new_eq_to_diff + @set! state.fullvars = new_fullvars + state +end - var_to_diff = new_var_to_diff - eq_to_diff = new_eq_to_diff +""" +Update the system equations, unknowns, and observables after simplification. +""" +function update_simplified_system!( + state::TearingState, neweqs, solved_eqs, dummy_sub, var_eq_matching, extra_unknowns; + cse_hack = true, array_hack = true) + @unpack solvable_graph, var_to_diff, eq_to_diff, graph = state.structure diff_to_var = invview(var_to_diff) - old_fullvars = fullvars - @set! state.structure.graph = complete(graph) - @set! state.structure.var_to_diff = var_to_diff - @set! state.structure.eq_to_diff = eq_to_diff - @set! state.fullvars = fullvars = fullvars[invvarsperm] ispresent = let var_to_diff = var_to_diff, graph = graph i -> (!isempty(𝑑neighbors(graph, i)) || (var_to_diff[i] !== nothing && !isempty(𝑑neighbors(graph, var_to_diff[i])))) end sys = state.sys - obs_sub = dummy_sub for eq in neweqs isdiffeq(eq) || continue obs_sub[eq.lhs] = eq.rhs end # TODO: compute the dependency correctly so that we don't have to do this - obs = [fast_substitute(observed(sys), obs_sub); subeqs] + obs = [fast_substitute(observed(sys), obs_sub); solved_eqs] unknowns = Any[v - for (i, v) in enumerate(fullvars) + for (i, v) in enumerate(state.fullvars) if diff_to_var[i] === nothing && ispresent(i)] - if !isempty(extra_vars) - for v in extra_vars - push!(unknowns, old_fullvars[v]) - end - end + unknowns = [unknowns; extra_unknowns] @set! sys.unknowns = unknowns obs, subeqs, deps = cse_and_array_hacks( - sys, obs, subeqs, unknowns, neweqs; cse = cse_hack, array = array_hack) + sys, obs, solved_eqs, unknowns, neweqs; cse = cse_hack, array = array_hack) @set! sys.eqs = neweqs @set! sys.observed = obs - @set! sys.substitutions = Substitutions(subeqs, deps) # Only makes sense for time-dependent @@ -606,6 +737,74 @@ function tearing_reassemble(state::TearingState, var_eq_matching, @set! sys.schedule = Schedule(var_eq_matching, dummy_sub) end sys = schedule(sys) +end + +""" +Give the order of the variable indexed by dv. +""" +function var_order(dv, diff_to_var) + order = 0 + while (dv′ = diff_to_var[dv]) !== nothing + order += 1 + dv = dv′ + end + order, dv +end + +""" +Main internal function for structural simplification for DAE systems and discrete systems. +Generate dummy derivative variables, new equations in terms of variables, return updated +system and tearing state. + +Terminology and Definition: + +A general DAE is in the form of `F(u'(t), u(t), p, t) == 0`. We can +characterize variables in `u(t)` into two classes: differential variables +(denoted `v(t)`) and algebraic variables (denoted `z(t)`). Differential +variables are marked as `SelectedState` and they are differentiated in the +DAE system, i.e. `v'(t)` are all the variables in `u'(t)` that actually +appear in the system. Algebraic variables are variables that are not +differential variables. +""" +function tearing_reassemble(state::TearingState, var_eq_matching, + full_var_eq_matching = nothing; simplify = false, mm = nothing, cse_hack = true, array_hack = true) + extra_vars = Int[] + if full_var_eq_matching !== nothing + for v in 𝑑vertices(state.structure.graph) + eq = full_var_eq_matching[v] + eq isa Int && continue + push!(extra_vars, v) + end + end + extra_unknowns = state.fullvars[extra_vars] + neweqs = collect(equations(state)) + dummy_sub = Dict() + + if ModelingToolkit.has_iv(state.sys) + iv = get_iv(state.sys) + if !is_only_discrete(state.structure) + D = Differential(iv) + else + D = Shift(iv, 1) + end + else + iv = D = nothing + end + + # Structural simplification + substitute_derivatives_algevars!(state, neweqs, var_eq_matching, dummy_sub; iv, D) + + generate_derivative_variables!(state, neweqs, var_eq_matching; mm, iv, D) + + neweqs, solved_eqs, eq_ordering, var_ordering, nelim_eq, nelim_var = generate_system_equations!( + state, neweqs, var_eq_matching; simplify, iv, D) + + state = reorder_vars!( + state, var_eq_matching, eq_ordering, var_ordering, nelim_eq, nelim_var) + + sys = update_simplified_system!(state, neweqs, solved_eqs, dummy_sub, var_eq_matching, + extra_unknowns; cse_hack, array_hack) + @set! state.sys = sys @set! sys.tearing_state = state return invalidate_cache!(sys) diff --git a/src/structural_transformation/utils.jl b/src/structural_transformation/utils.jl index bd24a1d017..f3cea9c7ba 100644 --- a/src/structural_transformation/utils.jl +++ b/src/structural_transformation/utils.jl @@ -449,13 +449,49 @@ end ### Misc ### -function lower_varname_withshift(var, iv, order) - order == 0 && return var - if ModelingToolkit.isoperator(var, ModelingToolkit.Shift) - op = operation(var) - return Shift(op.t, order)(var) +""" +Handle renaming variable names for discrete structural simplification. Three cases: +- positive shift: do nothing +- zero shift: x(t) => Shift(t, 0)(x(t)) +- negative shift: rename the variable +""" +function lower_shift_varname(var, iv) + op = operation(var) + op isa Shift || return Shift(iv, 0)(var, true) # hack to prevent simplification of x(t) - x(t) + if op.steps < 0 + return shift2term(var) + else + return var end - return lower_varname_with_unit(var, iv, order) +end + +""" +Rename a Shift variable with negative shift, Shift(t, k)(x(t)) to xₜ₋ₖ(t). +""" +function shift2term(var) + op = operation(var) + iv = op.t + arg = only(arguments(var)) + is_lowered = !isnothing(ModelingToolkit.getunshifted(arg)) + + backshift = is_lowered ? op.steps + ModelingToolkit.getshift(arg) : op.steps + + num = join(Char(0x2080 + d) for d in reverse!(digits(-backshift))) # subscripted number, e.g. ₁ + ds = join([Char(0x209c), Char(0x208b), num]) + # Char(0x209c) = ₜ + # Char(0x208b) = ₋ (subscripted minus) + + O = is_lowered ? ModelingToolkit.getunshifted(arg) : arg + oldop = operation(O) + newname = backshift != 0 ? Symbol(string(nameof(oldop)), ds) : + Symbol(string(nameof(oldop))) + + newvar = maketerm(typeof(O), Symbolics.rename(oldop, newname), + Symbolics.children(O), Symbolics.metadata(O)) + newvar = setmetadata(newvar, Symbolics.VariableSource, (:variables, newname)) + newvar = setmetadata(newvar, ModelingToolkit.VariableUnshifted, O) + newvar = setmetadata(newvar, ModelingToolkit.VariableShift, backshift) + return newvar end function isdoubleshift(var) @@ -466,6 +502,7 @@ end function simplify_shifts(var) ModelingToolkit.hasshift(var) || return var var isa Equation && return simplify_shifts(var.lhs) ~ simplify_shifts(var.rhs) + (op = operation(var)) isa Shift && op.steps == 0 && return first(arguments(var)) if isdoubleshift(var) op1 = operation(var) vv1 = arguments(var)[1] diff --git a/src/systems/discrete_system/discrete_system.jl b/src/systems/discrete_system/discrete_system.jl index 5af8e24c58..8359a3e7de 100644 --- a/src/systems/discrete_system/discrete_system.jl +++ b/src/systems/discrete_system/discrete_system.jl @@ -269,15 +269,22 @@ function shift_u0map_forward(sys::DiscreteSystem, u0map, defs) for k in collect(keys(u0map)) v = u0map[k] if !((op = operation(k)) isa Shift) - error("Initial conditions must be for the past state of the unknowns. Instead of providing the condition for $k, provide the condition for $(Shift(iv, -1)(k)).") + isnothing(getunshifted(k)) && + error("Initial conditions must be for the past state of the unknowns. Instead of providing the condition for $k, provide the condition for $(Shift(iv, -1)(k)).") + + updated[Shift(iv, 1)(k)] = v + elseif op.steps > 0 + error("Initial conditions must be for the past state of the unknowns. Instead of providing the condition for $k, provide the condition for $(Shift(iv, -1)(only(arguments(k)))).") + else + updated[Shift(iv, op.steps + 1)(only(arguments(k)))] = v end - updated[Shift(iv, op.steps + 1)(arguments(k)[1])] = v end for var in unknowns(sys) op = operation(var) - op isa Shift || continue - haskey(updated, var) && continue - root = first(arguments(var)) + root = getunshifted(var) + shift = getshift(var) + isnothing(root) && continue + (haskey(updated, Shift(iv, shift)(root)) || haskey(updated, var)) && continue haskey(defs, root) || error("Initial condition for $var not provided.") updated[var] = defs[root] end diff --git a/src/systems/systemstructure.jl b/src/systems/systemstructure.jl index 1bdc11f06a..0643f32ec4 100644 --- a/src/systems/systemstructure.jl +++ b/src/systems/systemstructure.jl @@ -140,16 +140,20 @@ get_fullvars(ts::TransformationState) = ts.fullvars has_equations(::TransformationState) = true Base.@kwdef mutable struct SystemStructure - # Maps the (index of) a variable to the (index of) the variable describing - # its derivative. + """Maps the index of variable x to the index of variable D(x).""" var_to_diff::DiffGraph + """Maps the index of an algebraic equation to the index of the equation it is differentiated into.""" eq_to_diff::DiffGraph # Can be access as # `graph` to automatically look at the bipartite graph # or as `torn` to assert that tearing has run. + """Graph that connects equations to variables that appear in them.""" graph::BipartiteGraph{Int, Nothing} + """Graph that connects equations to the variable they will be solved for during simplification.""" solvable_graph::Union{BipartiteGraph{Int, Nothing}, Nothing} + """Variable types (brownian, variable, parameter) in the system.""" var_types::Union{Vector{VariableType}, Nothing} + """Whether the system is discrete.""" only_discrete::Bool end @@ -197,7 +201,9 @@ function complete!(s::SystemStructure) end mutable struct TearingState{T <: AbstractSystem} <: AbstractTearingState{T} + """The system of equations.""" sys::T + """The set of variables of the system.""" fullvars::Vector structure::SystemStructure extra_eqs::Vector @@ -346,6 +352,8 @@ function TearingState(sys; quick_cancel = false, check = true) eqs[i] = eqs[i].lhs ~ rhs end end + + ### Handle discrete variables lowest_shift = Dict() for var in fullvars if ModelingToolkit.isoperator(var, ModelingToolkit.Shift) @@ -464,9 +472,11 @@ function shift_discrete_system(ts::TearingState) vars!(discvars, eq; op = Union{Sample, Hold}) end iv = get_iv(sys) + discmap = Dict(k => StructuralTransformations.simplify_shifts(Shift(iv, 1)(k)) for k in discvars if any(isequal(k), fullvars) && !isa(operation(k), Union{Sample, Hold})) + for i in eachindex(fullvars) fullvars[i] = StructuralTransformations.simplify_shifts(fast_substitute( fullvars[i], discmap; operator = Union{Sample, Hold})) diff --git a/src/utils.jl b/src/utils.jl index 962801622a..747a2833d8 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -1028,6 +1028,8 @@ end diff2term_with_unit(x, t) = _with_unit(diff2term, x, t) lower_varname_with_unit(var, iv, order) = _with_unit(lower_varname, var, iv, iv, order) +shift2term_with_unit(x, t) = _with_unit(shift2term, x, t) +lower_shift_varname_with_unit(var, iv) = _with_unit(lower_shift_varname, var, iv, iv) """ $(TYPEDSIGNATURES) diff --git a/src/variables.jl b/src/variables.jl index 83f0681a09..ddc33fa838 100644 --- a/src/variables.jl +++ b/src/variables.jl @@ -6,6 +6,9 @@ struct VariableOutput end struct VariableIrreducible end struct VariableStatePriority end struct VariableMisc end +# Metadata for renamed shift variables xₜ₋₁ +struct VariableUnshifted end +struct VariableShift end Symbolics.option_to_metadata_type(::Val{:unit}) = VariableUnit Symbolics.option_to_metadata_type(::Val{:connect}) = VariableConnectType Symbolics.option_to_metadata_type(::Val{:input}) = VariableInput @@ -13,6 +16,8 @@ Symbolics.option_to_metadata_type(::Val{:output}) = VariableOutput Symbolics.option_to_metadata_type(::Val{:irreducible}) = VariableIrreducible Symbolics.option_to_metadata_type(::Val{:state_priority}) = VariableStatePriority Symbolics.option_to_metadata_type(::Val{:misc}) = VariableMisc +Symbolics.option_to_metadata_type(::Val{:unshifted}) = VariableUnshifted +Symbolics.option_to_metadata_type(::Val{:shift}) = VariableShift """ dump_variable_metadata(var) @@ -95,7 +100,7 @@ struct Stream <: AbstractConnectType end # special stream connector Get the connect type of x. See also [`hasconnect`](@ref). """ -getconnect(x) = getconnect(unwrap(x)) +getconnect(x::Num) = getconnect(unwrap(x)) getconnect(x::Symbolic) = Symbolics.getmetadata(x, VariableConnectType, nothing) """ hasconnect(x) @@ -134,7 +139,7 @@ function default_toterm(x) if iscall(x) && (op = operation(x)) isa Operator if !(op isa Differential) if op isa Shift && op.steps < 0 - return x + return shift2term(x) end x = normalize_to_differential(op)(arguments(x)...) end @@ -263,7 +268,7 @@ end end struct IsHistory end -ishistory(x) = ishistory(unwrap(x)) +ishistory(x::Num) = ishistory(unwrap(x)) ishistory(x::Symbolic) = getmetadata(x, IsHistory, false) hist(x, t) = wrap(hist(unwrap(x), t)) function hist(x::Symbolic, t) @@ -575,7 +580,7 @@ end Fetch any miscellaneous data associated with symbolic variable `x`. See also [`hasmisc(x)`](@ref). """ -getmisc(x) = getmisc(unwrap(x)) +getmisc(x::Num) = getmisc(unwrap(x)) getmisc(x::Symbolic) = Symbolics.getmetadata(x, VariableMisc, nothing) """ hasmisc(x) @@ -594,7 +599,7 @@ setmisc(x, miscdata) = setmetadata(x, VariableMisc, miscdata) Fetch the unit associated with variable `x`. This function is a metadata getter for an individual variable, while `get_unit` is used for unit inference on more complicated sdymbolic expressions. """ -getunit(x) = getunit(unwrap(x)) +getunit(x::Num) = getunit(unwrap(x)) getunit(x::Symbolic) = Symbolics.getmetadata(x, VariableUnit, nothing) """ hasunit(x) @@ -602,3 +607,9 @@ getunit(x::Symbolic) = Symbolics.getmetadata(x, VariableUnit, nothing) Check if the variable `x` has a unit. """ hasunit(x) = getunit(x) !== nothing + +getunshifted(x::Num) = getunshifted(unwrap(x)) +getunshifted(x::Symbolic) = Symbolics.getmetadata(x, VariableUnshifted, nothing) + +getshift(x::Num) = getshift(unwrap(x)) +getshift(x::Symbolic) = Symbolics.getmetadata(x, VariableShift, 0) diff --git a/test/discrete_system.jl b/test/discrete_system.jl index 78afafd51d..eea0ffc36b 100644 --- a/test/discrete_system.jl +++ b/test/discrete_system.jl @@ -222,21 +222,6 @@ sol = solve(prob, FunctionMap()) @test reduce(vcat, sol.u) == 1:11 -# test that default values apply to the entire history -@variables x(t) = 1.0 -@mtkbuild de = DiscreteSystem([x ~ x(k - 1) + x(k - 2)], t) -prob = DiscreteProblem(de, [], (0, 10)) -@test prob[x] == 2.0 -@test prob[x(k - 1)] == 1.0 - -# must provide initial conditions for history -@test_throws ErrorException DiscreteProblem(de, [x => 2.0], (0, 10)) - -# initial values only affect _that timestep_, not the entire history -prob = DiscreteProblem(de, [x(k - 1) => 2.0], (0, 10)) -@test prob[x] == 3.0 -@test prob[x(k - 1)] == 2.0 - # Issue#2585 getdata(buffer, t) = buffer[mod1(Int(t), length(buffer))] @register_symbolic getdata(buffer::Vector, t) @@ -281,3 +266,65 @@ k = ShiftIndex(t) prob = @test_nowarn DiscreteProblem(sys, nothing, (0.0, 1.0)) @test_nowarn solve(prob, FunctionMap()) end + +@testset "Initialization" begin + # test that default values apply to the entire history + @variables x(t) = 1.0 + @mtkbuild de = DiscreteSystem([x ~ x(k - 1) + x(k - 2)], t) + prob = DiscreteProblem(de, [], (0, 10)) + @test prob[x] == 2.0 + @test prob[x(k - 1)] == 1.0 + + # must provide initial conditions for history + @test_throws ErrorException DiscreteProblem(de, [x => 2.0], (0, 10)) + @test_throws ErrorException DiscreteProblem(de, [x(k + 1) => 2.0], (0, 10)) + + # initial values only affect _that timestep_, not the entire history + prob = DiscreteProblem(de, [x(k - 1) => 2.0], (0, 10)) + @test prob[x] == 3.0 + @test prob[x(k - 1)] == 2.0 + @variables xₜ₋₁(t) + @test prob[xₜ₋₁] == 2.0 + + # Test initial assignment with lowered variable + prob = DiscreteProblem(de, [xₜ₋₁(k - 1) => 4.0], (0, 10)) + @test prob[x(k - 1)] == prob[xₜ₋₁] == 1.0 + @test prob[x] == 5.0 + + # Test missing initial throws error + @variables x(t) + @mtkbuild de = DiscreteSystem([x ~ x(k - 1) + x(k - 2) * x(k - 3)], t) + @test_throws ErrorException prob=DiscreteProblem(de, [x(k - 3) => 2.0], (0, 10)) + @test_throws ErrorException prob=DiscreteProblem( + de, [x(k - 3) => 2.0, x(k - 1) => 3.0], (0, 10)) + + # Test non-assigned initials are given default value + @variables x(t) = 2.0 + @mtkbuild de = DiscreteSystem([x ~ x(k - 1) + x(k - 2) * x(k - 3)], t) + prob = DiscreteProblem(de, [x(k - 3) => 12.0], (0, 10)) + @test prob[x] == 26.0 + @test prob[x(k - 1)] == 2.0 + @test prob[x(k - 2)] == 2.0 + + # Elaborate test + @variables xₜ₋₂(t) zₜ₋₁(t) z(t) + eqs = [x ~ x(k - 1) + z(k - 2), + z ~ x(k - 2) * x(k - 3) - z(k - 1)^2] + @mtkbuild de = DiscreteSystem(eqs, t) + u0 = [x(k - 1) => 3, + xₜ₋₂(k - 1) => 4, + x(k - 2) => 1, + z(k - 1) => 5, + zₜ₋₁(k - 1) => 12] + prob = DiscreteProblem(de, u0, (0, 10)) + @test prob[x] == 15 + @test prob[z] == -21 + + import ModelingToolkit: shift2term + # unknowns(de) = xₜ₋₁, x, zₜ₋₁, xₜ₋₂, z + vars = ModelingToolkit.value.(unknowns(de)) + @test isequal(shift2term(Shift(t, 1)(vars[1])), vars[2]) + @test isequal(shift2term(Shift(t, 1)(vars[4])), vars[1]) + @test isequal(shift2term(Shift(t, -1)(vars[5])), vars[3]) + @test isequal(shift2term(Shift(t, -2)(vars[2])), vars[4]) +end diff --git a/test/structural_transformation/utils.jl b/test/structural_transformation/utils.jl index 863e091aad..4da3d1e924 100644 --- a/test/structural_transformation/utils.jl +++ b/test/structural_transformation/utils.jl @@ -4,6 +4,7 @@ using Graphs using SparseArrays using UnPack using ModelingToolkit: t_nounits as t, D_nounits as D +const ST = StructuralTransformations # Define some variables @parameters L g