From ef75b3352d20aa5ee56f93edfd867c0ee9909474 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Tue, 4 Mar 2025 17:22:27 +0530 Subject: [PATCH 1/3] feat: allow `NonlinearSystem(::ODESystem)` and `NonlinearProblem(::ODESystem)` --- src/systems/nonlinear/nonlinearsystem.jl | 30 ++++++++++++++ test/nonlinearsystem.jl | 50 ++++++++++++++++++++++++ 2 files changed, 80 insertions(+) diff --git a/src/systems/nonlinear/nonlinearsystem.jl b/src/systems/nonlinear/nonlinearsystem.jl index f93e8a5838..523fee28fc 100644 --- a/src/systems/nonlinear/nonlinearsystem.jl +++ b/src/systems/nonlinear/nonlinearsystem.jl @@ -232,6 +232,32 @@ function NonlinearSystem(eqs; kwargs...) return NonlinearSystem(eqs, collect(allunknowns), collect(new_ps); kwargs...) end +""" + $(TYPEDSIGNATURES) + +Convert an `ODESystem` to a `NonlinearSystem` solving for its steady state (where derivatives are zero). +Any differential variable `D(x) ~ f(...)` will be turned into `0 ~ f(...)`. The returned system is not +simplified. If the input system is `complete`d, then so will the returned system. +""" +function NonlinearSystem(sys::ODESystem) + eqs = equations(sys) + obs = observed(sys) + subrules = Dict(D(x) => 0.0 for x in unknowns(sys)) + eqs = map(eqs) do eq + fast_substitute(eq, subrules) + end + + nsys = NonlinearSystem(eqs, unknowns(sys), [parameters(sys); get_iv(sys)]; + parameter_dependencies = parameter_dependencies(sys), + defaults = merge(defaults(sys), Dict(get_iv(sys) => Inf)), guesses = guesses(sys), + initialization_eqs = initialization_equations(sys), name = nameof(sys), + observed = obs) + if iscomplete(sys) + nsys = complete(nsys; split = is_split(sys)) + end + return nsys +end + function calculate_jacobian(sys::NonlinearSystem; sparse = false, simplify = false) cache = get_jac(sys)[] if cache isa Tuple && cache[2] == (sparse, simplify) @@ -529,6 +555,10 @@ function DiffEqBase.NonlinearProblem{iip}(sys::NonlinearSystem, u0map, return remake(NonlinearProblem{iip}(f, u0, p, pt; filter_kwargs(kwargs)...)) end +function DiffEqBase.NonlinearProblem(sys::ODESystem, args...; kwargs...) + NonlinearProblem(NonlinearSystem(sys), args...; kwargs...) +end + """ ```julia DiffEqBase.NonlinearLeastSquaresProblem{iip}(sys::NonlinearSystem, u0map, diff --git a/test/nonlinearsystem.jl b/test/nonlinearsystem.jl index c8dda530e2..cde0be7373 100644 --- a/test/nonlinearsystem.jl +++ b/test/nonlinearsystem.jl @@ -390,3 +390,53 @@ end @test !any(isequal(p[1]), parameters(sys)) @test is_parameter(sys, p) end + +@testset "Can convert from `ODESystem`" begin + @variables x(t) y(t) + @parameters p q r + @named sys = ODESystem([D(x) ~ p * x^3 + q, 0 ~ -y + q * x - r], t; + defaults = [x => 1.0, p => missing], guesses = [p => 1.0], + initialization_eqs = [p^3 + q^3 ~ 4r], parameter_dependencies = [r ~ 3p]) + nlsys = NonlinearSystem(sys) + defs = defaults(nlsys) + @test length(defs) == 3 + @test defs[x] == 1.0 + @test defs[p] === missing + @test isinf(defs[t]) + @test length(guesses(nlsys)) == 1 + @test guesses(nlsys)[p] == 1.0 + @test length(initialization_equations(nlsys)) == 1 + @test length(parameter_dependencies(nlsys)) == 1 + @test length(equations(nlsys)) == 2 + @test all(iszero, [eq.lhs for eq in equations(nlsys)]) + @test nameof(nlsys) == nameof(sys) + @test !ModelingToolkit.iscomplete(nlsys) + + sys1 = complete(sys; split = false) + nlsys = NonlinearSystem(sys1) + @test ModelingToolkit.iscomplete(nlsys) + @test !ModelingToolkit.is_split(nlsys) + + sys2 = complete(sys) + nlsys = NonlinearSystem(sys2) + @test ModelingToolkit.iscomplete(nlsys) + @test ModelingToolkit.is_split(nlsys) + + sys3 = structural_simplify(sys) + nlsys = NonlinearSystem(sys3) + @test length(equations(nlsys)) == length(observed(nlsys)) == 1 + + prob = NonlinearProblem(sys3, [q => 2.0]) + @test prob.f.initialization_data.initializeprobmap === nothing + sol = solve(prob) + @test SciMLBase.successful_retcode(sol) + @test sol.ps[p^3 + q^3]≈sol.ps[4r] atol=1e-10 + + @testset "Differential inside expression also substituted" begin + @named sys = ODESystem([0 ~ y * D(x) + x^2 - p, 0 ~ x * D(y) + y * p], t) + nlsys = NonlinearSystem(sys) + vs = ModelingToolkit.vars(equations(nlsys)) + @test !in(D(x), vs) + @test !in(D(y), vs) + end +end From ee84d23969f1eea39add5bed3feba6e9d74faec8 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Tue, 4 Mar 2025 18:11:03 +0530 Subject: [PATCH 2/3] fix: dispatch on `AbstractODESystem` to avoid use-before-define problems --- src/systems/nonlinear/nonlinearsystem.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/systems/nonlinear/nonlinearsystem.jl b/src/systems/nonlinear/nonlinearsystem.jl index 523fee28fc..ea3730fdb6 100644 --- a/src/systems/nonlinear/nonlinearsystem.jl +++ b/src/systems/nonlinear/nonlinearsystem.jl @@ -239,7 +239,7 @@ Convert an `ODESystem` to a `NonlinearSystem` solving for its steady state (wher Any differential variable `D(x) ~ f(...)` will be turned into `0 ~ f(...)`. The returned system is not simplified. If the input system is `complete`d, then so will the returned system. """ -function NonlinearSystem(sys::ODESystem) +function NonlinearSystem(sys::AbstractODESystem) eqs = equations(sys) obs = observed(sys) subrules = Dict(D(x) => 0.0 for x in unknowns(sys)) @@ -555,7 +555,7 @@ function DiffEqBase.NonlinearProblem{iip}(sys::NonlinearSystem, u0map, return remake(NonlinearProblem{iip}(f, u0, p, pt; filter_kwargs(kwargs)...)) end -function DiffEqBase.NonlinearProblem(sys::ODESystem, args...; kwargs...) +function DiffEqBase.NonlinearProblem(sys::AbstractODESystem, args...; kwargs...) NonlinearProblem(NonlinearSystem(sys), args...; kwargs...) end From 82b9b40371aedb6165fb7d88de080a6631f7b181 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Wed, 5 Mar 2025 11:29:25 +0530 Subject: [PATCH 3/3] fix: disambiguate `observed` --- test/nonlinearsystem.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/nonlinearsystem.jl b/test/nonlinearsystem.jl index cde0be7373..420bdc088e 100644 --- a/test/nonlinearsystem.jl +++ b/test/nonlinearsystem.jl @@ -424,7 +424,7 @@ end sys3 = structural_simplify(sys) nlsys = NonlinearSystem(sys3) - @test length(equations(nlsys)) == length(observed(nlsys)) == 1 + @test length(equations(nlsys)) == length(ModelingToolkit.observed(nlsys)) == 1 prob = NonlinearProblem(sys3, [q => 2.0]) @test prob.f.initialization_data.initializeprobmap === nothing