From 21d1ce7e3707c890633d363bba21301d3a69bdbd Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Thu, 24 Oct 2024 15:05:28 +0530 Subject: [PATCH 1/3] fix: construct `initializeprob` if initial value is symbolic --- src/systems/problem_utils.jl | 7 ++++++- test/initializationsystem.jl | 16 ++++++++++++++++ 2 files changed, 22 insertions(+), 1 deletion(-) diff --git a/src/systems/problem_utils.jl b/src/systems/problem_utils.jl index e530e62eed..cc2e959ccd 100644 --- a/src/systems/problem_utils.jl +++ b/src/systems/problem_utils.jl @@ -433,9 +433,14 @@ function process_SciMLProblem( solvablepars = [p for p in parameters(sys) if is_parameter_solvable(p, pmap, defs, guesses)] + has_dependent_unknowns = any(unknowns(sys)) do sym + val = get(op, sym, nothing) + val === nothing && return false + return symbolic_type(val) != NotSymbolic() || is_array_of_symbolics(val) + end if build_initializeprob && (((implicit_dae || has_observed_u0s || !isempty(missing_unknowns) || - !isempty(solvablepars)) && + !isempty(solvablepars) || has_dependent_unknowns) && get_tearing_state(sys) !== nothing) || !isempty(initialization_equations(sys))) && t !== nothing initializeprob = ModelingToolkit.InitializationProblem( diff --git a/test/initializationsystem.jl b/test/initializationsystem.jl index de48821cff..3eddba2b78 100644 --- a/test/initializationsystem.jl +++ b/test/initializationsystem.jl @@ -844,3 +844,19 @@ end isys = ModelingToolkit.generate_initializesystem(sys) @test isequal(defaults(isys)[y], 2x + 1) end + +@testset "Create initializeprob when unknown has dependent value" begin + @variables x(t) y(t) + @mtkbuild sys = ODESystem([D(x) ~ x, D(y) ~ t * y], t; defaults = [x => 2y]) + prob = ODEProblem(sys, [y => 1.0], (0.0, 1.0)) + @test prob.f.initializeprob !== nothing + integ = init(prob) + @test integ[x] ≈ 2.0 + + @variables x(t)[1:2] y(t) + @mtkbuild sys = ODESystem([D(x) ~ x, D(y) ~ t], t; defaults = [x => [y, 3.0]]) + prob = ODEProblem(sys, [y => 1.0], (0.0, 1.0)) + @test prob.f.initializeprob !== nothing + integ = init(prob) + @test integ[x] ≈ [1.0, 3.0] +end From 6081a502fdd102d311ba296a94c4963ea1a61c65 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Thu, 24 Oct 2024 15:27:58 +0530 Subject: [PATCH 2/3] fix: recursively unwrap arrays of symbolics in `process_SciMLProblem` --- src/systems/problem_utils.jl | 21 +++++++++++++++++++-- test/initial_values.jl | 9 +++++++++ 2 files changed, 28 insertions(+), 2 deletions(-) diff --git a/src/systems/problem_utils.jl b/src/systems/problem_utils.jl index cc2e959ccd..ce46f2762a 100644 --- a/src/systems/problem_utils.jl +++ b/src/systems/problem_utils.jl @@ -171,7 +171,24 @@ function to_varmap(vals, varlist::Vector) check_eqs_u0(varlist, varlist, vals) vals = vec(varlist) .=> vec(vals) end - return anydict(unwrap(k) => unwrap(v) for (k, v) in anydict(vals)) + return recursive_unwrap(anydict(vals)) +end + +""" + $(TYPEDSIGNATURES) + +Recursively call `Symbolics.unwrap` on `x`. Useful when `x` is an array of (potentially) +symbolic values, all of which need to be unwrapped. Specializes when `x isa AbstractDict` +to unwrap keys and values, returning an `AnyDict`. +""" +function recursive_unwrap(x::AbstractArray) + symbolic_type(x) == ArraySymbolic() ? unwrap(x) : recursive_unwrap.(x) +end + +recursive_unwrap(x) = unwrap(x) + +function recursive_unwrap(x::AbstractDict) + return anydict(unwrap(k) => recursive_unwrap(v) for (k, v) in x) end """ @@ -410,7 +427,7 @@ function process_SciMLProblem( u0map = to_varmap(u0map, dvs) _pmap = pmap pmap = to_varmap(pmap, ps) - defs = add_toterms(defaults(sys)) + defs = add_toterms(recursive_unwrap(defaults(sys))) cmap, cs = get_cmap(sys) kwargs = NamedTuple(kwargs) diff --git a/test/initial_values.jl b/test/initial_values.jl index 2ff1b0c2a3..12fb7633e9 100644 --- a/test/initial_values.jl +++ b/test/initial_values.jl @@ -119,3 +119,12 @@ end prob = ODEProblem(sys, [], (1.0, 2.0), []) @test prob[x] == 1.0 @test prob.ps[p] == 2.0 + +@testset "Array of symbolics is unwrapped" begin + @variables x(t)[1:2] y(t) + @mtkbuild sys = ODESystem([D(x) ~ x, D(y) ~ t], t; defaults = [x => [y, 3.0]]) + prob = ODEProblem(sys, [y => 1.0], (0.0, 1.0)) + @test eltype(prob.u0) <: Float64 + prob = ODEProblem(sys, [x => [y, 4.0], y => 2.0], (0.0, 1.0)) + @test eltype(prob.u0) <: Float64 +end From 2184598d1d88f0a3f1b1f760d4040d5651f2e5b3 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Thu, 24 Oct 2024 15:42:51 +0530 Subject: [PATCH 3/3] fix: handle all parameter values from defaults in `split = false` systems --- src/systems/problem_utils.jl | 2 +- test/initial_values.jl | 8 ++++++++ 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/src/systems/problem_utils.jl b/src/systems/problem_utils.jl index ce46f2762a..daf1c164d8 100644 --- a/src/systems/problem_utils.jl +++ b/src/systems/problem_utils.jl @@ -279,7 +279,7 @@ function better_varmap_to_vars(varmap::AbstractDict, vars::Vector; end vals = map(x -> varmap[x], vars) - if container_type <: Union{AbstractDict, Tuple, Nothing} + if container_type <: Union{AbstractDict, Tuple, Nothing, SciMLBase.NullParameters} container_type = Array end diff --git a/test/initial_values.jl b/test/initial_values.jl index 12fb7633e9..9911bab7f4 100644 --- a/test/initial_values.jl +++ b/test/initial_values.jl @@ -128,3 +128,11 @@ prob = ODEProblem(sys, [], (1.0, 2.0), []) prob = ODEProblem(sys, [x => [y, 4.0], y => 2.0], (0.0, 1.0)) @test eltype(prob.u0) <: Float64 end + +@testset "split=false systems with all parameter defaults" begin + @variables x(t) = 1.0 + @parameters p=1.0 q=2.0 r=3.0 + @mtkbuild sys=ODESystem(D(x) ~ p * x + q * t + r, t) split=false + prob = @test_nowarn ODEProblem(sys, [], (0.0, 1.0)) + @test prob.p isa Vector{Float64} +end