From 48ec55c15ed46d03d6a5e1e8d0627fd026dde311 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Tue, 25 Feb 2025 13:22:23 +0530 Subject: [PATCH 1/3] feat: add `map_variables_to_equations` --- src/ModelingToolkit.jl | 1 + src/systems/systems.jl | 54 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 55 insertions(+) diff --git a/src/ModelingToolkit.jl b/src/ModelingToolkit.jl index b683e132a2..48070f244a 100644 --- a/src/ModelingToolkit.jl +++ b/src/ModelingToolkit.jl @@ -276,6 +276,7 @@ export TearingState export BipartiteGraph, equation_dependencies, variable_dependencies export eqeq_dependencies, varvar_dependencies export asgraph, asdigraph +export map_variables_to_equations export toexpr, get_variables export simplify, substitute diff --git a/src/systems/systems.jl b/src/systems/systems.jl index ef0f966eee..6151ffa515 100644 --- a/src/systems/systems.jl +++ b/src/systems/systems.jl @@ -158,3 +158,57 @@ function __structural_simplify(sys::AbstractSystem, io = nothing; simplify = fal guesses = guesses(sys), initialization_eqs = initialization_equations(sys)) end end + +""" + $(TYPEDSIGNATURES) + +Given a system that has been simplified via `structural_simplify`, return a `Dict` mapping +variables of the system to equations that are used to solve for them. This includes +observed variables. + +# Keyword Arguments + +- `rename_dummy_derivatives`: Whether to rename dummy derivative variable keys into their + `Differential` forms. For example, this would turn the key `y藣t(t)` into + `Differential(t)(y(t))`. +""" +function map_variables_to_equations(sys::AbstractSystem; rename_dummy_derivatives = true) + if !has_tearing_state(sys) + throw(ArgumentError("$(typeof(sys)) is not supported.")) + end + ts = get_tearing_state(sys) + if ts === nothing + throw(ArgumentError("`map_variables_to_equations` requires a simplified system. Call `structural_simplify` on the system before calling this function.")) + end + + dummy_sub = Dict() + if rename_dummy_derivatives && has_schedule(sys) && (sc = get_schedule(sys)) !== nothing + dummy_sub = Dict(v => k for (k, v) in sc.dummy_sub if isequal(default_toterm(k), v)) + end + + mapping = Dict{Union{Num, BasicSymbolic}, Equation}() + eqs = equations(sys) + for eq in eqs + isdifferential(eq.lhs) || continue + var = arguments(eq.lhs)[1] + var = get(dummy_sub, var, var) + mapping[var] = eq + end + + graph = ts.structure.graph + algvars = BitSet(findall( + Base.Fix1(StructuralTransformations.isalgvar, ts.structure), 1:ndsts(graph))) + algeqs = BitSet(findall(1:nsrcs(graph)) do eq + all(!Base.Fix1(isdervar, ts.structure), 饾憼neighbors(graph, eq)) + end) + alge_var_eq_matching = complete(maximal_matching(graph, in(algeqs), in(algvars))) + for (i, eq) in enumerate(alge_var_eq_matching) + eq isa Unassigned && continue + mapping[get(dummy_sub, ts.fullvars[i], ts.fullvars[i])] = eqs[eq] + end + for eq in observed(sys) + mapping[get(dummy_sub, eq.lhs, eq.lhs)] = eq + end + + return mapping +end From 74de88b8acfb0efb9045801604969e682bc3504d Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Tue, 25 Feb 2025 13:22:30 +0530 Subject: [PATCH 2/3] test: test `map_variables_to_equations` --- test/structural_transformation/utils.jl | 98 ++++++++++++++++++++++++- 1 file changed, 96 insertions(+), 2 deletions(-) diff --git a/test/structural_transformation/utils.jl b/test/structural_transformation/utils.jl index 4da3d1e924..24cfb98d45 100644 --- a/test/structural_transformation/utils.jl +++ b/test/structural_transformation/utils.jl @@ -3,8 +3,8 @@ using ModelingToolkit using Graphs using SparseArrays using UnPack -using ModelingToolkit: t_nounits as t, D_nounits as D -const ST = StructuralTransformations +using ModelingToolkit: t_nounits as t, D_nounits as D, default_toterm +using Symbolics: unwrap # Define some variables @parameters L g @@ -162,3 +162,97 @@ end structural_simplify(sys; additional_passes = [pass]) @test value[] == 1 end + +@testset "`map_variables_to_equations`" begin + @testset "Not supported for systems without `.tearing_state`" begin + @variables x + @mtkbuild sys = OptimizationSystem(x^2) + @test_throws ArgumentError map_variables_to_equations(sys) + end + @testset "Requires simplified system" begin + @variables x(t) y(t) + @named sys = ODESystem([D(x) ~ x, y ~ 2x], t) + sys = complete(sys) + @test_throws ArgumentError map_variables_to_equations(sys) + end + @testset "`ODESystem`" begin + @variables x(t) y(t) z(t) + @mtkbuild sys = ODESystem([D(x) ~ 2x + y, y ~ x + z, z^3 + x^3 ~ 12], t) + mapping = map_variables_to_equations(sys) + @test mapping[x] == (D(x) ~ 2x + y) + @test mapping[y] == (y ~ x + z) + @test mapping[z] == (0 ~ 12 - z^3 - x^3) + @test length(mapping) == 3 + + @testset "With dummy derivatives" begin + @parameters g + @variables x(t) y(t) [state_priority = 10] 位(t) + eqs = [D(D(x)) ~ 位 * x + D(D(y)) ~ 位 * y - g + x^2 + y^2 ~ 1] + @mtkbuild sys = ODESystem(eqs, t) + mapping = map_variables_to_equations(sys) + + yt = default_toterm(unwrap(D(y))) + xt = default_toterm(unwrap(D(x))) + xtt = default_toterm(unwrap(D(D(x)))) + @test mapping[x] == (0 ~ 1 - x^2 - y^2) + @test mapping[y] == (D(y) ~ yt) + @test mapping[D(y)] == (D(yt) ~ -g + y * 位) + @test mapping[D(x)] == (0 ~ -2xt * x - 2yt * y) + @test mapping[D(D(x))] == (xtt ~ x * 位) + @test length(mapping) == 5 + + @testset "`rename_dummy_derivatives = false`" begin + mapping = map_variables_to_equations(sys; rename_dummy_derivatives = false) + + @test mapping[x] == (0 ~ 1 - x^2 - y^2) + @test mapping[y] == (D(y) ~ yt) + @test mapping[yt] == (D(yt) ~ -g + y * 位) + @test mapping[xt] == (0 ~ -2xt * x - 2yt * y) + @test mapping[xtt] == (xtt ~ x * 位) + @test length(mapping) == 5 + end + end + @testset "DDEs" begin + function oscillator(; name, k = 1.0, 蟿 = 0.01) + @parameters k=k 蟿=蟿 + @variables x(..)=0.1 y(t)=0.1 jcn(t)=0.0 delx(t) + eqs = [D(x(t)) ~ y, + D(y) ~ -k * x(t - 蟿) + jcn, + delx ~ x(t - 蟿)] + return System(eqs, t; name = name) + end + + systems = @named begin + osc1 = oscillator(k = 1.0, 蟿 = 0.01) + osc2 = oscillator(k = 2.0, 蟿 = 0.04) + end + eqs = [osc1.jcn ~ osc2.delx, + osc2.jcn ~ osc1.delx] + @named coupledOsc = System(eqs, t) + @mtkbuild sys = compose(coupledOsc, systems) + mapping = map_variables_to_equations(sys) + x1 = operation(unwrap(osc1.x)) + x2 = operation(unwrap(osc2.x)) + @test mapping[osc1.x] == (D(osc1.x) ~ osc1.y) + @test mapping[osc1.y] == (D(osc1.y) ~ osc1.jcn - osc1.k * x1(t - osc1.蟿)) + @test mapping[osc1.delx] == (osc1.delx ~ x1(t - osc1.蟿)) + @test mapping[osc1.jcn] == (osc1.jcn ~ osc2.delx) + @test mapping[osc2.x] == (D(osc2.x) ~ osc2.y) + @test mapping[osc2.y] == (D(osc2.y) ~ osc2.jcn - osc2.k * x2(t - osc2.蟿)) + @test mapping[osc2.delx] == (osc2.delx ~ x2(t - osc2.蟿)) + @test mapping[osc2.jcn] == (osc2.jcn ~ osc1.delx) + @test length(mapping) == 8 + end + end + @testset "`NonlinearSystem`" begin + @variables x y z + @mtkbuild sys = NonlinearSystem([x^2 ~ 2y^2 + 1, sin(z) ~ y, z^3 + 4z + 1 ~ 0]) + mapping = map_variables_to_equations(sys) + @test mapping[x] == (0 ~ 2y^2 + 1 - x^2) + @test mapping[y] == (y ~ sin(z)) + @test mapping[z] == (0 ~ -1 - 4z - z^3) + @test length(mapping) == 3 + end +end From 55601e07318531341c118de9b3f4b38416a03373 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Tue, 25 Feb 2025 13:24:28 +0530 Subject: [PATCH 3/3] docs: add `map_variables_to_equations` to docs --- docs/src/basics/DependencyGraphs.md | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/docs/src/basics/DependencyGraphs.md b/docs/src/basics/DependencyGraphs.md index 67de5ea8f1..73fbbca9be 100644 --- a/docs/src/basics/DependencyGraphs.md +++ b/docs/src/basics/DependencyGraphs.md @@ -22,3 +22,9 @@ asdigraph eqeq_dependencies varvar_dependencies ``` + +# Miscellaneous + +```@docs +map_variables_to_equations +```