Skip to content

Commit be939ac

Browse files
committed
Introduced support for nested dual numbers
1 parent ef77d8d commit be939ac

File tree

2 files changed

+105
-7
lines changed

2 files changed

+105
-7
lines changed

src/PreallocationTools.jl

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,17 @@ struct DiffCache{T<:AbstractArray, S<:AbstractArray}
77
dual_du::S
88
end
99

10-
#= removed dependency on ArrayInterface, because it seemed not necessary anymore;
11-
not sure whether it breaks things that are not in the testset; needs checking. =#
12-
function DiffCache(u::AbstractArray{T}, siz, ::Type{Val{chunk_size}}) where {T, chunk_size}
10+
function DiffCache(u::AbstractArray{T}, siz, chunk_size::V) where {T,V<:Int}
1311
x = zeros(T,(chunk_size+1)*prod(siz))
1412
DiffCache(u, x)
1513
end
1614

15+
function DiffCache(u::AbstractArray{T}, siz, chunk_sizes::AbstractArray{V}) where {T,V<:Int}
16+
clamp!(chunk_sizes,1,ForwardDiff.DEFAULT_CHUNK_THRESHOLD)
17+
x = zeros(T,prod(chunk_sizes.+1)*prod(siz))
18+
DiffCache(u, x)
19+
end
20+
1721
"""
1822
1923
`dualcache(u::AbstractArray, N = Val{default_cache_size(length(u))})`
@@ -22,7 +26,7 @@ Builds a `DualCache` object that stores versions of the cache for `u` and for th
2226
forward-mode automatic differentiation.
2327
2428
"""
25-
dualcache(u::AbstractArray, N=Val{ForwardDiff.pickchunksize(length(u))}) = DiffCache(u, size(u), N)
29+
dualcache(u::AbstractArray, N=ForwardDiff.pickchunksize(length(u))) = DiffCache(u, size(u), N)
2630

2731
"""
2832

test/runtests.jl

Lines changed: 97 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
using LinearAlgebra, OrdinaryDiffEq, Test, PreallocationTools, ForwardDiff, LabelledArrays
1+
using LinearAlgebra, OrdinaryDiffEq, Test, PreallocationTools, ForwardDiff, LabelledArrays, GalacticOptim, Optim
22

33
## Check ODE problem with specified chunk_size
44
function foo(du, u, (A, tmp), t)
@@ -8,7 +8,7 @@ function foo(du, u, (A, tmp), t)
88
nothing
99
end
1010
chunk_size = 5
11-
prob = ODEProblem(foo, ones(5, 5), (0., 1.0), (ones(5,5), dualcache(zeros(5,5), Val{chunk_size})))
11+
prob = ODEProblem(foo, ones(5, 5), (0., 1.0), (ones(5,5), dualcache(zeros(5,5), chunk_size)))
1212
solve(prob, TRBDF2(chunk_size=chunk_size))
1313

1414
## Check ODE problem with auto-detected chunk_size
@@ -70,4 +70,98 @@ df1 = ForwardDiff.derivative(τ -> claytonsample!(stod, τ, 0.0), 0.3)
7070
df2 = ForwardDiff.jacobian(x -> claytonsample!(stod, x[1], x[2]), [0.3; 0.0]) #should give a 15x2 array,
7171
#because ForwardDiff flattens the output of jacobian, see: https://juliadiff.org/ForwardDiff.jl/stable/user/api/#ForwardDiff.jacobian
7272

73-
@test all(df1[1:5,2] df2[6:10,1])
73+
@test df1[1:5,2] df2[6:10,1]
74+
75+
76+
## Checking nested dual numbers: second derivatives
77+
78+
#= taking the second derivative of claytonsample! with respect to τ with manual chunk_sizes. In setting up the dualcache,
79+
we are setting chunk_size to [1, 1], because we differentiate only twice with respect to τ.
80+
This initializes the cache with the minimum memory needed. =#
81+
stod = dualcache(sto, [1, 1])
82+
df3 = ForwardDiff.derivative-> ForwardDiff.derivative-> claytonsample!(stod, ξ, 0.0), τ), 0.3)
83+
84+
#= taking the second derivative of claytonsample! with respect to τ, auto-detect. For the given size of sto, ForwardDiff's heuristic
85+
chooses chunk_size = 8. Since this is greater than (1+1)^2 = 4, the auto-allocated cache is big enough to handle the nested
86+
dual numbers. This should in general not be relied on to work, especially if more levels of nesting occurs (as below). =#
87+
stod = dualcache(sto)
88+
df4 = ForwardDiff.derivative-> ForwardDiff.derivative-> claytonsample!(stod, ξ, 0.0), τ), 0.3)
89+
90+
@test df3 df4
91+
92+
## Checking nested dual numbers: Checking an optimization problem inspired by the above tests
93+
## (using Optim.jl's Newton() (involving Hessians) and BFGS() (involving gradients))
94+
function foo(du, u, p, t)
95+
tmp = p[2]
96+
A = reshape(p[1], size(tmp.du))
97+
tmp = get_tmp(tmp, u)
98+
mul!(tmp, A, u)
99+
@. du = u + tmp
100+
nothing
101+
end
102+
103+
ps = 2 #use to specify problem size (ps ∈ {1,2})
104+
coeffs = rand(ps^2)
105+
cache = dualcache(zeros(ps,ps), [4, 4, 4])
106+
prob = ODEProblem(foo, ones(ps, ps), (0., 1.0), (coeffs, cache))
107+
realsol = solve(prob, TRBDF2(), saveat = 0.0:0.01:1.0, reltol = 1e-8)
108+
109+
function objfun(x, prob, realsol, cache)
110+
prob = remake(prob, u0 = eltype(x).(ones(ps, ps)), p = (x, cache))
111+
sol = solve(prob, TRBDF2(), saveat = 0.0:0.01:1.0, reltol = 1e-8)
112+
113+
ofv = 0.0
114+
if any((s.retcode != :Success for s in sol))
115+
ofv = 1e12
116+
else
117+
ofv = sum((sol.-realsol).^2)
118+
end
119+
return ofv
120+
end
121+
122+
fn(x,p) = objfun(x, p[1], p[2], p[3])
123+
124+
optfun = OptimizationFunction(fn, GalacticOptim.AutoForwardDiff())
125+
optprob = OptimizationProblem(optfun, rand(size(coeffs)...), (prob, realsol, cache))
126+
newtonsol = solve(optprob, Newton())
127+
bfgssol = solve(optprob, BFGS()) #since only gradients are used here, we could go with a slim dualcache(zeros(ps,ps), [4,4]) as well.
128+
129+
@test all(abs.(coeffs .- newtonsol.u) .< 1e-3)
130+
@test all(abs.(coeffs .- bfgssol.u) .< 1e-3)
131+
132+
#an example where chunk_sizes are not the same on all differentiation levels:
133+
function foo(du, u, p, t)
134+
tmp = p[2]
135+
A = ones(size(tmp.du)).*p[1]
136+
tmp = get_tmp(tmp, u)
137+
mul!(tmp, A, u)
138+
@. du = u + tmp
139+
nothing
140+
end
141+
142+
ps = 2 #use to specify problem size (ps ∈ {1,2})
143+
coeffs = rand(1)
144+
cache = dualcache(zeros(ps,ps), [1, 1, 4])
145+
prob = ODEProblem(foo, ones(ps, ps), (0., 1.0), (coeffs, cache))
146+
realsol = solve(prob, TRBDF2(), saveat = 0.0:0.01:1.0, reltol = 1e-8)
147+
148+
function objfun(x, prob, realsol, cache)
149+
prob = remake(prob, u0 = eltype(x).(ones(ps, ps)), p = (x, cache))
150+
sol = solve(prob, TRBDF2(), saveat = 0.0:0.01:1.0, reltol = 1e-8)
151+
152+
ofv = 0.0
153+
if any((s.retcode != :Success for s in sol))
154+
ofv = 1e12
155+
else
156+
ofv = sum((sol.-realsol).^2)
157+
end
158+
return ofv
159+
end
160+
161+
fn(x,p) = objfun(x, p[1], p[2], p[3])
162+
163+
optfun = OptimizationFunction(fn, GalacticOptim.AutoForwardDiff())
164+
optprob = OptimizationProblem(optfun, rand(size(coeffs)...), (prob, realsol, cache))
165+
newtonsol2 = solve(optprob, Newton())
166+
167+
@test all(abs.(coeffs .- newtonsol2.u) .< 1e-3)

0 commit comments

Comments
 (0)