-
-
Notifications
You must be signed in to change notification settings - Fork 113
fix: fix remake autodiff tests and Zygote adjoint #943
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
bbeb0ce
7a54aab
ce6ca68
66ac0fd
a1aec24
aa06f6d
e60a118
e78e0d5
3845403
172a6bc
2236bd0
29bbc34
795047e
73cb4d0
fb2bc23
4853750
9ee2af8
2ffb0ee
d53a127
a52552c
23a14e1
5e2e25e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -3,6 +3,7 @@ module SciMLBaseZygoteExt | |
| using Zygote | ||
| using Zygote: @adjoint, pullback | ||
| import Zygote: literal_getproperty | ||
| import ChainRulesCore | ||
| using SciMLBase | ||
| using SciMLBase: ODESolution, remake, | ||
| getobserved, build_solution, EnsembleSolution, | ||
|
|
@@ -40,31 +41,9 @@ import SciMLStructures | |
| VA[i, j], ODESolution_getindex_pullback | ||
| end | ||
|
|
||
| @adjoint function Base.getindex(VA::ODESolution, sym, j::Int) | ||
| function ODESolution_getindex_pullback(Δ) | ||
| i = symbolic_type(sym) != NotSymbolic() ? variable_index(VA, sym) : sym | ||
| du, dprob = if i === nothing | ||
| getter = getobserved(VA) | ||
| grz = pullback(getter, sym, VA.u[j], VA.prob.p, VA.t[j])[2](Δ) | ||
| du = [k == j ? grz[2] : zero(VA.u[1]) for k in 1:length(VA.u)] | ||
| dp = grz[3] # pullback for p | ||
| dprob = remake(VA.prob, p = dp) | ||
| du, dprob | ||
| else | ||
|
Comment on lines
-43
to
-53
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this because the ChainRulesCore one is fixed? If so, just delete instead of comment.
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oops, yeah. |
||
| du = [m == j ? [i == k ? Δ : zero(VA.u[1][1]) for k in 1:length(VA.u[1])] : | ||
| zero(VA.u[1]) for m in 1:length(VA.u)] | ||
| dp = zero(VA.prob.p) | ||
| dprob = remake(VA.prob, p = dp) | ||
| du, dprob | ||
| end | ||
| T = eltype(eltype(VA.u)) | ||
| N = ndims(VA) | ||
| Δ′ = ODESolution{T, N}(du, nothing, nothing, | ||
| VA.t, VA.k, VA.discretes, dprob, VA.alg, VA.interp, VA.dense, 0, VA.stats, | ||
| VA.alg_choice, VA.retcode) | ||
| (Δ′, nothing, nothing) | ||
| end | ||
| VA[sym, j], ODESolution_getindex_pullback | ||
| @adjoint function Base.getindex(VA::ODESolution, sym, j::Integer) | ||
| res, pullback = ChainRulesCore.rrule(Zygote.ZygoteRuleConfig(), getindex, VA, sym, j) | ||
| return res, Base.tail ∘ pullback | ||
| end | ||
|
|
||
| @adjoint function EnsembleSolution(sim, time, converged, stats) | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -44,9 +44,11 @@ function as_diffeq_array(vt::Vector{VectorTemplate}, t) | |
| return DiffEqArray(typeof(TupleOfArraysWrapper(vt))[], t, (1, 1)) | ||
| end | ||
|
|
||
| function is_empty_indp(indp) | ||
| isempty(variable_symbols(indp)) && isempty(parameter_symbols(indp)) && | ||
| isempty(independent_variable_symbols(indp)) | ||
| function get_root_indp(indp) | ||
| if hasmethod(symbolic_container, Tuple{typeof(indp)}) && (sc = symbolic_container(indp)) !== indp | ||
| return get_root_indp(sc) | ||
| end | ||
| return indp | ||
| end | ||
|
|
||
| # Everything from this point on is public API | ||
|
|
@@ -105,17 +107,26 @@ struct SavedSubsystem{V, T, M, I, P, Q, C} | |
| partition_count::C | ||
| end | ||
|
|
||
| function SavedSubsystem(indp, pobj, saved_idxs) | ||
| # nothing saved | ||
| if saved_idxs === nothing || isempty(saved_idxs) | ||
| SavedSubsystem(indp, pobj, ::Nothing) = nothing | ||
|
|
||
| function SavedSubsystem(indp, pobj, idx::Int) | ||
| _indp = get_root_indp(indp) | ||
| if _indp === EMPTY_SYMBOLCACHE || _indp === nothing | ||
| return nothing | ||
| end | ||
| state_map = Dict(1 => idx) | ||
| return SavedSubsystem(state_map, nothing, nothing, nothing, nothing, nothing, nothing) | ||
| end | ||
|
|
||
| # this is required because problems with no system have an empty `SymbolCache` | ||
| # as their symbolic container. | ||
| if is_empty_indp(indp) | ||
| function SavedSubsystem(indp, pobj, saved_idxs::Union{AbstractArray, Tuple}) | ||
| _indp = get_root_indp(indp) | ||
| if _indp === EMPTY_SYMBOLCACHE || _indp === nothing | ||
| return nothing | ||
| end | ||
| if eltype(saved_idxs) == Int | ||
| state_map = Dict{Int, Int}(v => k for (k, v) in enumerate(saved_idxs)) | ||
| return SavedSubsystem(state_map, nothing, nothing, nothing, nothing, nothing, nothing) | ||
| end | ||
|
|
||
| # array state symbolics must be scalarized | ||
| saved_idxs = collect(Iterators.flatten(map(saved_idxs) do sym | ||
|
|
@@ -357,29 +368,32 @@ corresponding to the state variables and a `SavedSubsystem` to pass to `build_so | |
| The second return value (corresponding to the `SavedSubsystem`) may be `nothing` in case | ||
| one is not required. `save_idxs` may be a scalar or `nothing`. | ||
| """ | ||
| get_save_idxs_and_saved_subsystem(prob, ::Nothing) = nothing, nothing | ||
| function get_save_idxs_and_saved_subsystem(prob, save_idxs::Vector{Int}) | ||
| save_idxs, SavedSubsystem(prob, parameter_values(prob), save_idxs) | ||
| end | ||
| function get_save_idxs_and_saved_subsystem(prob, save_idx::Int) | ||
| save_idx, SavedSubsystem(prob, parameter_values(prob), save_idx) | ||
| end | ||
| function get_save_idxs_and_saved_subsystem(prob, save_idxs) | ||
| if save_idxs === nothing | ||
| saved_subsystem = nothing | ||
| if !(save_idxs isa AbstractArray) || symbolic_type(save_idxs) != NotSymbolic() | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Won't this be an unnecessary allocation in the scalar case? You only want to do this when symbolic scalar. |
||
| _save_idxs = (save_idxs,) | ||
| else | ||
| if !(save_idxs isa AbstractArray) || symbolic_type(save_idxs) != NotSymbolic() | ||
| _save_idxs = [save_idxs] | ||
| _save_idxs = save_idxs | ||
| end | ||
| saved_subsystem = SavedSubsystem(prob, parameter_values(prob), _save_idxs) | ||
| if saved_subsystem !== nothing | ||
| _save_idxs = get_saved_state_idxs(saved_subsystem) | ||
| if isempty(_save_idxs) | ||
| # no states to save | ||
| save_idxs = Int[] | ||
|
Comment on lines
+387
to
+389
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Just error early? This case is a bit odd. Saving nothing? Must be a user issue.
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The user could still be saving discrete variables |
||
| elseif !(save_idxs isa AbstractArray) || | ||
| symbolic_type(save_idxs) != NotSymbolic() | ||
| # only a single state to save, and save it as a scalar timeseries instead of | ||
| # single-element array | ||
| save_idxs = only(_save_idxs) | ||
| else | ||
| _save_idxs = save_idxs | ||
| end | ||
| saved_subsystem = SavedSubsystem(prob, parameter_values(prob), _save_idxs) | ||
| if saved_subsystem !== nothing | ||
| _save_idxs = get_saved_state_idxs(saved_subsystem) | ||
| if isempty(_save_idxs) | ||
| # no states to save | ||
| save_idxs = Int[] | ||
| elseif !(save_idxs isa AbstractArray) || | ||
| symbolic_type(save_idxs) != NotSymbolic() | ||
| # only a single state to save, and save it as a scalar timeseries instead of | ||
| # single-element array | ||
| save_idxs = only(_save_idxs) | ||
| else | ||
| save_idxs = _save_idxs | ||
| end | ||
| save_idxs = _save_idxs | ||
| end | ||
| end | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
😅 oh no.