From bbeb0cedaf692a4fb5cc9ab80cdc78e0ecdd73cb Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Wed, 26 Feb 2025 13:44:05 +0530 Subject: [PATCH 01/22] fix: fix `nothing` pullback for parameter object in Zygote adjoint --- ext/SciMLBaseZygoteExt.jl | 3 +++ 1 file changed, 3 insertions(+) diff --git a/ext/SciMLBaseZygoteExt.jl b/ext/SciMLBaseZygoteExt.jl index fa67f06464..9aed585b74 100644 --- a/ext/SciMLBaseZygoteExt.jl +++ b/ext/SciMLBaseZygoteExt.jl @@ -48,6 +48,9 @@ end grz = pullback(getter, sym, VA.u[j], VA.prob.p, VA.t[j])[2](Δ) du = [k == j ? grz[2] : zero(VA.u[1]) for k in 1:length(VA.u)] dp = grz[3] # pullback for p + if dp === nothing + dp = parameter_values(VA) + end dprob = remake(VA.prob, p = dp) du, dprob else From 7a54aab2436efddd74af2f22ae9df5948f745bef Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Wed, 26 Feb 2025 13:44:11 +0530 Subject: [PATCH 02/22] test: fix remake autodiff test --- test/downstream/remake_autodiff.jl | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/test/downstream/remake_autodiff.jl b/test/downstream/remake_autodiff.jl index ce8abce588..35015f41cb 100644 --- a/test/downstream/remake_autodiff.jl +++ b/test/downstream/remake_autodiff.jl @@ -1,4 +1,5 @@ using OrdinaryDiffEq, ModelingToolkit, Zygote, SciMLSensitivity +using SymbolicIndexingInterface using ModelingToolkit: t_nounits as t, D_nounits as D @variables x(t) o(t) @@ -17,8 +18,8 @@ end lotka_volterra_sys = structural_simplify(lotka_volterra_sys, split = false) prob = ODEProblem(lotka_volterra_sys, [], (0.0, 10.0), []) sol = solve(prob, Tsit5(), reltol = 1e-6, abstol = 1e-6) -u0 = [1.0, 1.0] -p = [1.5, 1.0, 1.0, 1.0] +setter = setsym_oop(prob, [unknowns(lotka_volterra_sys); parameters(lotka_volterra_sys)]) +u0, p = setter(prob, [1.0, 1.0, 1.5, 1.0, 1.0, 1.0]) function sum_of_solution(u0, p) _prob = remake(prob, u0 = u0, p = p) From ce6ca68141d7e7ecf3b70042b59ca46df338265c Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Wed, 26 Feb 2025 14:01:00 +0530 Subject: [PATCH 03/22] fix: improve inference for non-symbolic `save_idxs` --- src/solutions/save_idxs.jl | 57 +++++++++++++++++++++++++------------- 1 file changed, 37 insertions(+), 20 deletions(-) diff --git a/src/solutions/save_idxs.jl b/src/solutions/save_idxs.jl index 3d86f306a2..d58095c2f4 100644 --- a/src/solutions/save_idxs.jl +++ b/src/solutions/save_idxs.jl @@ -105,6 +105,20 @@ struct SavedSubsystem{V, T, M, I, P, Q, C} partition_count::C end +SavedSubsystem(indp, pobj, ::Nothing) = nothing + +function SavedSubsystem(indp, pobj, saved_idxs::Vector{Int}) + isempty(saved_idxs) && return nothing + isempty(variable_symbols(indp)) && return nothing + state_map = Dict{Int, Int}(k => v for (k, v) in enumerate(saved_idxs)) + return SavedSubsystem(state_map, nothing, nothing, nothing, nothing, nothing, nothing) +end + +function SavedSubsystem(indp, pobj, idx::Int) + state_map = Dict(1 => idx) + return SavedSubsystem(state_map, nothing, nothing, nothing, nothing, nothing, nothing) +end + function SavedSubsystem(indp, pobj, saved_idxs) # nothing saved if saved_idxs === nothing || isempty(saved_idxs) @@ -357,29 +371,32 @@ corresponding to the state variables and a `SavedSubsystem` to pass to `build_so The second return value (corresponding to the `SavedSubsystem`) may be `nothing` in case one is not required. `save_idxs` may be a scalar or `nothing`. """ +get_save_idxs_and_saved_subsystem(prob, ::Nothing) = nothing, nothing +function get_save_idxs_and_saved_subsystem(prob, save_idxs::Vector{Int}) + save_idxs, SavedSubsystem(prob, parameter_values(prob), save_idxs) +end +function get_save_idxs_and_saved_subsystem(prob, save_idx::Int) + save_idx, SavedSubsystem(prob, parameter_values(prob), save_idx) +end function get_save_idxs_and_saved_subsystem(prob, save_idxs) - if save_idxs === nothing - saved_subsystem = nothing + if !(save_idxs isa AbstractArray) || symbolic_type(save_idxs) != NotSymbolic() + _save_idxs = [save_idxs] else - if !(save_idxs isa AbstractArray) || symbolic_type(save_idxs) != NotSymbolic() - _save_idxs = [save_idxs] + _save_idxs = save_idxs + end + saved_subsystem = SavedSubsystem(prob, parameter_values(prob), _save_idxs) + if saved_subsystem !== nothing + _save_idxs = get_saved_state_idxs(saved_subsystem) + if isempty(_save_idxs) + # no states to save + save_idxs = Int[] + elseif !(save_idxs isa AbstractArray) || + symbolic_type(save_idxs) != NotSymbolic() + # only a single state to save, and save it as a scalar timeseries instead of + # single-element array + save_idxs = only(_save_idxs) else - _save_idxs = save_idxs - end - saved_subsystem = SavedSubsystem(prob, parameter_values(prob), _save_idxs) - if saved_subsystem !== nothing - _save_idxs = get_saved_state_idxs(saved_subsystem) - if isempty(_save_idxs) - # no states to save - save_idxs = Int[] - elseif !(save_idxs isa AbstractArray) || - symbolic_type(save_idxs) != NotSymbolic() - # only a single state to save, and save it as a scalar timeseries instead of - # single-element array - save_idxs = only(_save_idxs) - else - save_idxs = _save_idxs - end + save_idxs = _save_idxs end end From 66ac0fd58ad8ae4515914c2f20459aeea411fc85 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Wed, 26 Feb 2025 14:14:53 +0530 Subject: [PATCH 04/22] fix: fix `constructorof` for `SDEProblem` handling `f::AbstractSDEFunction` --- src/problems/sde_problems.jl | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/problems/sde_problems.jl b/src/problems/sde_problems.jl index 5cc10db511..68311d22f1 100644 --- a/src/problems/sde_problems.jl +++ b/src/problems/sde_problems.jl @@ -129,10 +129,14 @@ function ConstructionBase.constructorof(::Type{P}) where {P <: SDEProblem} function ctor(f, g, u0, tspan, p, noise, kw, noise_rate_prototype, seed) if f isa AbstractSDEFunction iip = isinplace(f) + if g !== f.g + f = remake(f; g) + end + return SDEProblem{iip}(f, u0, tspan, p; kw..., noise, noise_rate_prototype, seed) else iip = isinplace(f, 4) + return SDEProblem{iip}(f, g, u0, tspan, p; kw..., noise, noise_rate_prototype, seed) end - return SDEProblem{iip}(f, g, u0, tspan, p; kw..., noise, noise_rate_prototype, seed) end end From a1aec2411d98310be106ac489ac956218e1dd22b Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Wed, 26 Feb 2025 14:31:10 +0530 Subject: [PATCH 05/22] ci: add SciMLSensitivity Core6 and Core7 test groups to downstream CI --- .github/workflows/Downstream.yml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.github/workflows/Downstream.yml b/.github/workflows/Downstream.yml index c17dfa0bce..9558c6ba08 100644 --- a/.github/workflows/Downstream.yml +++ b/.github/workflows/Downstream.yml @@ -49,6 +49,8 @@ jobs: - {user: SciML, repo: SciMLSensitivity.jl, group: Core3} - {user: SciML, repo: SciMLSensitivity.jl, group: Core4} - {user: SciML, repo: SciMLSensitivity.jl, group: Core5} + - {user: SciML, repo: SciMLSensitivity.jl, group: Core6} + - {user: SciML, repo: SciMLSensitivity.jl, group: Core7} - {user: SciML, repo: Catalyst.jl, group: All} steps: From aa06f6d30bbf8d6154eeb7aba46b320d4ef39955 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Wed, 26 Feb 2025 21:03:49 +0530 Subject: [PATCH 06/22] fix: replace Zygote adjoint with ChainRulesCore adjoint --- ext/SciMLBaseChainRulesCoreExt.jl | 51 +++++++++---------------------- ext/SciMLBaseZygoteExt.jl | 30 ------------------ 2 files changed, 14 insertions(+), 67 deletions(-) diff --git a/ext/SciMLBaseChainRulesCoreExt.jl b/ext/SciMLBaseChainRulesCoreExt.jl index 934a0a1dfa..6aea80a0ab 100644 --- a/ext/SciMLBaseChainRulesCoreExt.jl +++ b/ext/SciMLBaseChainRulesCoreExt.jl @@ -1,8 +1,9 @@ module SciMLBaseChainRulesCoreExt using SciMLBase +using SciMLBase: getobserved import ChainRulesCore -import ChainRulesCore: NoTangent, @non_differentiable +import ChainRulesCore: NoTangent, @non_differentiable, zero_tangent, rrule_via_ad using SymbolicIndexingInterface function ChainRulesCore.rrule( @@ -15,52 +16,28 @@ function ChainRulesCore.rrule( j::Integer) function ODESolution_getindex_pullback(Δ) i = symbolic_type(sym) != NotSymbolic() ? variable_index(VA, sym) : sym - if i === nothing + du, dprob = if i === nothing getter = getobserved(VA) grz = rrule_via_ad(config, getter, sym, VA.u[j], VA.prob.p, VA.t[j])[2](Δ) - du = [k == j ? grz[2] : zero(VA.u[1]) for k in 1:length(VA.u)] - dp = grz[3] # pullback for p + du = [k == j ? grz[3] : zero(VA.u[1]) for k in 1:length(VA.u)] + dp = grz[4] # pullback for p + if dp == NoTangent() + dp = zero_tangent(parameter_values(VA.prob)) + end dprob = remake(VA.prob, p = dp) T = eltype(eltype(VA.u)) N = length(VA.prob.p) - Δ′ = ODESolution{T, N, typeof(du), Nothing, Nothing, Nothing, Nothing, - typeof(dprob), Nothing, Nothing, Nothing, Nothing}(du, nothing, - nothing, nothing, nothing, dprob, nothing, nothing, - VA.dense, 0, nothing, nothing, VA.retcode) - (NoTangent(), Δ′, NoTangent(), NoTangent()) + du, dprob else du = [m == j ? [i == k ? Δ : zero(VA.u[1][1]) for k in 1:length(VA.u[1])] : zero(VA.u[1]) for m in 1:length(VA.u)] - dp = zero(VA.prob.p) + dp = zero_tangent(VA.prob.p) dprob = remake(VA.prob, p = dp) - Δ′ = ODESolution{ - T, - N, - typeof(du), - Nothing, - Nothing, - typeof(VA.t), - typeof(VA.k), - typeof(dprob), - typeof(VA.alg), - typeof(VA.interp), - typeof(VA.alg_choice), - typeof(VA.stats) - }(du, - nothing, - nothing, - VA.t, - VA.k, - dprob, - VA.alg, - VA.interp, - VA.dense, - 0, - VA.stats, - VA.alg_choice, - VA.retcode) - (NoTangent(), Δ′, NoTangent(), NoTangent()) + du, dprob end + Δ′ = ODESolution{T, N}(du, nothing, nothing, VA.t, VA.k, nothing, dprob, + VA.alg, VA.interp, VA.dense, 0, VA.stats, VA.alg_choice, VA.retcode) + (NoTangent(), Δ′, NoTangent(), NoTangent()) end VA[sym, j], ODESolution_getindex_pullback end diff --git a/ext/SciMLBaseZygoteExt.jl b/ext/SciMLBaseZygoteExt.jl index 9aed585b74..7ffc79709c 100644 --- a/ext/SciMLBaseZygoteExt.jl +++ b/ext/SciMLBaseZygoteExt.jl @@ -40,36 +40,6 @@ import SciMLStructures VA[i, j], ODESolution_getindex_pullback end -@adjoint function Base.getindex(VA::ODESolution, sym, j::Int) - function ODESolution_getindex_pullback(Δ) - i = symbolic_type(sym) != NotSymbolic() ? variable_index(VA, sym) : sym - du, dprob = if i === nothing - getter = getobserved(VA) - grz = pullback(getter, sym, VA.u[j], VA.prob.p, VA.t[j])[2](Δ) - du = [k == j ? grz[2] : zero(VA.u[1]) for k in 1:length(VA.u)] - dp = grz[3] # pullback for p - if dp === nothing - dp = parameter_values(VA) - end - dprob = remake(VA.prob, p = dp) - du, dprob - else - du = [m == j ? [i == k ? Δ : zero(VA.u[1][1]) for k in 1:length(VA.u[1])] : - zero(VA.u[1]) for m in 1:length(VA.u)] - dp = zero(VA.prob.p) - dprob = remake(VA.prob, p = dp) - du, dprob - end - T = eltype(eltype(VA.u)) - N = ndims(VA) - Δ′ = ODESolution{T, N}(du, nothing, nothing, - VA.t, VA.k, VA.discretes, dprob, VA.alg, VA.interp, VA.dense, 0, VA.stats, - VA.alg_choice, VA.retcode) - (Δ′, nothing, nothing) - end - VA[sym, j], ODESolution_getindex_pullback -end - @adjoint function EnsembleSolution(sim, time, converged, stats) out = EnsembleSolution(sim, time, converged, stats) function EnsembleSolution_adjoint(p̄::AbstractArray{T, N}) where {T, N} From e60a118d5e4e72412f8163f1a5e5c9a9eb7614a8 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Thu, 27 Feb 2025 00:33:42 +0530 Subject: [PATCH 07/22] test: bump MTK compat in downstream CI --- test/downstream/Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/downstream/Project.toml b/test/downstream/Project.toml index c390a258f4..5e65ebbf4f 100644 --- a/test/downstream/Project.toml +++ b/test/downstream/Project.toml @@ -34,7 +34,7 @@ DelayDiffEq = "5" DiffEqCallbacks = "3, 4" ForwardDiff = "0.10" JumpProcesses = "9.10" -ModelingToolkit = "9.64.1" +ModelingToolkit = "9.64.3" ModelingToolkitStandardLibrary = "2.7" NonlinearSolve = "2, 3, 4" Optimization = "4" From e78e0d54ee871a79bb7d9afe44e3554e6c888fb5 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Thu, 27 Feb 2025 15:44:46 +0530 Subject: [PATCH 08/22] test: fix printing of cyclic dependencies in test --- src/remake.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/remake.jl b/src/remake.jl index a5417a27eb..2d5a59c27e 100644 --- a/src/remake.jl +++ b/src/remake.jl @@ -958,7 +958,7 @@ end function Base.showerror(io::IO, err::CyclicDependencyError) println(io, "Detected cyclic dependency in initial values:") for (k, v) in err.varmap - println(io, k, " => ", "v") + println(io, k, " => ", v) end println(io, "While trying to solve for variables: ", err.vars) end From 3845403162c56be1c14f2f65b7cd0caf44bcf4de Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Thu, 27 Feb 2025 15:45:29 +0530 Subject: [PATCH 09/22] test: fix MTK remake tests, use SCCNonlinearProblem codegen --- test/downstream/modelingtoolkit_remake.jl | 62 ++++++++++------------- 1 file changed, 26 insertions(+), 36 deletions(-) diff --git a/test/downstream/modelingtoolkit_remake.jl b/test/downstream/modelingtoolkit_remake.jl index 34beb3cf39..80703ea342 100644 --- a/test/downstream/modelingtoolkit_remake.jl +++ b/test/downstream/modelingtoolkit_remake.jl @@ -6,6 +6,7 @@ using Optimization using OptimizationOptimJL using ForwardDiff using SciMLStructures +using Test probs = [] syss = [] @@ -67,30 +68,21 @@ push!(probs, OptimizationProblem(optsys, u0, p)) k = ShiftIndex(t) @mtkbuild discsys = DiscreteSystem( [x ~ x(k - 1) * ρ + y(k - 2), y ~ y(k - 1) * σ - z(k - 2), z ~ z(k - 1) * β + x(k - 2)], - t) + t; defaults = [x => 1.0, y => 1.0, z => 1.0]) # Roundabout method to avoid having to specify values for previous timestep -fn = DiscreteFunction(discsys) -ps = ModelingToolkit.MTKParameters(discsys, p) -discu0 = Dict([u0..., x(k - 1) => 0.0, y(k - 1) => 0.0, z(k - 1) => 0.0]) +discprob = DiscreteProblem(discsys, [], (0, 10), p) +for (var, v) in u0 + discprob[var] = v + discprob[var(k-1)] = 0.0 +end push!(syss, discsys) -push!(probs, DiscreteProblem(fn, getindex.((discu0,), unknowns(discsys)), (0, 10), ps)) - -# TODO: Rewrite this example when the MTK codegen is merged -@named sys1 = NonlinearSystem( - [0 ~ x^3 * β + y^3 * ρ - σ, 0 ~ x^2 + 2x * y + y^2], [x, y], [σ, β, ρ]) -sys1 = complete(sys1) -@named sys2 = NonlinearSystem([0 ~ z^2 - 4z + 4], [z], []) -sys2 = complete(sys2) -@named fullsys = NonlinearSystem( +push!(probs, discprob) + +@mtkbuild sys = NonlinearSystem( [0 ~ x^3 * β + y^3 * ρ - σ, 0 ~ x^2 + 2x * y + y^2, 0 ~ z^2 - 4z + 4], [x, y, z], [σ, β, ρ]) -fullsys = complete(fullsys) - -prob1 = NonlinearProblem(sys1, u0, p) -prob2 = NonlinearProblem(sys2, u0, prob1.p) -sccprob = SCCNonlinearProblem( - [prob1, prob2], [Returns(nothing), Returns(nothing)], prob1.p, true; sys = fullsys) -push!(syss, fullsys) +sccprob = SCCNonlinearProblem(sys, u0, p) +push!(syss, sys) push!(probs, sccprob) for (sys, prob) in zip(syss, probs) @@ -273,7 +265,9 @@ end function SciMLBase.detect_cycles( ::ModelingToolkit.AbstractSystem, varmap::Dict{Any, Any}, vars) for sym in vars - if symbolic_type(ModelingToolkit.fixpoint_sub(sym, varmap; maxiters = 10)) != + newval = ModelingToolkit.fixpoint_sub(sym, varmap; maxiters = 10) + vs = ModelingToolkit.vars(newval) + if !isempty(vars) && any(in(Set(vars)), vs) NotSymbolic() return true end @@ -296,15 +290,9 @@ end end @testset "SCCNonlinearProblem" begin - @named sys1 = NonlinearSystem( - [0 ~ x^3 * β + y^3 * ρ - σ, 0 ~ x^2 + 2x * y + y^2], [x, y], [σ, β, ρ]) - sys1 = complete(sys1) - @named sys2 = NonlinearSystem([0 ~ z^2 - 4z + 4], [z], []) - sys2 = complete(sys2) - @named fullsys = NonlinearSystem( + @mtkbuild fullsys = NonlinearSystem( [0 ~ x^3 * β + y^3 * ρ - σ, 0 ~ x^2 + 2x * y + y^2, 0 ~ z^2 - 4z + 4], [x, y, z], [σ, β, ρ]) - fullsys = complete(fullsys) u0 = [x => 1.0, y => 0.0, @@ -314,15 +302,17 @@ end ρ => 10.0, β => 8 / 3] - prob1 = NonlinearProblem(sys1, u0, p) - prob2 = NonlinearProblem(sys2, u0, prob1.p) - sccprob = SCCNonlinearProblem( - [prob1, prob2], [Returns(nothing), Returns(nothing)], prob1.p, true; sys = fullsys) + sccprob = SCCNonlinearProblem(fullsys, u0, p) sccprob2 = remake(sccprob; u0 = 2ones(3)) @test state_values(sccprob2) ≈ 2ones(3) - @test sccprob2.probs[1].u0 ≈ 2ones(2) - @test sccprob2.probs[2].u0 ≈ 2ones(1) + prob1, prob2 = if length(state_values(sccprob2.probs[1])) == 1 + sccprob2.probs[2], sccprob2.probs[1] + else + sccprob2.probs[1], sccprob2.probs[2] + end + @test prob1.u0 ≈ 2ones(2) + @test prob2.u0 ≈ 2ones(1) @test sccprob2.explicitfuns! !== missing @test sccprob2.f.sys !== missing @@ -333,9 +323,9 @@ end @test_throws ["parameters_alias", "SCCNonlinearProblem"] remake( sccprob; parameters_alias = false, p = [σ => 2.0]) - newp = remake_buffer(sys1, prob1.p, [σ], [3.0]) + newp = remake_buffer(sccprob.f.sys, sccprob.p, [σ], [3.0]) sccprob4 = remake(sccprob; parameters_alias = false, p = newp, - probs = [remake(prob1; p = [σ => 3.0]), prob2]) + probs = [remake(sccprob.probs[1]; p = [σ => 3.0]), sccprob.probs[2]]) @test !sccprob4.parameters_alias @test sccprob4.p !== sccprob4.probs[1].p @test sccprob4.p !== sccprob4.probs[2].p From 172a6bcc1533030994159aed11b3aaf93aeb19fd Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Thu, 27 Feb 2025 17:28:33 +0530 Subject: [PATCH 10/22] fix: fix definition of `T` and `N` in ChainRulesCore adjoint --- ext/SciMLBaseChainRulesCoreExt.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ext/SciMLBaseChainRulesCoreExt.jl b/ext/SciMLBaseChainRulesCoreExt.jl index 6aea80a0ab..278bb00350 100644 --- a/ext/SciMLBaseChainRulesCoreExt.jl +++ b/ext/SciMLBaseChainRulesCoreExt.jl @@ -25,8 +25,6 @@ function ChainRulesCore.rrule( dp = zero_tangent(parameter_values(VA.prob)) end dprob = remake(VA.prob, p = dp) - T = eltype(eltype(VA.u)) - N = length(VA.prob.p) du, dprob else du = [m == j ? [i == k ? Δ : zero(VA.u[1][1]) for k in 1:length(VA.u[1])] : @@ -35,6 +33,8 @@ function ChainRulesCore.rrule( dprob = remake(VA.prob, p = dp) du, dprob end + T = eltype(eltype(du)) + N = ndims(eltype(du)) + 1 Δ′ = ODESolution{T, N}(du, nothing, nothing, VA.t, VA.k, nothing, dprob, VA.alg, VA.interp, VA.dense, 0, VA.stats, VA.alg_choice, VA.retcode) (NoTangent(), Δ′, NoTangent(), NoTangent()) From 2236bd0626c62ac13301123b94024294bcee9721 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Fri, 28 Feb 2025 12:09:36 +0530 Subject: [PATCH 11/22] test: `JumpProblem` indexing test is no longer broken --- test/downstream/comprehensive_indexing.jl | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/test/downstream/comprehensive_indexing.jl b/test/downstream/comprehensive_indexing.jl index 0081bebd52..3e74bb3b6b 100644 --- a/test/downstream/comprehensive_indexing.jl +++ b/test/downstream/comprehensive_indexing.jl @@ -124,11 +124,7 @@ timeseries_systems = [osys, ssys, jsys] set! = setsym(indp, sym) @inferred get(valp) @test get(valp) == val - if valp isa JumpProblem && sym isa Union{Tuple, AbstractArray} - @test_broken valp[sym] - else - @test valp[sym] == val - end + @test valp[sym] == val if !(valp isa SciMLBase.AbstractNoTimeSolution) @inferred set!(valp, newval) From 29bbc3478b7cebc1df5498969413360ea7831f24 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Fri, 28 Feb 2025 12:10:47 +0530 Subject: [PATCH 12/22] test: fix usage of `SolverStepClock` in test --- test/downstream/comprehensive_indexing.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/downstream/comprehensive_indexing.jl b/test/downstream/comprehensive_indexing.jl index 3e74bb3b6b..60b77e52d4 100644 --- a/test/downstream/comprehensive_indexing.jl +++ b/test/downstream/comprehensive_indexing.jl @@ -868,7 +868,7 @@ end ud2interp = ud2val[2:4] c1 = SciMLBase.Clock(0.1) - c2 = SciMLBase.SolverStepClock + c2 = SciMLBase.SolverStepClock() for (sym, t, val) in [ (x, c1[2], xinterp[1]), (x, c1[2:4], xinterp), From 795047ee9b27cd626261abfd8682062553944cc8 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Fri, 28 Feb 2025 15:04:59 +0530 Subject: [PATCH 13/22] test: fix adjoints test --- test/downstream/adjoints.jl | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/test/downstream/adjoints.jl b/test/downstream/adjoints.jl index 4e75e19b6f..9269de6b40 100644 --- a/test/downstream/adjoints.jl +++ b/test/downstream/adjoints.jl @@ -22,8 +22,7 @@ u0 = [lorenz1.x => 1.0, lorenz1.z => 0.0, lorenz2.x => 0.0, lorenz2.y => 1.0, - lorenz2.z => 0.0, - a => 2.0] + lorenz2.z => 0.0] p = [lorenz1.σ => 10.0, lorenz1.ρ => 28.0, @@ -68,7 +67,7 @@ gs_ts, = Zygote.gradient(sol) do sol sum(sum.(sol[[lorenz1.x, lorenz2.x], :])) end -@test all(map(x -> x == true_grad_vecsym, gs_ts)) +@test all(map(x -> x == true_grad_vecsym, gs_ts.u)) # BatchedInterface AD @variables x(t)=1.0 y(t)=1.0 z(t)=1.0 w(t)=1.0 From 73cb4d0e2f464b8a2e74bdcce798367cc3dafb55 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Fri, 28 Feb 2025 15:05:12 +0530 Subject: [PATCH 14/22] test: fix `SCCNonlinearProblem` indexing test --- test/downstream/problem_interface.jl | 63 ++-------------------------- 1 file changed, 3 insertions(+), 60 deletions(-) diff --git a/test/downstream/problem_interface.jl b/test/downstream/problem_interface.jl index 91ceae31bc..573f09bc3e 100644 --- a/test/downstream/problem_interface.jl +++ b/test/downstream/problem_interface.jl @@ -294,8 +294,6 @@ prob = SteadyStateProblem(osys, u0, ps) getsym(prob, (:X, :X2))(prob) == (0.1, 0.2) @testset "SCCNonlinearProblem" begin - # TODO: Rewrite this example when the MTK codegen is merged - function fullf!(du, u, p) du[1] = cos(u[2]) - u[1] du[2] = sin(u[1] + u[2]) + u[2] @@ -311,63 +309,10 @@ prob = SteadyStateProblem(osys, u0, ps) @parameters p = 1.0 eqs = Any[0 for _ in 1:8] fullf!(eqs, u, [p]) - @named model = NonlinearSystem(0 .~ eqs, [u...], [p]) - model = complete(model; split = false) - - cache = zeros(4) - cache[1] = 1.0 - - function f1!(du, u, p) - du[1] = cos(u[2]) - u[1] - du[2] = sin(u[1] + u[2]) + u[2] - end - explicitfun1(cache, sols) = nothing - - f1!(eqs, u2[1:2], [p]) - @named subsys1 = NonlinearSystem(0 .~ eqs[1:2], [u2[1:2]...], [p]) - subsys1 = complete(subsys1; split = false) - prob1 = NonlinearProblem( - NonlinearFunction{true, SciMLBase.NoSpecialize}(f1!; sys = subsys1), - zeros(2), copy(cache)) - - function f2!(du, u, p) - du[1] = 2u[2] + u[1] + p[1] - du[2] = u[3]^2 + u[2] - du[3] = u[1]^2 + u[3] - end - explicitfun2(cache, sols) = nothing - - f2!(eqs, u2[3:5], [p]) - @named subsys2 = NonlinearSystem(0 .~ eqs[1:3], [u2[3:5]...], [p]) - subsys2 = complete(subsys2; split = false) - prob2 = NonlinearProblem( - NonlinearFunction{true, SciMLBase.NoSpecialize}(f2!; sys = subsys2), - zeros(3), copy(cache)) - - function f3!(du, u, p) - du[1] = p[2] + 2.0u[1] + 2.5u[2] + 1.5u[3] - du[2] = p[3] + 4.0u[1] - 1.5u[2] + 1.5u[3] - du[3] = p[4] + +u[1] - u[2] - u[3] - end - function explicitfun3(cache, sols) - cache[2] = sols[1][1] + sols[1][2] + sols[2][1] + sols[2][2] + sols[2][3] - cache[3] = sols[1][1] + sols[1][2] + sols[2][1] + 2.0sols[2][2] + sols[2][3] - cache[4] = sols[1][1] + 2.0sols[1][2] + 3.0sols[2][1] + 5.0sols[2][2] + - 6.0sols[2][3] - end - - @parameters tmpvar[1:3] - f3!(eqs, u2[6:8], [p, tmpvar...]) - @named subsys3 = NonlinearSystem(0 .~ eqs[1:3], [u2[6:8]...], [p, tmpvar...]) - subsys3 = complete(subsys3; split = false) - prob3 = NonlinearProblem( - NonlinearFunction{true, SciMLBase.NoSpecialize}(f3!; sys = subsys3), - zeros(3), copy(cache)) + @mtkbuild model = NonlinearSystem(0 .~ eqs, [u...], [p]) prob = NonlinearProblem(model, []) - sccprob = SciMLBase.SCCNonlinearProblem([prob1, prob2, prob3], - SciMLBase.Void{Any}.([explicitfun1, explicitfun2, explicitfun3]), - copy(cache); sys = model) + sccprob = SCCNonlinearProblem(model, []) for sym in [u, u..., u[2] + u[3], p * u[1] + u[2]] @test prob[sym] ≈ sccprob[sym] @@ -380,12 +325,10 @@ prob = SteadyStateProblem(osys, u0, ps) for (i, sym) in enumerate([u[1], u[3], u[6]]) sccprob[sym] = 0.5i @test sccprob[sym] ≈ 0.5i - @test sccprob.probs[i].u0[1] ≈ 0.5i end sccprob.ps[p] = 2.5 @test sccprob.ps[p] ≈ 2.5 - @test sccprob.p[1] ≈ 2.5 for scc in sccprob.probs - @test parameter_values(scc)[1] ≈ 2.5 + @test scc.ps[p] ≈ 2.5 end end From fb2bc238d55bfced076f6a18db7269d81f381b11 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Mon, 3 Mar 2025 12:20:49 +0530 Subject: [PATCH 15/22] fix: improve inference of `get_save_idxs_and_saved_subsystem` --- src/scimlfunctions.jl | 4 +++- src/solutions/save_idxs.jl | 31 +++++++++++---------------- test/downstream/solution_interface.jl | 9 ++++---- 3 files changed, 19 insertions(+), 25 deletions(-) diff --git a/src/scimlfunctions.jl b/src/scimlfunctions.jl index a9b2622943..139afd3b00 100644 --- a/src/scimlfunctions.jl +++ b/src/scimlfunctions.jl @@ -4818,8 +4818,10 @@ for S in [:ODEFunction end end +const EMPTY_SYMBOLCACHE = SymbolCache() + function SymbolicIndexingInterface.symbolic_container(fn::AbstractSciMLFunction) - has_sys(fn) ? fn.sys : SymbolCache() + has_sys(fn) ? fn.sys : EMPTY_SYMBOLCACHE end function SymbolicIndexingInterface.is_observed(fn::AbstractSciMLFunction, sym) diff --git a/src/solutions/save_idxs.jl b/src/solutions/save_idxs.jl index d58095c2f4..c74ebd6931 100644 --- a/src/solutions/save_idxs.jl +++ b/src/solutions/save_idxs.jl @@ -44,9 +44,11 @@ function as_diffeq_array(vt::Vector{VectorTemplate}, t) return DiffEqArray(typeof(TupleOfArraysWrapper(vt))[], t, (1, 1)) end -function is_empty_indp(indp) - isempty(variable_symbols(indp)) && isempty(parameter_symbols(indp)) && - isempty(independent_variable_symbols(indp)) +function get_root_indp(indp) + if hasmethod(symbolic_container, Tuple{typeof(indp)}) && (sc = symbolic_container(indp)) !== indp + return get_root_indp(sc) + end + return indp end # Everything from this point on is public API @@ -107,28 +109,19 @@ end SavedSubsystem(indp, pobj, ::Nothing) = nothing -function SavedSubsystem(indp, pobj, saved_idxs::Vector{Int}) - isempty(saved_idxs) && return nothing - isempty(variable_symbols(indp)) && return nothing - state_map = Dict{Int, Int}(k => v for (k, v) in enumerate(saved_idxs)) - return SavedSubsystem(state_map, nothing, nothing, nothing, nothing, nothing, nothing) -end - function SavedSubsystem(indp, pobj, idx::Int) state_map = Dict(1 => idx) return SavedSubsystem(state_map, nothing, nothing, nothing, nothing, nothing, nothing) end -function SavedSubsystem(indp, pobj, saved_idxs) - # nothing saved - if saved_idxs === nothing || isempty(saved_idxs) +function SavedSubsystem(indp, pobj, saved_idxs::Union{Array, Tuple}) + _indp = get_root_indp(indp) + if indp === EMPTY_SYMBOLCACHE || indp === nothing return nothing end - - # this is required because problems with no system have an empty `SymbolCache` - # as their symbolic container. - if is_empty_indp(indp) - return nothing + if eltype(saved_idxs) == Int + state_map = Dict{Int, Int}(k => v for (k, v) in enumerate(saved_idxs)) + return SavedSubsystem(state_map, nothing, nothing, nothing, nothing, nothing, nothing) end # array state symbolics must be scalarized @@ -380,7 +373,7 @@ function get_save_idxs_and_saved_subsystem(prob, save_idx::Int) end function get_save_idxs_and_saved_subsystem(prob, save_idxs) if !(save_idxs isa AbstractArray) || symbolic_type(save_idxs) != NotSymbolic() - _save_idxs = [save_idxs] + _save_idxs = (save_idxs,) else _save_idxs = save_idxs end diff --git a/test/downstream/solution_interface.jl b/test/downstream/solution_interface.jl index 6d363967d1..b59424dd2d 100644 --- a/test/downstream/solution_interface.jl +++ b/test/downstream/solution_interface.jl @@ -182,8 +182,7 @@ end xidx = variable_index(sys, x) prob = ODEProblem(sys, [x => 1.0, y => 1.0], (0.0, 5.0), [p => 0.5]) - @test SciMLBase.SavedSubsystem(sys, prob.p, []) === - SciMLBase.SavedSubsystem(sys, prob.p, nothing) === nothing + @test SciMLBase.SavedSubsystem(sys, prob.p, nothing) === nothing @test SciMLBase.SavedSubsystem(sys, prob.p, [x, y]) === nothing @test begin ss1 = SciMLBase.SavedSubsystem(sys, prob.p, [x]) @@ -319,12 +318,12 @@ end prob = ODEProblem(sys, [x => 1.0, y => 1.0], (0.0, 5.0), [p => 0.5, q => 0.0, r => 1.0, s => 10.0, u => 4096.0]) - _idxs, _ss = SciMLBase.get_save_idxs_and_saved_subsystem(prob, nothing) + _idxs, _ss = @inferred SciMLBase.get_save_idxs_and_saved_subsystem(prob, nothing) @test _idxs === _ss === nothing - _idxs, _ss = SciMLBase.get_save_idxs_and_saved_subsystem(prob, 1) + _idxs, _ss = @inferred SciMLBase.get_save_idxs_and_saved_subsystem(prob, 1) @test _idxs == 1 @test _ss isa SciMLBase.SavedSubsystem - _idxs, _ss = SciMLBase.get_save_idxs_and_saved_subsystem(prob, [1]) + _idxs, _ss = @inferred SciMLBase.get_save_idxs_and_saved_subsystem(prob, [1]) @test _idxs == [1] @test _ss isa SciMLBase.SavedSubsystem _idxs, _ss = SciMLBase.get_save_idxs_and_saved_subsystem(prob, x) From 485375059815beb0f3a9a696823d1efd80e78723 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Tue, 4 Mar 2025 12:30:46 +0530 Subject: [PATCH 16/22] fix: improve `remake` performance via `late_binding_update_u0_p` improvements I have no idea why this works --- src/remake.jl | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/src/remake.jl b/src/remake.jl index 2d5a59c27e..16124b17c9 100644 --- a/src/remake.jl +++ b/src/remake.jl @@ -1085,10 +1085,6 @@ calling `SymbolicIndexingInterface.symbolic_container`, provided for dispatch. R the updated `newu0` and `newp`. """ function late_binding_update_u0_p(prob, root_indp, u0, p, t0, newu0, newp) - if hasmethod(symbolic_container, Tuple{typeof(root_indp)}) && - (sc = symbolic_container(root_indp)) !== root_indp - return late_binding_update_u0_p(prob, sc, u0, p, t0, newu0, newp) - end return newu0, newp end @@ -1099,7 +1095,7 @@ Calls `late_binding_update_u0_p(prob, root_indp, u0, p, t0, newu0, newp)` after `root_indp`. """ function late_binding_update_u0_p(prob, u0, p, t0, newu0, newp) - root_indp = prob + root_indp = get_root_indp(prob) return late_binding_update_u0_p(prob, root_indp, u0, p, t0, newu0, newp) end From 9ee2af866c54315e671fc3e9ceda55652874ee03 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Thu, 6 Mar 2025 12:40:32 +0530 Subject: [PATCH 17/22] fix: define `Zygote.@adjoint` to fall back to `ChainRulesCore.rrule` --- Project.toml | 2 +- ext/SciMLBaseZygoteExt.jl | 7 +++++++ 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 6c78004166..d1d1d447b1 100644 --- a/Project.toml +++ b/Project.toml @@ -50,7 +50,7 @@ SciMLBasePartialFunctionsExt = "PartialFunctions" SciMLBasePyCallExt = "PyCall" SciMLBasePythonCallExt = "PythonCall" SciMLBaseRCallExt = "RCall" -SciMLBaseZygoteExt = "Zygote" +SciMLBaseZygoteExt = ["Zygote", "ChainRulesCore"] [compat] ADTypes = "0.2.5,1.0.0" diff --git a/ext/SciMLBaseZygoteExt.jl b/ext/SciMLBaseZygoteExt.jl index 7ffc79709c..0a9edbc8d8 100644 --- a/ext/SciMLBaseZygoteExt.jl +++ b/ext/SciMLBaseZygoteExt.jl @@ -3,6 +3,7 @@ module SciMLBaseZygoteExt using Zygote using Zygote: @adjoint, pullback import Zygote: literal_getproperty +import ChainRulesCore using SciMLBase using SciMLBase: ODESolution, remake, getobserved, build_solution, EnsembleSolution, @@ -40,6 +41,12 @@ import SciMLStructures VA[i, j], ODESolution_getindex_pullback end +struct ZygoteConfig <: ChainRulesCore.RuleConfig{ChainRulesCore.HasReverseMode} end + +@adjoint function Base.getindex(VA::ODESolution, sym, j::Integer) + Zygote.ChainRulesCore.rrule(ZygoteConfig(), getindex, VA, sym, j) +end + @adjoint function EnsembleSolution(sim, time, converged, stats) out = EnsembleSolution(sim, time, converged, stats) function EnsembleSolution_adjoint(p̄::AbstractArray{T, N}) where {T, N} From 2ffb0ee89bfd5eff8e2a3efc11d8a8b8d20d42b0 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Thu, 6 Mar 2025 12:40:54 +0530 Subject: [PATCH 18/22] fix: fix cases where `SavedSubsystem` should be `nothing` --- src/solutions/save_idxs.jl | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/solutions/save_idxs.jl b/src/solutions/save_idxs.jl index c74ebd6931..01c4176ea4 100644 --- a/src/solutions/save_idxs.jl +++ b/src/solutions/save_idxs.jl @@ -110,17 +110,21 @@ end SavedSubsystem(indp, pobj, ::Nothing) = nothing function SavedSubsystem(indp, pobj, idx::Int) + _indp = get_root_indp(indp) + if indp === EMPTY_SYMBOLCACHE || indp === nothing + return nothing + end state_map = Dict(1 => idx) return SavedSubsystem(state_map, nothing, nothing, nothing, nothing, nothing, nothing) end function SavedSubsystem(indp, pobj, saved_idxs::Union{Array, Tuple}) _indp = get_root_indp(indp) - if indp === EMPTY_SYMBOLCACHE || indp === nothing + if _indp === EMPTY_SYMBOLCACHE || _indp === nothing return nothing end if eltype(saved_idxs) == Int - state_map = Dict{Int, Int}(k => v for (k, v) in enumerate(saved_idxs)) + state_map = Dict{Int, Int}(v => k for (k, v) in enumerate(saved_idxs)) return SavedSubsystem(state_map, nothing, nothing, nothing, nothing, nothing, nothing) end From d53a1275ea17d01e44ac917980ce57048b28bd4b Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Thu, 6 Mar 2025 12:42:37 +0530 Subject: [PATCH 19/22] fix: fix `SavedSubsystem` constructor dispatch --- src/solutions/save_idxs.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/solutions/save_idxs.jl b/src/solutions/save_idxs.jl index 01c4176ea4..fe7035a8ae 100644 --- a/src/solutions/save_idxs.jl +++ b/src/solutions/save_idxs.jl @@ -118,7 +118,7 @@ function SavedSubsystem(indp, pobj, idx::Int) return SavedSubsystem(state_map, nothing, nothing, nothing, nothing, nothing, nothing) end -function SavedSubsystem(indp, pobj, saved_idxs::Union{Array, Tuple}) +function SavedSubsystem(indp, pobj, saved_idxs::Union{AbstractArray, Tuple}) _indp = get_root_indp(indp) if _indp === EMPTY_SYMBOLCACHE || _indp === nothing return nothing From a52552c415a3dc1aa30701e76d014f64622ed8bb Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Thu, 6 Mar 2025 13:55:02 +0530 Subject: [PATCH 20/22] fix: handle `rrule` returning extra value in pullback --- ext/SciMLBaseZygoteExt.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/ext/SciMLBaseZygoteExt.jl b/ext/SciMLBaseZygoteExt.jl index 0a9edbc8d8..c458a7f605 100644 --- a/ext/SciMLBaseZygoteExt.jl +++ b/ext/SciMLBaseZygoteExt.jl @@ -44,7 +44,8 @@ end struct ZygoteConfig <: ChainRulesCore.RuleConfig{ChainRulesCore.HasReverseMode} end @adjoint function Base.getindex(VA::ODESolution, sym, j::Integer) - Zygote.ChainRulesCore.rrule(ZygoteConfig(), getindex, VA, sym, j) + res, pullback = Zygote.ChainRulesCore.rrule(ZygoteConfig(), getindex, VA, sym, j) + return res, Base.tail ∘ pullback end @adjoint function EnsembleSolution(sim, time, converged, stats) From 23a14e123b98f1e5db21eaee83bb9751103a8d07 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Thu, 6 Mar 2025 14:51:41 +0530 Subject: [PATCH 21/22] fix: use `Zygote.ZygoteRuleConfig` when forwarding adjoint --- ext/SciMLBaseZygoteExt.jl | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/ext/SciMLBaseZygoteExt.jl b/ext/SciMLBaseZygoteExt.jl index c458a7f605..902228b03f 100644 --- a/ext/SciMLBaseZygoteExt.jl +++ b/ext/SciMLBaseZygoteExt.jl @@ -41,10 +41,8 @@ import SciMLStructures VA[i, j], ODESolution_getindex_pullback end -struct ZygoteConfig <: ChainRulesCore.RuleConfig{ChainRulesCore.HasReverseMode} end - @adjoint function Base.getindex(VA::ODESolution, sym, j::Integer) - res, pullback = Zygote.ChainRulesCore.rrule(ZygoteConfig(), getindex, VA, sym, j) + res, pullback = ChainRulesCore.rrule(Zygote.ZygoteRuleConfig(), getindex, VA, sym, j) return res, Base.tail ∘ pullback end From 5e2e25e3c0202cf0d532b663d5af72adde9326a1 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Thu, 6 Mar 2025 16:35:20 +0530 Subject: [PATCH 22/22] fix: fix detection of absent index provider in `SavedSubsystem` constructor --- src/solutions/save_idxs.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/solutions/save_idxs.jl b/src/solutions/save_idxs.jl index fe7035a8ae..2559eaf349 100644 --- a/src/solutions/save_idxs.jl +++ b/src/solutions/save_idxs.jl @@ -111,7 +111,7 @@ SavedSubsystem(indp, pobj, ::Nothing) = nothing function SavedSubsystem(indp, pobj, idx::Int) _indp = get_root_indp(indp) - if indp === EMPTY_SYMBOLCACHE || indp === nothing + if _indp === EMPTY_SYMBOLCACHE || _indp === nothing return nothing end state_map = Dict(1 => idx)