@@ -2,17 +2,35 @@ module NonlinearSolveBaseForwardDiffExt
22
33using ADTypes: ADTypes, AutoForwardDiff, AutoPolyesterForwardDiff
44using ArrayInterface: ArrayInterface
5- using CommonSolve: solve
5+ using CommonSolve: CommonSolve, solve, solve!, init
6+ using ConcreteStructs: @concrete
67using DifferentiationInterface: DifferentiationInterface
78using FastClosures: @closure
8- using ForwardDiff: ForwardDiff, Dual
9+ using ForwardDiff: ForwardDiff, Dual, pickchunksize
910using SciMLBase: SciMLBase, AbstractNonlinearProblem, IntervalNonlinearProblem,
1011 NonlinearProblem, NonlinearLeastSquaresProblem, remake
1112
12- using NonlinearSolveBase: NonlinearSolveBase, ImmutableNonlinearProblem, Utils
13+ using NonlinearSolveBase: NonlinearSolveBase, ImmutableNonlinearProblem, Utils, InternalAPI,
14+ NonlinearSolvePolyAlgorithm, NonlinearSolveForwardDiffCache
1315
1416const DI = DifferentiationInterface
1517
18+ const GENERAL_SOLVER_TYPES = [
19+ Nothing, NonlinearSolvePolyAlgorithm
20+ ]
21+
22+ const DualNonlinearProblem = NonlinearProblem{
23+ <: Union{Number, <:AbstractArray} , iip,
24+ <: Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}}
25+ } where {iip, T, V, P}
26+ const DualNonlinearLeastSquaresProblem = NonlinearLeastSquaresProblem{
27+ <: Union{Number, <:AbstractArray} , iip,
28+ <: Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}}
29+ } where {iip, T, V, P}
30+ const DualAbstractNonlinearProblem = Union{
31+ DualNonlinearProblem, DualNonlinearLeastSquaresProblem
32+ }
33+
1634function NonlinearSolveBase. additional_incompatible_backend_check (
1735 prob:: AbstractNonlinearProblem , :: Union{AutoForwardDiff, AutoPolyesterForwardDiff} )
1836 return ! ForwardDiff. can_dual (eltype (prob. u0))
@@ -102,4 +120,78 @@ function NonlinearSolveBase.nonlinearsolve_dual_solution(
102120 return map (((uᵢ, pᵢ),) -> Dual {T, V, P} (uᵢ, pᵢ), zip (u, Utils. restructure (u, partials)))
103121end
104122
123+ for algType in GENERAL_SOLVER_TYPES
124+ @eval function SciMLBase. __solve (
125+ prob:: DualAbstractNonlinearProblem , alg:: $ (algType), args... ; kwargs...
126+ )
127+ sol, partials = NonlinearSolveBase. nonlinearsolve_forwarddiff_solve (
128+ prob, alg, args... ; kwargs...
129+ )
130+ dual_soln = NonlinearSolveBase. nonlinearsolve_dual_solution (sol. u, partials, prob. p)
131+ return SciMLBase. build_solution (
132+ prob, alg, dual_soln, sol. resid; sol. retcode, sol. stats, sol. original
133+ )
134+ end
135+ end
136+
137+ function InternalAPI. reinit! (
138+ cache:: NonlinearSolveForwardDiffCache , args... ;
139+ p = cache. p, u0 = NonlinearSolveBase. get_u (cache. cache), kwargs...
140+ )
141+ InternalAPI. reinit! (
142+ cache. cache; p = NonlinearSolveBase. nodual_value (p),
143+ u0 = NonlinearSolveBase. nodual_value (u0), kwargs...
144+ )
145+ cache. p = p
146+ cache. values_p = NonlinearSolveBase. nodual_value (p)
147+ cache. partials_p = ForwardDiff. partials (p)
148+ return cache
149+ end
150+
151+ for algType in GENERAL_SOLVER_TYPES
152+ @eval function SciMLBase. __init (
153+ prob:: DualAbstractNonlinearProblem , alg:: $ (algType), args... ; kwargs...
154+ )
155+ p = NonlinearSolveBase. nodual_value (prob. p)
156+ newprob = SciMLBase. remake (prob; u0 = NonlinearSolveBase. nodual_value (prob. u0), p)
157+ cache = init (newprob, alg, args... ; kwargs... )
158+ return NonlinearSolveForwardDiffCache (
159+ cache, newprob, alg, prob. p, p, ForwardDiff. partials (prob. p)
160+ )
161+ end
162+ end
163+
164+ function CommonSolve. solve! (cache:: NonlinearSolveForwardDiffCache )
165+ sol = solve! (cache. cache)
166+ prob = cache. prob
167+ uu = sol. u
168+
169+ fn = prob isa NonlinearLeastSquaresProblem ?
170+ NonlinearSolveBase. nlls_generate_vjp_function (prob, sol, uu) : prob. f
171+
172+ Jₚ = NonlinearSolveBase. nonlinearsolve_∂f_∂p (prob, fn, uu, cache. values_p)
173+ Jᵤ = NonlinearSolveBase. nonlinearsolve_∂f_∂u (prob, fn, uu, cache. values_p)
174+
175+ z_arr = - Jᵤ \ Jₚ
176+
177+ sumfun = ((z, p),) -> map (zᵢ -> zᵢ * ForwardDiff. partials (p), z)
178+ if cache. p isa Number
179+ partials = sumfun ((z_arr, cache. p))
180+ else
181+ partials = sum (sumfun, zip (eachcol (z_arr), cache. p))
182+ end
183+
184+ dual_soln = NonlinearSolveBase. nonlinearsolve_dual_solution (sol. u, partials, cache. p)
185+ return SciMLBase. build_solution (
186+ prob, cache. alg, dual_soln, sol. resid; sol. retcode, sol. stats, sol. original
187+ )
188+ end
189+
190+ NonlinearSolveBase. nodual_value (x) = x
191+ NonlinearSolveBase. nodual_value (x:: Dual ) = ForwardDiff. value (x)
192+ NonlinearSolveBase. nodual_value (x:: AbstractArray{<:Dual} ) = map (ForwardDiff. value, x)
193+
194+ @inline NonlinearSolveBase. pickchunksize (x) = pickchunksize (length (x))
195+ @inline NonlinearSolveBase. pickchunksize (x:: Int ) = ForwardDiff. pickchunksize (x)
196+
105197end
0 commit comments