diff --git a/.github/workflows/Downstream.yml b/.github/workflows/Downstream.yml index c17dfa0bce..9558c6ba08 100644 --- a/.github/workflows/Downstream.yml +++ b/.github/workflows/Downstream.yml @@ -49,6 +49,8 @@ jobs: - {user: SciML, repo: SciMLSensitivity.jl, group: Core3} - {user: SciML, repo: SciMLSensitivity.jl, group: Core4} - {user: SciML, repo: SciMLSensitivity.jl, group: Core5} + - {user: SciML, repo: SciMLSensitivity.jl, group: Core6} + - {user: SciML, repo: SciMLSensitivity.jl, group: Core7} - {user: SciML, repo: Catalyst.jl, group: All} steps: diff --git a/Project.toml b/Project.toml index 6c78004166..d1d1d447b1 100644 --- a/Project.toml +++ b/Project.toml @@ -50,7 +50,7 @@ SciMLBasePartialFunctionsExt = "PartialFunctions" SciMLBasePyCallExt = "PyCall" SciMLBasePythonCallExt = "PythonCall" SciMLBaseRCallExt = "RCall" -SciMLBaseZygoteExt = "Zygote" +SciMLBaseZygoteExt = ["Zygote", "ChainRulesCore"] [compat] ADTypes = "0.2.5,1.0.0" diff --git a/ext/SciMLBaseChainRulesCoreExt.jl b/ext/SciMLBaseChainRulesCoreExt.jl index 934a0a1dfa..278bb00350 100644 --- a/ext/SciMLBaseChainRulesCoreExt.jl +++ b/ext/SciMLBaseChainRulesCoreExt.jl @@ -1,8 +1,9 @@ module SciMLBaseChainRulesCoreExt using SciMLBase +using SciMLBase: getobserved import ChainRulesCore -import ChainRulesCore: NoTangent, @non_differentiable +import ChainRulesCore: NoTangent, @non_differentiable, zero_tangent, rrule_via_ad using SymbolicIndexingInterface function ChainRulesCore.rrule( @@ -15,52 +16,28 @@ function ChainRulesCore.rrule( j::Integer) function ODESolution_getindex_pullback(Δ) i = symbolic_type(sym) != NotSymbolic() ? variable_index(VA, sym) : sym - if i === nothing + du, dprob = if i === nothing getter = getobserved(VA) grz = rrule_via_ad(config, getter, sym, VA.u[j], VA.prob.p, VA.t[j])[2](Δ) - du = [k == j ? grz[2] : zero(VA.u[1]) for k in 1:length(VA.u)] - dp = grz[3] # pullback for p + du = [k == j ? grz[3] : zero(VA.u[1]) for k in 1:length(VA.u)] + dp = grz[4] # pullback for p + if dp == NoTangent() + dp = zero_tangent(parameter_values(VA.prob)) + end dprob = remake(VA.prob, p = dp) - T = eltype(eltype(VA.u)) - N = length(VA.prob.p) - Δ′ = ODESolution{T, N, typeof(du), Nothing, Nothing, Nothing, Nothing, - typeof(dprob), Nothing, Nothing, Nothing, Nothing}(du, nothing, - nothing, nothing, nothing, dprob, nothing, nothing, - VA.dense, 0, nothing, nothing, VA.retcode) - (NoTangent(), Δ′, NoTangent(), NoTangent()) + du, dprob else du = [m == j ? [i == k ? Δ : zero(VA.u[1][1]) for k in 1:length(VA.u[1])] : zero(VA.u[1]) for m in 1:length(VA.u)] - dp = zero(VA.prob.p) + dp = zero_tangent(VA.prob.p) dprob = remake(VA.prob, p = dp) - Δ′ = ODESolution{ - T, - N, - typeof(du), - Nothing, - Nothing, - typeof(VA.t), - typeof(VA.k), - typeof(dprob), - typeof(VA.alg), - typeof(VA.interp), - typeof(VA.alg_choice), - typeof(VA.stats) - }(du, - nothing, - nothing, - VA.t, - VA.k, - dprob, - VA.alg, - VA.interp, - VA.dense, - 0, - VA.stats, - VA.alg_choice, - VA.retcode) - (NoTangent(), Δ′, NoTangent(), NoTangent()) + du, dprob end + T = eltype(eltype(du)) + N = ndims(eltype(du)) + 1 + Δ′ = ODESolution{T, N}(du, nothing, nothing, VA.t, VA.k, nothing, dprob, + VA.alg, VA.interp, VA.dense, 0, VA.stats, VA.alg_choice, VA.retcode) + (NoTangent(), Δ′, NoTangent(), NoTangent()) end VA[sym, j], ODESolution_getindex_pullback end diff --git a/ext/SciMLBaseZygoteExt.jl b/ext/SciMLBaseZygoteExt.jl index fa67f06464..902228b03f 100644 --- a/ext/SciMLBaseZygoteExt.jl +++ b/ext/SciMLBaseZygoteExt.jl @@ -3,6 +3,7 @@ module SciMLBaseZygoteExt using Zygote using Zygote: @adjoint, pullback import Zygote: literal_getproperty +import ChainRulesCore using SciMLBase using SciMLBase: ODESolution, remake, getobserved, build_solution, EnsembleSolution, @@ -40,31 +41,9 @@ import SciMLStructures VA[i, j], ODESolution_getindex_pullback end -@adjoint function Base.getindex(VA::ODESolution, sym, j::Int) - function ODESolution_getindex_pullback(Δ) - i = symbolic_type(sym) != NotSymbolic() ? variable_index(VA, sym) : sym - du, dprob = if i === nothing - getter = getobserved(VA) - grz = pullback(getter, sym, VA.u[j], VA.prob.p, VA.t[j])[2](Δ) - du = [k == j ? grz[2] : zero(VA.u[1]) for k in 1:length(VA.u)] - dp = grz[3] # pullback for p - dprob = remake(VA.prob, p = dp) - du, dprob - else - du = [m == j ? [i == k ? Δ : zero(VA.u[1][1]) for k in 1:length(VA.u[1])] : - zero(VA.u[1]) for m in 1:length(VA.u)] - dp = zero(VA.prob.p) - dprob = remake(VA.prob, p = dp) - du, dprob - end - T = eltype(eltype(VA.u)) - N = ndims(VA) - Δ′ = ODESolution{T, N}(du, nothing, nothing, - VA.t, VA.k, VA.discretes, dprob, VA.alg, VA.interp, VA.dense, 0, VA.stats, - VA.alg_choice, VA.retcode) - (Δ′, nothing, nothing) - end - VA[sym, j], ODESolution_getindex_pullback +@adjoint function Base.getindex(VA::ODESolution, sym, j::Integer) + res, pullback = ChainRulesCore.rrule(Zygote.ZygoteRuleConfig(), getindex, VA, sym, j) + return res, Base.tail ∘ pullback end @adjoint function EnsembleSolution(sim, time, converged, stats) diff --git a/src/problems/sde_problems.jl b/src/problems/sde_problems.jl index 5cc10db511..68311d22f1 100644 --- a/src/problems/sde_problems.jl +++ b/src/problems/sde_problems.jl @@ -129,10 +129,14 @@ function ConstructionBase.constructorof(::Type{P}) where {P <: SDEProblem} function ctor(f, g, u0, tspan, p, noise, kw, noise_rate_prototype, seed) if f isa AbstractSDEFunction iip = isinplace(f) + if g !== f.g + f = remake(f; g) + end + return SDEProblem{iip}(f, u0, tspan, p; kw..., noise, noise_rate_prototype, seed) else iip = isinplace(f, 4) + return SDEProblem{iip}(f, g, u0, tspan, p; kw..., noise, noise_rate_prototype, seed) end - return SDEProblem{iip}(f, g, u0, tspan, p; kw..., noise, noise_rate_prototype, seed) end end diff --git a/src/remake.jl b/src/remake.jl index a5417a27eb..16124b17c9 100644 --- a/src/remake.jl +++ b/src/remake.jl @@ -958,7 +958,7 @@ end function Base.showerror(io::IO, err::CyclicDependencyError) println(io, "Detected cyclic dependency in initial values:") for (k, v) in err.varmap - println(io, k, " => ", "v") + println(io, k, " => ", v) end println(io, "While trying to solve for variables: ", err.vars) end @@ -1085,10 +1085,6 @@ calling `SymbolicIndexingInterface.symbolic_container`, provided for dispatch. R the updated `newu0` and `newp`. """ function late_binding_update_u0_p(prob, root_indp, u0, p, t0, newu0, newp) - if hasmethod(symbolic_container, Tuple{typeof(root_indp)}) && - (sc = symbolic_container(root_indp)) !== root_indp - return late_binding_update_u0_p(prob, sc, u0, p, t0, newu0, newp) - end return newu0, newp end @@ -1099,7 +1095,7 @@ Calls `late_binding_update_u0_p(prob, root_indp, u0, p, t0, newu0, newp)` after `root_indp`. """ function late_binding_update_u0_p(prob, u0, p, t0, newu0, newp) - root_indp = prob + root_indp = get_root_indp(prob) return late_binding_update_u0_p(prob, root_indp, u0, p, t0, newu0, newp) end diff --git a/src/scimlfunctions.jl b/src/scimlfunctions.jl index a9b2622943..139afd3b00 100644 --- a/src/scimlfunctions.jl +++ b/src/scimlfunctions.jl @@ -4818,8 +4818,10 @@ for S in [:ODEFunction end end +const EMPTY_SYMBOLCACHE = SymbolCache() + function SymbolicIndexingInterface.symbolic_container(fn::AbstractSciMLFunction) - has_sys(fn) ? fn.sys : SymbolCache() + has_sys(fn) ? fn.sys : EMPTY_SYMBOLCACHE end function SymbolicIndexingInterface.is_observed(fn::AbstractSciMLFunction, sym) diff --git a/src/solutions/save_idxs.jl b/src/solutions/save_idxs.jl index 3d86f306a2..2559eaf349 100644 --- a/src/solutions/save_idxs.jl +++ b/src/solutions/save_idxs.jl @@ -44,9 +44,11 @@ function as_diffeq_array(vt::Vector{VectorTemplate}, t) return DiffEqArray(typeof(TupleOfArraysWrapper(vt))[], t, (1, 1)) end -function is_empty_indp(indp) - isempty(variable_symbols(indp)) && isempty(parameter_symbols(indp)) && - isempty(independent_variable_symbols(indp)) +function get_root_indp(indp) + if hasmethod(symbolic_container, Tuple{typeof(indp)}) && (sc = symbolic_container(indp)) !== indp + return get_root_indp(sc) + end + return indp end # Everything from this point on is public API @@ -105,17 +107,26 @@ struct SavedSubsystem{V, T, M, I, P, Q, C} partition_count::C end -function SavedSubsystem(indp, pobj, saved_idxs) - # nothing saved - if saved_idxs === nothing || isempty(saved_idxs) +SavedSubsystem(indp, pobj, ::Nothing) = nothing + +function SavedSubsystem(indp, pobj, idx::Int) + _indp = get_root_indp(indp) + if _indp === EMPTY_SYMBOLCACHE || _indp === nothing return nothing end + state_map = Dict(1 => idx) + return SavedSubsystem(state_map, nothing, nothing, nothing, nothing, nothing, nothing) +end - # this is required because problems with no system have an empty `SymbolCache` - # as their symbolic container. - if is_empty_indp(indp) +function SavedSubsystem(indp, pobj, saved_idxs::Union{AbstractArray, Tuple}) + _indp = get_root_indp(indp) + if _indp === EMPTY_SYMBOLCACHE || _indp === nothing return nothing end + if eltype(saved_idxs) == Int + state_map = Dict{Int, Int}(v => k for (k, v) in enumerate(saved_idxs)) + return SavedSubsystem(state_map, nothing, nothing, nothing, nothing, nothing, nothing) + end # array state symbolics must be scalarized saved_idxs = collect(Iterators.flatten(map(saved_idxs) do sym @@ -357,29 +368,32 @@ corresponding to the state variables and a `SavedSubsystem` to pass to `build_so The second return value (corresponding to the `SavedSubsystem`) may be `nothing` in case one is not required. `save_idxs` may be a scalar or `nothing`. """ +get_save_idxs_and_saved_subsystem(prob, ::Nothing) = nothing, nothing +function get_save_idxs_and_saved_subsystem(prob, save_idxs::Vector{Int}) + save_idxs, SavedSubsystem(prob, parameter_values(prob), save_idxs) +end +function get_save_idxs_and_saved_subsystem(prob, save_idx::Int) + save_idx, SavedSubsystem(prob, parameter_values(prob), save_idx) +end function get_save_idxs_and_saved_subsystem(prob, save_idxs) - if save_idxs === nothing - saved_subsystem = nothing + if !(save_idxs isa AbstractArray) || symbolic_type(save_idxs) != NotSymbolic() + _save_idxs = (save_idxs,) else - if !(save_idxs isa AbstractArray) || symbolic_type(save_idxs) != NotSymbolic() - _save_idxs = [save_idxs] + _save_idxs = save_idxs + end + saved_subsystem = SavedSubsystem(prob, parameter_values(prob), _save_idxs) + if saved_subsystem !== nothing + _save_idxs = get_saved_state_idxs(saved_subsystem) + if isempty(_save_idxs) + # no states to save + save_idxs = Int[] + elseif !(save_idxs isa AbstractArray) || + symbolic_type(save_idxs) != NotSymbolic() + # only a single state to save, and save it as a scalar timeseries instead of + # single-element array + save_idxs = only(_save_idxs) else - _save_idxs = save_idxs - end - saved_subsystem = SavedSubsystem(prob, parameter_values(prob), _save_idxs) - if saved_subsystem !== nothing - _save_idxs = get_saved_state_idxs(saved_subsystem) - if isempty(_save_idxs) - # no states to save - save_idxs = Int[] - elseif !(save_idxs isa AbstractArray) || - symbolic_type(save_idxs) != NotSymbolic() - # only a single state to save, and save it as a scalar timeseries instead of - # single-element array - save_idxs = only(_save_idxs) - else - save_idxs = _save_idxs - end + save_idxs = _save_idxs end end diff --git a/test/downstream/Project.toml b/test/downstream/Project.toml index c390a258f4..5e65ebbf4f 100644 --- a/test/downstream/Project.toml +++ b/test/downstream/Project.toml @@ -34,7 +34,7 @@ DelayDiffEq = "5" DiffEqCallbacks = "3, 4" ForwardDiff = "0.10" JumpProcesses = "9.10" -ModelingToolkit = "9.64.1" +ModelingToolkit = "9.64.3" ModelingToolkitStandardLibrary = "2.7" NonlinearSolve = "2, 3, 4" Optimization = "4" diff --git a/test/downstream/adjoints.jl b/test/downstream/adjoints.jl index 4e75e19b6f..9269de6b40 100644 --- a/test/downstream/adjoints.jl +++ b/test/downstream/adjoints.jl @@ -22,8 +22,7 @@ u0 = [lorenz1.x => 1.0, lorenz1.z => 0.0, lorenz2.x => 0.0, lorenz2.y => 1.0, - lorenz2.z => 0.0, - a => 2.0] + lorenz2.z => 0.0] p = [lorenz1.σ => 10.0, lorenz1.ρ => 28.0, @@ -68,7 +67,7 @@ gs_ts, = Zygote.gradient(sol) do sol sum(sum.(sol[[lorenz1.x, lorenz2.x], :])) end -@test all(map(x -> x == true_grad_vecsym, gs_ts)) +@test all(map(x -> x == true_grad_vecsym, gs_ts.u)) # BatchedInterface AD @variables x(t)=1.0 y(t)=1.0 z(t)=1.0 w(t)=1.0 diff --git a/test/downstream/comprehensive_indexing.jl b/test/downstream/comprehensive_indexing.jl index 0081bebd52..60b77e52d4 100644 --- a/test/downstream/comprehensive_indexing.jl +++ b/test/downstream/comprehensive_indexing.jl @@ -124,11 +124,7 @@ timeseries_systems = [osys, ssys, jsys] set! = setsym(indp, sym) @inferred get(valp) @test get(valp) == val - if valp isa JumpProblem && sym isa Union{Tuple, AbstractArray} - @test_broken valp[sym] - else - @test valp[sym] == val - end + @test valp[sym] == val if !(valp isa SciMLBase.AbstractNoTimeSolution) @inferred set!(valp, newval) @@ -872,7 +868,7 @@ end ud2interp = ud2val[2:4] c1 = SciMLBase.Clock(0.1) - c2 = SciMLBase.SolverStepClock + c2 = SciMLBase.SolverStepClock() for (sym, t, val) in [ (x, c1[2], xinterp[1]), (x, c1[2:4], xinterp), diff --git a/test/downstream/modelingtoolkit_remake.jl b/test/downstream/modelingtoolkit_remake.jl index 34beb3cf39..80703ea342 100644 --- a/test/downstream/modelingtoolkit_remake.jl +++ b/test/downstream/modelingtoolkit_remake.jl @@ -6,6 +6,7 @@ using Optimization using OptimizationOptimJL using ForwardDiff using SciMLStructures +using Test probs = [] syss = [] @@ -67,30 +68,21 @@ push!(probs, OptimizationProblem(optsys, u0, p)) k = ShiftIndex(t) @mtkbuild discsys = DiscreteSystem( [x ~ x(k - 1) * ρ + y(k - 2), y ~ y(k - 1) * σ - z(k - 2), z ~ z(k - 1) * β + x(k - 2)], - t) + t; defaults = [x => 1.0, y => 1.0, z => 1.0]) # Roundabout method to avoid having to specify values for previous timestep -fn = DiscreteFunction(discsys) -ps = ModelingToolkit.MTKParameters(discsys, p) -discu0 = Dict([u0..., x(k - 1) => 0.0, y(k - 1) => 0.0, z(k - 1) => 0.0]) +discprob = DiscreteProblem(discsys, [], (0, 10), p) +for (var, v) in u0 + discprob[var] = v + discprob[var(k-1)] = 0.0 +end push!(syss, discsys) -push!(probs, DiscreteProblem(fn, getindex.((discu0,), unknowns(discsys)), (0, 10), ps)) - -# TODO: Rewrite this example when the MTK codegen is merged -@named sys1 = NonlinearSystem( - [0 ~ x^3 * β + y^3 * ρ - σ, 0 ~ x^2 + 2x * y + y^2], [x, y], [σ, β, ρ]) -sys1 = complete(sys1) -@named sys2 = NonlinearSystem([0 ~ z^2 - 4z + 4], [z], []) -sys2 = complete(sys2) -@named fullsys = NonlinearSystem( +push!(probs, discprob) + +@mtkbuild sys = NonlinearSystem( [0 ~ x^3 * β + y^3 * ρ - σ, 0 ~ x^2 + 2x * y + y^2, 0 ~ z^2 - 4z + 4], [x, y, z], [σ, β, ρ]) -fullsys = complete(fullsys) - -prob1 = NonlinearProblem(sys1, u0, p) -prob2 = NonlinearProblem(sys2, u0, prob1.p) -sccprob = SCCNonlinearProblem( - [prob1, prob2], [Returns(nothing), Returns(nothing)], prob1.p, true; sys = fullsys) -push!(syss, fullsys) +sccprob = SCCNonlinearProblem(sys, u0, p) +push!(syss, sys) push!(probs, sccprob) for (sys, prob) in zip(syss, probs) @@ -273,7 +265,9 @@ end function SciMLBase.detect_cycles( ::ModelingToolkit.AbstractSystem, varmap::Dict{Any, Any}, vars) for sym in vars - if symbolic_type(ModelingToolkit.fixpoint_sub(sym, varmap; maxiters = 10)) != + newval = ModelingToolkit.fixpoint_sub(sym, varmap; maxiters = 10) + vs = ModelingToolkit.vars(newval) + if !isempty(vars) && any(in(Set(vars)), vs) NotSymbolic() return true end @@ -296,15 +290,9 @@ end end @testset "SCCNonlinearProblem" begin - @named sys1 = NonlinearSystem( - [0 ~ x^3 * β + y^3 * ρ - σ, 0 ~ x^2 + 2x * y + y^2], [x, y], [σ, β, ρ]) - sys1 = complete(sys1) - @named sys2 = NonlinearSystem([0 ~ z^2 - 4z + 4], [z], []) - sys2 = complete(sys2) - @named fullsys = NonlinearSystem( + @mtkbuild fullsys = NonlinearSystem( [0 ~ x^3 * β + y^3 * ρ - σ, 0 ~ x^2 + 2x * y + y^2, 0 ~ z^2 - 4z + 4], [x, y, z], [σ, β, ρ]) - fullsys = complete(fullsys) u0 = [x => 1.0, y => 0.0, @@ -314,15 +302,17 @@ end ρ => 10.0, β => 8 / 3] - prob1 = NonlinearProblem(sys1, u0, p) - prob2 = NonlinearProblem(sys2, u0, prob1.p) - sccprob = SCCNonlinearProblem( - [prob1, prob2], [Returns(nothing), Returns(nothing)], prob1.p, true; sys = fullsys) + sccprob = SCCNonlinearProblem(fullsys, u0, p) sccprob2 = remake(sccprob; u0 = 2ones(3)) @test state_values(sccprob2) ≈ 2ones(3) - @test sccprob2.probs[1].u0 ≈ 2ones(2) - @test sccprob2.probs[2].u0 ≈ 2ones(1) + prob1, prob2 = if length(state_values(sccprob2.probs[1])) == 1 + sccprob2.probs[2], sccprob2.probs[1] + else + sccprob2.probs[1], sccprob2.probs[2] + end + @test prob1.u0 ≈ 2ones(2) + @test prob2.u0 ≈ 2ones(1) @test sccprob2.explicitfuns! !== missing @test sccprob2.f.sys !== missing @@ -333,9 +323,9 @@ end @test_throws ["parameters_alias", "SCCNonlinearProblem"] remake( sccprob; parameters_alias = false, p = [σ => 2.0]) - newp = remake_buffer(sys1, prob1.p, [σ], [3.0]) + newp = remake_buffer(sccprob.f.sys, sccprob.p, [σ], [3.0]) sccprob4 = remake(sccprob; parameters_alias = false, p = newp, - probs = [remake(prob1; p = [σ => 3.0]), prob2]) + probs = [remake(sccprob.probs[1]; p = [σ => 3.0]), sccprob.probs[2]]) @test !sccprob4.parameters_alias @test sccprob4.p !== sccprob4.probs[1].p @test sccprob4.p !== sccprob4.probs[2].p diff --git a/test/downstream/problem_interface.jl b/test/downstream/problem_interface.jl index 91ceae31bc..573f09bc3e 100644 --- a/test/downstream/problem_interface.jl +++ b/test/downstream/problem_interface.jl @@ -294,8 +294,6 @@ prob = SteadyStateProblem(osys, u0, ps) getsym(prob, (:X, :X2))(prob) == (0.1, 0.2) @testset "SCCNonlinearProblem" begin - # TODO: Rewrite this example when the MTK codegen is merged - function fullf!(du, u, p) du[1] = cos(u[2]) - u[1] du[2] = sin(u[1] + u[2]) + u[2] @@ -311,63 +309,10 @@ prob = SteadyStateProblem(osys, u0, ps) @parameters p = 1.0 eqs = Any[0 for _ in 1:8] fullf!(eqs, u, [p]) - @named model = NonlinearSystem(0 .~ eqs, [u...], [p]) - model = complete(model; split = false) - - cache = zeros(4) - cache[1] = 1.0 - - function f1!(du, u, p) - du[1] = cos(u[2]) - u[1] - du[2] = sin(u[1] + u[2]) + u[2] - end - explicitfun1(cache, sols) = nothing - - f1!(eqs, u2[1:2], [p]) - @named subsys1 = NonlinearSystem(0 .~ eqs[1:2], [u2[1:2]...], [p]) - subsys1 = complete(subsys1; split = false) - prob1 = NonlinearProblem( - NonlinearFunction{true, SciMLBase.NoSpecialize}(f1!; sys = subsys1), - zeros(2), copy(cache)) - - function f2!(du, u, p) - du[1] = 2u[2] + u[1] + p[1] - du[2] = u[3]^2 + u[2] - du[3] = u[1]^2 + u[3] - end - explicitfun2(cache, sols) = nothing - - f2!(eqs, u2[3:5], [p]) - @named subsys2 = NonlinearSystem(0 .~ eqs[1:3], [u2[3:5]...], [p]) - subsys2 = complete(subsys2; split = false) - prob2 = NonlinearProblem( - NonlinearFunction{true, SciMLBase.NoSpecialize}(f2!; sys = subsys2), - zeros(3), copy(cache)) - - function f3!(du, u, p) - du[1] = p[2] + 2.0u[1] + 2.5u[2] + 1.5u[3] - du[2] = p[3] + 4.0u[1] - 1.5u[2] + 1.5u[3] - du[3] = p[4] + +u[1] - u[2] - u[3] - end - function explicitfun3(cache, sols) - cache[2] = sols[1][1] + sols[1][2] + sols[2][1] + sols[2][2] + sols[2][3] - cache[3] = sols[1][1] + sols[1][2] + sols[2][1] + 2.0sols[2][2] + sols[2][3] - cache[4] = sols[1][1] + 2.0sols[1][2] + 3.0sols[2][1] + 5.0sols[2][2] + - 6.0sols[2][3] - end - - @parameters tmpvar[1:3] - f3!(eqs, u2[6:8], [p, tmpvar...]) - @named subsys3 = NonlinearSystem(0 .~ eqs[1:3], [u2[6:8]...], [p, tmpvar...]) - subsys3 = complete(subsys3; split = false) - prob3 = NonlinearProblem( - NonlinearFunction{true, SciMLBase.NoSpecialize}(f3!; sys = subsys3), - zeros(3), copy(cache)) + @mtkbuild model = NonlinearSystem(0 .~ eqs, [u...], [p]) prob = NonlinearProblem(model, []) - sccprob = SciMLBase.SCCNonlinearProblem([prob1, prob2, prob3], - SciMLBase.Void{Any}.([explicitfun1, explicitfun2, explicitfun3]), - copy(cache); sys = model) + sccprob = SCCNonlinearProblem(model, []) for sym in [u, u..., u[2] + u[3], p * u[1] + u[2]] @test prob[sym] ≈ sccprob[sym] @@ -380,12 +325,10 @@ prob = SteadyStateProblem(osys, u0, ps) for (i, sym) in enumerate([u[1], u[3], u[6]]) sccprob[sym] = 0.5i @test sccprob[sym] ≈ 0.5i - @test sccprob.probs[i].u0[1] ≈ 0.5i end sccprob.ps[p] = 2.5 @test sccprob.ps[p] ≈ 2.5 - @test sccprob.p[1] ≈ 2.5 for scc in sccprob.probs - @test parameter_values(scc)[1] ≈ 2.5 + @test scc.ps[p] ≈ 2.5 end end diff --git a/test/downstream/remake_autodiff.jl b/test/downstream/remake_autodiff.jl index ce8abce588..35015f41cb 100644 --- a/test/downstream/remake_autodiff.jl +++ b/test/downstream/remake_autodiff.jl @@ -1,4 +1,5 @@ using OrdinaryDiffEq, ModelingToolkit, Zygote, SciMLSensitivity +using SymbolicIndexingInterface using ModelingToolkit: t_nounits as t, D_nounits as D @variables x(t) o(t) @@ -17,8 +18,8 @@ end lotka_volterra_sys = structural_simplify(lotka_volterra_sys, split = false) prob = ODEProblem(lotka_volterra_sys, [], (0.0, 10.0), []) sol = solve(prob, Tsit5(), reltol = 1e-6, abstol = 1e-6) -u0 = [1.0, 1.0] -p = [1.5, 1.0, 1.0, 1.0] +setter = setsym_oop(prob, [unknowns(lotka_volterra_sys); parameters(lotka_volterra_sys)]) +u0, p = setter(prob, [1.0, 1.0, 1.5, 1.0, 1.0, 1.0]) function sum_of_solution(u0, p) _prob = remake(prob, u0 = u0, p = p) diff --git a/test/downstream/solution_interface.jl b/test/downstream/solution_interface.jl index 6d363967d1..b59424dd2d 100644 --- a/test/downstream/solution_interface.jl +++ b/test/downstream/solution_interface.jl @@ -182,8 +182,7 @@ end xidx = variable_index(sys, x) prob = ODEProblem(sys, [x => 1.0, y => 1.0], (0.0, 5.0), [p => 0.5]) - @test SciMLBase.SavedSubsystem(sys, prob.p, []) === - SciMLBase.SavedSubsystem(sys, prob.p, nothing) === nothing + @test SciMLBase.SavedSubsystem(sys, prob.p, nothing) === nothing @test SciMLBase.SavedSubsystem(sys, prob.p, [x, y]) === nothing @test begin ss1 = SciMLBase.SavedSubsystem(sys, prob.p, [x]) @@ -319,12 +318,12 @@ end prob = ODEProblem(sys, [x => 1.0, y => 1.0], (0.0, 5.0), [p => 0.5, q => 0.0, r => 1.0, s => 10.0, u => 4096.0]) - _idxs, _ss = SciMLBase.get_save_idxs_and_saved_subsystem(prob, nothing) + _idxs, _ss = @inferred SciMLBase.get_save_idxs_and_saved_subsystem(prob, nothing) @test _idxs === _ss === nothing - _idxs, _ss = SciMLBase.get_save_idxs_and_saved_subsystem(prob, 1) + _idxs, _ss = @inferred SciMLBase.get_save_idxs_and_saved_subsystem(prob, 1) @test _idxs == 1 @test _ss isa SciMLBase.SavedSubsystem - _idxs, _ss = SciMLBase.get_save_idxs_and_saved_subsystem(prob, [1]) + _idxs, _ss = @inferred SciMLBase.get_save_idxs_and_saved_subsystem(prob, [1]) @test _idxs == [1] @test _ss isa SciMLBase.SavedSubsystem _idxs, _ss = SciMLBase.get_save_idxs_and_saved_subsystem(prob, x)