Skip to content

Commit a424184

Browse files
committed
support for nested duals; tests included
1 parent 8b289b6 commit a424184

File tree

2 files changed

+35
-38
lines changed

2 files changed

+35
-38
lines changed

src/PreallocationTools.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,8 @@ end
1919
2020
Builds a `DualCache` object that stores both a version of the cache for `u`
2121
and for the `Dual` version of `u`, allowing use of pre-cached vectors with
22-
forward-mode automatic differentiation. Supports nested AD.
22+
forward-mode automatic differentiation. Supports nested AD via keyword `levels`
23+
or specifying an array of chunk_sizes.
2324
2425
"""
2526
dualcache(u::AbstractArray, N::Int=ForwardDiff.pickchunksize(length(u)); levels::Int = 1) = DiffCache(u, size(u), N*ones(Int, levels))

test/core_nesteddual.jl

Lines changed: 33 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -17,22 +17,32 @@ function claytonsample!(sto, τ, α; randmat=randmat)
1717
return sto
1818
end
1919

20-
#= taking the second derivative of claytonsample! with respect to τ with manual chunk_sizes. In setting up the dualcache,
21-
we are setting chunk_size to [1, 1], because we differentiate only with respect to τ.
22-
This initializes the cache with the minimum memory needed. =#
20+
#= taking the second derivative of claytonsample! with respect to τ with manual chunk_sizes.
21+
In setting up the dualcache, we are setting chunk_size to [1, 1], because we differentiate
22+
only with respect to τ. This initializes the cache with the minimum memory needed. =#
2323
stod = dualcache(sto, [1, 1])
2424
df3 = ForwardDiff.derivative-> ForwardDiff.derivative-> claytonsample!(stod, ξ, 0.0), τ), 0.3)
2525

26-
#= taking the second derivative of claytonsample! with respect to τ, auto-detect. For the given size of sto, ForwardDiff's heuristic
27-
chooses chunk_size = 8. Since this is greater than what's needed (1+1), the auto-allocated cache is big enough to handle the nested
28-
dual numbers. This should in general not be relied on to work, especially if more levels of nesting occurs (as below). =#
26+
#= taking the second derivative of claytonsample! with respect to τ with auto-detected chunk-size.
27+
For the given size of sto, ForwardDiff's heuristic chooses chunk_size = 8. Since this is greater
28+
than what's needed (1+1), the auto-allocated cache is big enough to handle the nested dual numbers, even
29+
if we don't specify the keyword argument levels = 2. This should in general not be relied on to work,
30+
especially if more levels of nesting occur (see optimization example below). =#
2931
stod = dualcache(sto)
3032
df4 = ForwardDiff.derivative-> ForwardDiff.derivative-> claytonsample!(stod, ξ, 0.0), τ), 0.3)
3133

3234
@test df3 df4
3335

34-
## Checking nested dual numbers: Checking an optimization problem inspired by the above tests
35-
## (using Optim.jl's Newton() (involving Hessians) and BFGS() (involving gradients))
36+
#= taking the second derivative of claytonsample! with respect to τ with auto-detected chunk-size.
37+
For the given size of sto, ForwardDiff's heuristic chooses chunk_size = 8 and with keyword arg levels = 2,
38+
the created cache size is larger than what's needed (even more so than the last example). =#
39+
stod = dualcache(sto, levels = 2)
40+
df5 = ForwardDiff.derivative-> ForwardDiff.derivative-> claytonsample!(stod, ξ, 0.0), τ), 0.3)
41+
42+
@test df3 df5
43+
44+
#= Checking nested dual numbers using optimization problem involving Optim.jl's Newton() (involving Hessians);
45+
so, we will need one level of AD for the ODE solver (TRBDF2) and two more to calculate the Hessian =#
3646
function foo(du, u, p, t)
3747
tmp = p[2]
3848
A = reshape(p[1], size(tmp.du))
@@ -42,15 +52,16 @@ function foo(du, u, p, t)
4252
nothing
4353
end
4454

45-
ps = 3 #use to specify problem size; don't go crazy on this, because compilation time...
46-
coeffs = -rand(ps^2)
47-
cache = dualcache(zeros(ps,ps), [9, 9, 9])
55+
ps = 3 #use to specify problem size; don't go crazy on this, because of the compilation time...
56+
coeffs = -rand(ps,ps)
57+
cache = dualcache(zeros(ps,ps), levels = 3)
4858
prob = ODEProblem(foo, ones(ps, ps), (0., 1.0), (coeffs, cache))
49-
realsol = solve(prob, TRBDF2(), saveat = 0.0:0.01:1.0, reltol = 1e-8)
59+
realsol = solve(prob, TRBDF2(), saveat = 0.0:0.1:10.0, reltol = 1e-8)
60+
u0 = rand(length(coeffs))
5061

5162
function objfun(x, prob, realsol, cache)
52-
prob = remake(prob, u0 = eltype(x).(ones(ps, ps)), p = (x, cache))
53-
sol = solve(prob, TRBDF2(), saveat = 0.0:0.01:1.0, reltol = 1e-8)
63+
prob = remake(prob, u0 = eltype(x).(prob.u0), p = (x, cache))
64+
sol = solve(prob, TRBDF2(), saveat = 0.0:0.1:10.0, reltol = 1e-8)
5465

5566
ofv = 0.0
5667
if any((s.retcode != :Success for s in sol))
@@ -60,35 +71,20 @@ function objfun(x, prob, realsol, cache)
6071
end
6172
return ofv
6273
end
63-
6474
fn(x,p) = objfun(x, p[1], p[2], p[3])
65-
6675
optfun = OptimizationFunction(fn, GalacticOptim.AutoForwardDiff())
67-
optprob = OptimizationProblem(optfun, rand(size(coeffs)...), (prob, realsol, cache))
76+
optprob = OptimizationProblem(optfun, -rand(length(coeffs)), (prob, realsol, cache), chunk_size = 2)
6877
newtonsol = solve(optprob, Newton())
69-
bfgssol = solve(optprob, BFGS()) #since only gradients are used here, we could go with a smaller dualcache(zeros(ps,ps), [9,9]) as well.
7078

71-
@test all(abs.(coeffs .- newtonsol.u) .< 1e-3)
72-
@test all(abs.(coeffs .- bfgssol.u) .< 1e-3)
79+
@test all(abs.(coeffs[:] .- newtonsol.u) .< 1e-2)
7380

7481
#an example where chunk_sizes are not the same on all differentiation levels:
75-
function foo(du, u, p, t)
76-
tmp = p[2]
77-
A = ones(size(tmp.du)).*p[1]
78-
tmp = get_tmp(tmp, u)
79-
mul!(tmp, A, u)
80-
@. du = u + tmp
81-
nothing
82-
end
83-
84-
coeffs = rand(1)
85-
cache = dualcache(zeros(ps,ps), [1, 1, 4])
86-
prob = ODEProblem(foo, ones(ps, ps), (0., 1.0), (coeffs, cache))
87-
realsol = solve(prob, TRBDF2(), saveat = 0.0:0.01:1.0, reltol = 1e-8)
82+
cache = dualcache(zeros(ps,ps), [9, 9, 2])
83+
realsol = solve(prob, TRBDF2(chunk_size = 2), saveat = 0.0:0.1:10.0, reltol = 1e-8)
8884

8985
function objfun(x, prob, realsol, cache)
90-
prob = remake(prob, u0 = eltype(x).(ones(ps, ps)), p = (x, cache))
91-
sol = solve(prob, TRBDF2(), saveat = 0.0:0.01:1.0, reltol = 1e-8)
86+
prob = remake(prob, u0 = eltype(x).(prob.u0), p = (x, cache))
87+
sol = solve(prob, TRBDF2(chunk_size = 2), saveat = 0.0:0.1:10.0, reltol = 1e-8)
9288

9389
ofv = 0.0
9490
if any((s.retcode != :Success for s in sol))
@@ -102,7 +98,7 @@ end
10298
fn(x,p) = objfun(x, p[1], p[2], p[3])
10399

104100
optfun = OptimizationFunction(fn, GalacticOptim.AutoForwardDiff())
105-
optprob = OptimizationProblem(optfun, rand(size(coeffs)...), (prob, realsol, cache))
101+
optprob = OptimizationProblem(optfun, -rand(length(coeffs)), (prob, realsol, cache))
106102
newtonsol2 = solve(optprob, Newton())
107103

108-
@test all(abs.(coeffs .- newtonsol2.u) .< 1e-3)
104+
@test all(abs.(coeffs[:] .- newtonsol2.u) .< 1e-2)

0 commit comments

Comments
 (0)