Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
bbeb0ce
fix: fix `nothing` pullback for parameter object in Zygote adjoint
AayushSabharwal Feb 26, 2025
7a54aab
test: fix remake autodiff test
AayushSabharwal Feb 26, 2025
ce6ca68
fix: improve inference for non-symbolic `save_idxs`
AayushSabharwal Feb 26, 2025
66ac0fd
fix: fix `constructorof` for `SDEProblem` handling `f::AbstractSDEFun…
AayushSabharwal Feb 26, 2025
a1aec24
ci: add SciMLSensitivity Core6 and Core7 test groups to downstream CI
AayushSabharwal Feb 26, 2025
aa06f6d
fix: replace Zygote adjoint with ChainRulesCore adjoint
AayushSabharwal Feb 26, 2025
e60a118
test: bump MTK compat in downstream CI
AayushSabharwal Feb 26, 2025
e78e0d5
test: fix printing of cyclic dependencies in test
AayushSabharwal Feb 27, 2025
3845403
test: fix MTK remake tests, use SCCNonlinearProblem codegen
AayushSabharwal Feb 27, 2025
172a6bc
fix: fix definition of `T` and `N` in ChainRulesCore adjoint
AayushSabharwal Feb 27, 2025
2236bd0
test: `JumpProblem` indexing test is no longer broken
AayushSabharwal Feb 28, 2025
29bbc34
test: fix usage of `SolverStepClock` in test
AayushSabharwal Feb 28, 2025
795047e
test: fix adjoints test
AayushSabharwal Feb 28, 2025
73cb4d0
test: fix `SCCNonlinearProblem` indexing test
AayushSabharwal Feb 28, 2025
fb2bc23
fix: improve inference of `get_save_idxs_and_saved_subsystem`
AayushSabharwal Mar 3, 2025
4853750
fix: improve `remake` performance via `late_binding_update_u0_p` impr…
AayushSabharwal Mar 4, 2025
9ee2af8
fix: define `Zygote.@adjoint` to fall back to `ChainRulesCore.rrule`
AayushSabharwal Mar 6, 2025
2ffb0ee
fix: fix cases where `SavedSubsystem` should be `nothing`
AayushSabharwal Mar 6, 2025
d53a127
fix: fix `SavedSubsystem` constructor dispatch
AayushSabharwal Mar 6, 2025
a52552c
fix: handle `rrule` returning extra value in pullback
AayushSabharwal Mar 6, 2025
23a14e1
fix: use `Zygote.ZygoteRuleConfig` when forwarding adjoint
AayushSabharwal Mar 6, 2025
5e2e25e
fix: fix detection of absent index provider in `SavedSubsystem` const…
AayushSabharwal Mar 6, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .github/workflows/Downstream.yml
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ jobs:
- {user: SciML, repo: SciMLSensitivity.jl, group: Core3}
- {user: SciML, repo: SciMLSensitivity.jl, group: Core4}
- {user: SciML, repo: SciMLSensitivity.jl, group: Core5}
- {user: SciML, repo: SciMLSensitivity.jl, group: Core6}
- {user: SciML, repo: SciMLSensitivity.jl, group: Core7}
- {user: SciML, repo: Catalyst.jl, group: All}

steps:
Expand Down
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ SciMLBasePartialFunctionsExt = "PartialFunctions"
SciMLBasePyCallExt = "PyCall"
SciMLBasePythonCallExt = "PythonCall"
SciMLBaseRCallExt = "RCall"
SciMLBaseZygoteExt = "Zygote"
SciMLBaseZygoteExt = ["Zygote", "ChainRulesCore"]

[compat]
ADTypes = "0.2.5,1.0.0"
Expand Down
55 changes: 16 additions & 39 deletions ext/SciMLBaseChainRulesCoreExt.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
module SciMLBaseChainRulesCoreExt

using SciMLBase
using SciMLBase: getobserved
import ChainRulesCore
import ChainRulesCore: NoTangent, @non_differentiable
import ChainRulesCore: NoTangent, @non_differentiable, zero_tangent, rrule_via_ad
using SymbolicIndexingInterface

function ChainRulesCore.rrule(
Expand All @@ -15,52 +16,28 @@ function ChainRulesCore.rrule(
j::Integer)
function ODESolution_getindex_pullback(Δ)
i = symbolic_type(sym) != NotSymbolic() ? variable_index(VA, sym) : sym
if i === nothing
du, dprob = if i === nothing
getter = getobserved(VA)
grz = rrule_via_ad(config, getter, sym, VA.u[j], VA.prob.p, VA.t[j])[2](Δ)
du = [k == j ? grz[2] : zero(VA.u[1]) for k in 1:length(VA.u)]
dp = grz[3] # pullback for p
du = [k == j ? grz[3] : zero(VA.u[1]) for k in 1:length(VA.u)]
Comment on lines 21 to +22
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

😅 oh no.

dp = grz[4] # pullback for p
if dp == NoTangent()
dp = zero_tangent(parameter_values(VA.prob))
end
dprob = remake(VA.prob, p = dp)
T = eltype(eltype(VA.u))
N = length(VA.prob.p)
Δ′ = ODESolution{T, N, typeof(du), Nothing, Nothing, Nothing, Nothing,
typeof(dprob), Nothing, Nothing, Nothing, Nothing}(du, nothing,
nothing, nothing, nothing, dprob, nothing, nothing,
VA.dense, 0, nothing, nothing, VA.retcode)
(NoTangent(), Δ′, NoTangent(), NoTangent())
du, dprob
else
du = [m == j ? [i == k ? Δ : zero(VA.u[1][1]) for k in 1:length(VA.u[1])] :
zero(VA.u[1]) for m in 1:length(VA.u)]
dp = zero(VA.prob.p)
dp = zero_tangent(VA.prob.p)
dprob = remake(VA.prob, p = dp)
Δ′ = ODESolution{
T,
N,
typeof(du),
Nothing,
Nothing,
typeof(VA.t),
typeof(VA.k),
typeof(dprob),
typeof(VA.alg),
typeof(VA.interp),
typeof(VA.alg_choice),
typeof(VA.stats)
}(du,
nothing,
nothing,
VA.t,
VA.k,
dprob,
VA.alg,
VA.interp,
VA.dense,
0,
VA.stats,
VA.alg_choice,
VA.retcode)
(NoTangent(), Δ′, NoTangent(), NoTangent())
du, dprob
end
T = eltype(eltype(du))
N = ndims(eltype(du)) + 1
Δ′ = ODESolution{T, N}(du, nothing, nothing, VA.t, VA.k, nothing, dprob,
VA.alg, VA.interp, VA.dense, 0, VA.stats, VA.alg_choice, VA.retcode)
(NoTangent(), Δ′, NoTangent(), NoTangent())
end
VA[sym, j], ODESolution_getindex_pullback
end
Expand Down
29 changes: 4 additions & 25 deletions ext/SciMLBaseZygoteExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ module SciMLBaseZygoteExt
using Zygote
using Zygote: @adjoint, pullback
import Zygote: literal_getproperty
import ChainRulesCore
using SciMLBase
using SciMLBase: ODESolution, remake,
getobserved, build_solution, EnsembleSolution,
Expand Down Expand Up @@ -40,31 +41,9 @@ import SciMLStructures
VA[i, j], ODESolution_getindex_pullback
end

@adjoint function Base.getindex(VA::ODESolution, sym, j::Int)
function ODESolution_getindex_pullback(Δ)
i = symbolic_type(sym) != NotSymbolic() ? variable_index(VA, sym) : sym
du, dprob = if i === nothing
getter = getobserved(VA)
grz = pullback(getter, sym, VA.u[j], VA.prob.p, VA.t[j])[2](Δ)
du = [k == j ? grz[2] : zero(VA.u[1]) for k in 1:length(VA.u)]
dp = grz[3] # pullback for p
dprob = remake(VA.prob, p = dp)
du, dprob
else
Comment on lines -43 to -53
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this because the ChainRulesCore one is fixed? If so, just delete instead of comment.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oops, yeah.

du = [m == j ? [i == k ? Δ : zero(VA.u[1][1]) for k in 1:length(VA.u[1])] :
zero(VA.u[1]) for m in 1:length(VA.u)]
dp = zero(VA.prob.p)
dprob = remake(VA.prob, p = dp)
du, dprob
end
T = eltype(eltype(VA.u))
N = ndims(VA)
Δ′ = ODESolution{T, N}(du, nothing, nothing,
VA.t, VA.k, VA.discretes, dprob, VA.alg, VA.interp, VA.dense, 0, VA.stats,
VA.alg_choice, VA.retcode)
(Δ′, nothing, nothing)
end
VA[sym, j], ODESolution_getindex_pullback
@adjoint function Base.getindex(VA::ODESolution, sym, j::Integer)
res, pullback = ChainRulesCore.rrule(Zygote.ZygoteRuleConfig(), getindex, VA, sym, j)
return res, Base.tail ∘ pullback
end

@adjoint function EnsembleSolution(sim, time, converged, stats)
Expand Down
6 changes: 5 additions & 1 deletion src/problems/sde_problems.jl
Original file line number Diff line number Diff line change
Expand Up @@ -129,10 +129,14 @@ function ConstructionBase.constructorof(::Type{P}) where {P <: SDEProblem}
function ctor(f, g, u0, tspan, p, noise, kw, noise_rate_prototype, seed)
if f isa AbstractSDEFunction
iip = isinplace(f)
if g !== f.g
f = remake(f; g)
end
return SDEProblem{iip}(f, u0, tspan, p; kw..., noise, noise_rate_prototype, seed)
else
iip = isinplace(f, 4)
return SDEProblem{iip}(f, g, u0, tspan, p; kw..., noise, noise_rate_prototype, seed)
end
return SDEProblem{iip}(f, g, u0, tspan, p; kw..., noise, noise_rate_prototype, seed)
end
end

Expand Down
8 changes: 2 additions & 6 deletions src/remake.jl
Original file line number Diff line number Diff line change
Expand Up @@ -958,7 +958,7 @@ end
function Base.showerror(io::IO, err::CyclicDependencyError)
println(io, "Detected cyclic dependency in initial values:")
for (k, v) in err.varmap
println(io, k, " => ", "v")
println(io, k, " => ", v)
end
println(io, "While trying to solve for variables: ", err.vars)
end
Expand Down Expand Up @@ -1085,10 +1085,6 @@ calling `SymbolicIndexingInterface.symbolic_container`, provided for dispatch. R
the updated `newu0` and `newp`.
"""
function late_binding_update_u0_p(prob, root_indp, u0, p, t0, newu0, newp)
if hasmethod(symbolic_container, Tuple{typeof(root_indp)}) &&
(sc = symbolic_container(root_indp)) !== root_indp
return late_binding_update_u0_p(prob, sc, u0, p, t0, newu0, newp)
end
return newu0, newp
end

Expand All @@ -1099,7 +1095,7 @@ Calls `late_binding_update_u0_p(prob, root_indp, u0, p, t0, newu0, newp)` after
`root_indp`.
"""
function late_binding_update_u0_p(prob, u0, p, t0, newu0, newp)
root_indp = prob
root_indp = get_root_indp(prob)
return late_binding_update_u0_p(prob, root_indp, u0, p, t0, newu0, newp)
end

Expand Down
4 changes: 3 additions & 1 deletion src/scimlfunctions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4818,8 +4818,10 @@ for S in [:ODEFunction
end
end

const EMPTY_SYMBOLCACHE = SymbolCache()

function SymbolicIndexingInterface.symbolic_container(fn::AbstractSciMLFunction)
has_sys(fn) ? fn.sys : SymbolCache()
has_sys(fn) ? fn.sys : EMPTY_SYMBOLCACHE
end

function SymbolicIndexingInterface.is_observed(fn::AbstractSciMLFunction, sym)
Expand Down
72 changes: 43 additions & 29 deletions src/solutions/save_idxs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,11 @@ function as_diffeq_array(vt::Vector{VectorTemplate}, t)
return DiffEqArray(typeof(TupleOfArraysWrapper(vt))[], t, (1, 1))
end

function is_empty_indp(indp)
isempty(variable_symbols(indp)) && isempty(parameter_symbols(indp)) &&
isempty(independent_variable_symbols(indp))
function get_root_indp(indp)
if hasmethod(symbolic_container, Tuple{typeof(indp)}) && (sc = symbolic_container(indp)) !== indp
return get_root_indp(sc)
end
return indp
end

# Everything from this point on is public API
Expand Down Expand Up @@ -105,17 +107,26 @@ struct SavedSubsystem{V, T, M, I, P, Q, C}
partition_count::C
end

function SavedSubsystem(indp, pobj, saved_idxs)
# nothing saved
if saved_idxs === nothing || isempty(saved_idxs)
SavedSubsystem(indp, pobj, ::Nothing) = nothing

function SavedSubsystem(indp, pobj, idx::Int)
_indp = get_root_indp(indp)
if _indp === EMPTY_SYMBOLCACHE || _indp === nothing
return nothing
end
state_map = Dict(1 => idx)
return SavedSubsystem(state_map, nothing, nothing, nothing, nothing, nothing, nothing)
end

# this is required because problems with no system have an empty `SymbolCache`
# as their symbolic container.
if is_empty_indp(indp)
function SavedSubsystem(indp, pobj, saved_idxs::Union{AbstractArray, Tuple})
_indp = get_root_indp(indp)
if _indp === EMPTY_SYMBOLCACHE || _indp === nothing
return nothing
end
if eltype(saved_idxs) == Int
state_map = Dict{Int, Int}(v => k for (k, v) in enumerate(saved_idxs))
return SavedSubsystem(state_map, nothing, nothing, nothing, nothing, nothing, nothing)
end

# array state symbolics must be scalarized
saved_idxs = collect(Iterators.flatten(map(saved_idxs) do sym
Expand Down Expand Up @@ -357,29 +368,32 @@ corresponding to the state variables and a `SavedSubsystem` to pass to `build_so
The second return value (corresponding to the `SavedSubsystem`) may be `nothing` in case
one is not required. `save_idxs` may be a scalar or `nothing`.
"""
get_save_idxs_and_saved_subsystem(prob, ::Nothing) = nothing, nothing
function get_save_idxs_and_saved_subsystem(prob, save_idxs::Vector{Int})
save_idxs, SavedSubsystem(prob, parameter_values(prob), save_idxs)
end
function get_save_idxs_and_saved_subsystem(prob, save_idx::Int)
save_idx, SavedSubsystem(prob, parameter_values(prob), save_idx)
end
function get_save_idxs_and_saved_subsystem(prob, save_idxs)
if save_idxs === nothing
saved_subsystem = nothing
if !(save_idxs isa AbstractArray) || symbolic_type(save_idxs) != NotSymbolic()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Won't this be an unnecessary allocation in the scalar case? You only want to do this when symbolic scalar.

_save_idxs = (save_idxs,)
else
if !(save_idxs isa AbstractArray) || symbolic_type(save_idxs) != NotSymbolic()
_save_idxs = [save_idxs]
_save_idxs = save_idxs
end
saved_subsystem = SavedSubsystem(prob, parameter_values(prob), _save_idxs)
if saved_subsystem !== nothing
_save_idxs = get_saved_state_idxs(saved_subsystem)
if isempty(_save_idxs)
# no states to save
save_idxs = Int[]
Comment on lines +387 to +389
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just error early? This case is a bit odd. Saving nothing? Must be a user issue.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The user could still be saving discrete variables

elseif !(save_idxs isa AbstractArray) ||
symbolic_type(save_idxs) != NotSymbolic()
# only a single state to save, and save it as a scalar timeseries instead of
# single-element array
save_idxs = only(_save_idxs)
else
_save_idxs = save_idxs
end
saved_subsystem = SavedSubsystem(prob, parameter_values(prob), _save_idxs)
if saved_subsystem !== nothing
_save_idxs = get_saved_state_idxs(saved_subsystem)
if isempty(_save_idxs)
# no states to save
save_idxs = Int[]
elseif !(save_idxs isa AbstractArray) ||
symbolic_type(save_idxs) != NotSymbolic()
# only a single state to save, and save it as a scalar timeseries instead of
# single-element array
save_idxs = only(_save_idxs)
else
save_idxs = _save_idxs
end
save_idxs = _save_idxs
end
end

Expand Down
2 changes: 1 addition & 1 deletion test/downstream/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ DelayDiffEq = "5"
DiffEqCallbacks = "3, 4"
ForwardDiff = "0.10"
JumpProcesses = "9.10"
ModelingToolkit = "9.64.1"
ModelingToolkit = "9.64.3"
ModelingToolkitStandardLibrary = "2.7"
NonlinearSolve = "2, 3, 4"
Optimization = "4"
Expand Down
5 changes: 2 additions & 3 deletions test/downstream/adjoints.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,7 @@ u0 = [lorenz1.x => 1.0,
lorenz1.z => 0.0,
lorenz2.x => 0.0,
lorenz2.y => 1.0,
lorenz2.z => 0.0,
a => 2.0]
lorenz2.z => 0.0]

p = [lorenz1.σ => 10.0,
lorenz1.ρ => 28.0,
Expand Down Expand Up @@ -68,7 +67,7 @@ gs_ts, = Zygote.gradient(sol) do sol
sum(sum.(sol[[lorenz1.x, lorenz2.x], :]))
end

@test all(map(x -> x == true_grad_vecsym, gs_ts))
@test all(map(x -> x == true_grad_vecsym, gs_ts.u))

# BatchedInterface AD
@variables x(t)=1.0 y(t)=1.0 z(t)=1.0 w(t)=1.0
Expand Down
8 changes: 2 additions & 6 deletions test/downstream/comprehensive_indexing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -124,11 +124,7 @@ timeseries_systems = [osys, ssys, jsys]
set! = setsym(indp, sym)
@inferred get(valp)
@test get(valp) == val
if valp isa JumpProblem && sym isa Union{Tuple, AbstractArray}
@test_broken valp[sym]
else
@test valp[sym] == val
end
@test valp[sym] == val

if !(valp isa SciMLBase.AbstractNoTimeSolution)
@inferred set!(valp, newval)
Expand Down Expand Up @@ -872,7 +868,7 @@ end
ud2interp = ud2val[2:4]

c1 = SciMLBase.Clock(0.1)
c2 = SciMLBase.SolverStepClock
c2 = SciMLBase.SolverStepClock()
for (sym, t, val) in [
(x, c1[2], xinterp[1]),
(x, c1[2:4], xinterp),
Expand Down
Loading
Loading