@@ -40,35 +40,36 @@ import SciMLStructures
4040 VA[i, j], ODESolution_getindex_pullback
4141end
4242
43- @adjoint function Base. getindex (VA:: ODESolution , sym, j:: Int )
44- function ODESolution_getindex_pullback (Δ)
45- i = symbolic_type (sym) != NotSymbolic () ? variable_index (VA, sym) : sym
46- du, dprob = if i === nothing
47- getter = getobserved (VA)
48- grz = pullback (getter, sym, VA. u[j], VA. prob. p, VA. t[j])[2 ](Δ)
49- du = [k == j ? grz[2 ] : zero (VA. u[1 ]) for k in 1 : length (VA. u)]
50- dp = grz[3 ] # pullback for p
51- if dp === nothing
52- dp = parameter_values (VA)
53- end
54- dprob = remake (VA. prob, p = dp)
55- du, dprob
56- else
57- du = [m == j ? [i == k ? Δ : zero (VA. u[1 ][1 ]) for k in 1 : length (VA. u[1 ])] :
58- zero (VA. u[1 ]) for m in 1 : length (VA. u)]
59- dp = zero (VA. prob. p)
60- dprob = remake (VA. prob, p = dp)
61- du, dprob
62- end
63- T = eltype (eltype (VA. u))
64- N = ndims (VA)
65- Δ′ = ODESolution {T, N} (du, nothing , nothing ,
66- VA. t, VA. k, VA. discretes, dprob, VA. alg, VA. interp, VA. dense, 0 , VA. stats,
67- VA. alg_choice, VA. retcode)
68- (Δ′, nothing , nothing )
69- end
70- VA[sym, j], ODESolution_getindex_pullback
71- end
43+ # @adjoint function Base.getindex(VA::ODESolution, sym, j::Int)
44+ # function ODESolution_getindex_pullback(Δ)
45+ # i = symbolic_type(sym) != NotSymbolic() ? variable_index(VA, sym) : sym
46+ # du, dprob = if i === nothing
47+ # getter = getobserved(VA)
48+ # grz = pullback(getter, sym, VA.u[j], VA.prob.p, VA.t[j])[2](Δ)
49+ # du = [k == j ? grz[2] : zero(VA.u[1]) for k in 1:length(VA.u)]
50+ # dp = grz[3] # pullback for p
51+ # if dp === nothing
52+ # dp = zeros(size(parameter_values(VA)))
53+ # dp = parameter_values(VA)
54+ # end
55+ # dprob = remake(VA.prob, p = dp)
56+ # du, dprob
57+ # else
58+ # du = [m == j ? [i == k ? Δ : zero(VA.u[1][1]) for k in 1:length(VA.u[1])] :
59+ # zero(VA.u[1]) for m in 1:length(VA.u)]
60+ # dp = zero(VA.prob.p)
61+ # dprob = remake(VA.prob, p = dp)
62+ # du, dprob
63+ # end
64+ # T = eltype(eltype(VA.u))
65+ # N = ndims(VA)
66+ # Δ′ = ODESolution{T, N}(du, nothing, nothing,
67+ # VA.t, VA.k, VA.discretes, dprob, VA.alg, VA.interp, VA.dense, 0, VA.stats,
68+ # VA.alg_choice, VA.retcode)
69+ # (Δ′, nothing, nothing)
70+ # end
71+ # VA[sym, j], ODESolution_getindex_pullback
72+ # end
7273
7374@adjoint function EnsembleSolution (sim, time, converged, stats)
7475 out = EnsembleSolution (sim, time, converged, stats)
0 commit comments