Skip to content

Commit 389e261

Browse files
Merge pull request #9 from thomvet/Support-nested-duals
Support nested duals
2 parents 917c175 + a424184 commit 389e261

File tree

5 files changed

+135
-26
lines changed

5 files changed

+135
-26
lines changed

Project.toml

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,12 @@ julia = "1.6"
1919
[extras]
2020
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
2121
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
22-
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
23-
RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
2422
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
23+
RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
2524
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
25+
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
26+
Optim = "429524aa-4258-5aef-a3af-852621145aeb"
27+
GalacticOptim = "a75be94c-b780-496d-a8a9-0878b188d577"
2628

2729
[targets]
28-
test = ["LinearAlgebra", "OrdinaryDiffEq", "Test", "RecursiveArrayTools", "Pkg", "SafeTestsets"]
30+
test = ["LinearAlgebra", "OrdinaryDiffEq", "Test", "RecursiveArrayTools", "Pkg", "SafeTestsets", "GalacticOptim", "Optim"]

src/PreallocationTools.jl

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,21 +7,24 @@ struct DiffCache{T<:AbstractArray, S<:AbstractArray}
77
dual_du::S
88
end
99

10-
function DiffCache(u::AbstractArray{T}, siz, chunk_size) where {T}
11-
x = adapt(ArrayInterface.parameterless_type(u), zeros(T,(chunk_size+1)*prod(siz)))
10+
function DiffCache(u::AbstractArray{T}, siz, chunk_sizes) where {T}
11+
x = adapt(ArrayInterface.parameterless_type(u), zeros(T, prod(chunk_sizes .+ 1)*prod(siz)))
1212
DiffCache(u, x)
1313
end
1414

1515
"""
1616
17-
`dualcache(u::AbstractArray, N = default_cache_size(length(u)))`
17+
`dualcache(u::AbstractArray, N::Int = ForwardDiff.pickchunksize(length(u)); levels::Int = 1)`
18+
`dualcache(u::AbstractArray; N::AbstractArray{<:Int})`
1819
1920
Builds a `DualCache` object that stores both a version of the cache for `u`
2021
and for the `Dual` version of `u`, allowing use of pre-cached vectors with
21-
forward-mode automatic differentiation.
22+
forward-mode automatic differentiation. Supports nested AD via keyword `levels`
23+
or specifying an array of chunk_sizes.
2224
2325
"""
24-
dualcache(u::AbstractArray, N=ForwardDiff.pickchunksize(length(u))) = DiffCache(u, size(u), N)
26+
dualcache(u::AbstractArray, N::Int=ForwardDiff.pickchunksize(length(u)); levels::Int = 1) = DiffCache(u, size(u), N*ones(Int, levels))
27+
dualcache(u::AbstractArray, N::AbstractArray{<:Int}) = DiffCache(u, size(u), N)
2528

2629
"""
2730

test/core_nesteddual.jl

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
using LinearAlgebra, OrdinaryDiffEq, Test, PreallocationTools, ForwardDiff, GalacticOptim, Optim
2+
3+
randmat = rand(5, 3)
4+
sto = similar(randmat)
5+
function claytonsample!(sto, τ, α; randmat=randmat)
6+
sto = get_tmp(sto, τ)
7+
sto .= randmat
8+
τ == 0 && return sto
9+
10+
n = size(sto, 1)
11+
for i in 1:n
12+
v = sto[i, 2]
13+
u = sto[i, 1]
14+
sto[i, 1] = (1 - u^(-τ) + u^(-τ)*v^(-/(1 + τ))))^(-1/τ)*α
15+
sto[i, 2] = (1 - u^(-τ) + u^(-τ)*v^(-/(1 + τ))))^(-1/τ)
16+
end
17+
return sto
18+
end
19+
20+
#= taking the second derivative of claytonsample! with respect to τ with manual chunk_sizes.
21+
In setting up the dualcache, we are setting chunk_size to [1, 1], because we differentiate
22+
only with respect to τ. This initializes the cache with the minimum memory needed. =#
23+
stod = dualcache(sto, [1, 1])
24+
df3 = ForwardDiff.derivative-> ForwardDiff.derivative-> claytonsample!(stod, ξ, 0.0), τ), 0.3)
25+
26+
#= taking the second derivative of claytonsample! with respect to τ with auto-detected chunk-size.
27+
For the given size of sto, ForwardDiff's heuristic chooses chunk_size = 8. Since this is greater
28+
than what's needed (1+1), the auto-allocated cache is big enough to handle the nested dual numbers, even
29+
if we don't specify the keyword argument levels = 2. This should in general not be relied on to work,
30+
especially if more levels of nesting occur (see optimization example below). =#
31+
stod = dualcache(sto)
32+
df4 = ForwardDiff.derivative-> ForwardDiff.derivative-> claytonsample!(stod, ξ, 0.0), τ), 0.3)
33+
34+
@test df3 df4
35+
36+
#= taking the second derivative of claytonsample! with respect to τ with auto-detected chunk-size.
37+
For the given size of sto, ForwardDiff's heuristic chooses chunk_size = 8 and with keyword arg levels = 2,
38+
the created cache size is larger than what's needed (even more so than the last example). =#
39+
stod = dualcache(sto, levels = 2)
40+
df5 = ForwardDiff.derivative-> ForwardDiff.derivative-> claytonsample!(stod, ξ, 0.0), τ), 0.3)
41+
42+
@test df3 df5
43+
44+
#= Checking nested dual numbers using optimization problem involving Optim.jl's Newton() (involving Hessians);
45+
so, we will need one level of AD for the ODE solver (TRBDF2) and two more to calculate the Hessian =#
46+
function foo(du, u, p, t)
47+
tmp = p[2]
48+
A = reshape(p[1], size(tmp.du))
49+
tmp = get_tmp(tmp, u)
50+
mul!(tmp, A, u)
51+
@. du = u + tmp
52+
nothing
53+
end
54+
55+
ps = 3 #use to specify problem size; don't go crazy on this, because of the compilation time...
56+
coeffs = -rand(ps,ps)
57+
cache = dualcache(zeros(ps,ps), levels = 3)
58+
prob = ODEProblem(foo, ones(ps, ps), (0., 1.0), (coeffs, cache))
59+
realsol = solve(prob, TRBDF2(), saveat = 0.0:0.1:10.0, reltol = 1e-8)
60+
u0 = rand(length(coeffs))
61+
62+
function objfun(x, prob, realsol, cache)
63+
prob = remake(prob, u0 = eltype(x).(prob.u0), p = (x, cache))
64+
sol = solve(prob, TRBDF2(), saveat = 0.0:0.1:10.0, reltol = 1e-8)
65+
66+
ofv = 0.0
67+
if any((s.retcode != :Success for s in sol))
68+
ofv = 1e12
69+
else
70+
ofv = sum((sol.-realsol).^2)
71+
end
72+
return ofv
73+
end
74+
fn(x,p) = objfun(x, p[1], p[2], p[3])
75+
optfun = OptimizationFunction(fn, GalacticOptim.AutoForwardDiff())
76+
optprob = OptimizationProblem(optfun, -rand(length(coeffs)), (prob, realsol, cache), chunk_size = 2)
77+
newtonsol = solve(optprob, Newton())
78+
79+
@test all(abs.(coeffs[:] .- newtonsol.u) .< 1e-2)
80+
81+
#an example where chunk_sizes are not the same on all differentiation levels:
82+
cache = dualcache(zeros(ps,ps), [9, 9, 2])
83+
realsol = solve(prob, TRBDF2(chunk_size = 2), saveat = 0.0:0.1:10.0, reltol = 1e-8)
84+
85+
function objfun(x, prob, realsol, cache)
86+
prob = remake(prob, u0 = eltype(x).(prob.u0), p = (x, cache))
87+
sol = solve(prob, TRBDF2(chunk_size = 2), saveat = 0.0:0.1:10.0, reltol = 1e-8)
88+
89+
ofv = 0.0
90+
if any((s.retcode != :Success for s in sol))
91+
ofv = 1e12
92+
else
93+
ofv = sum((sol.-realsol).^2)
94+
end
95+
return ofv
96+
end
97+
98+
fn(x,p) = objfun(x, p[1], p[2], p[3])
99+
100+
optfun = OptimizationFunction(fn, GalacticOptim.AutoForwardDiff())
101+
optprob = OptimizationProblem(optfun, -rand(length(coeffs)), (prob, realsol, cache))
102+
newtonsol2 = solve(optprob, Newton())
103+
104+
@test all(abs.(coeffs[:] .- newtonsol2.u) .< 1e-2)

test/gpu_all.jl

Lines changed: 16 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,16 @@
1-
using LinearAlgebra, OrdinaryDiffEq, Test, PreallocationTools, CUDA
1+
using LinearAlgebra, OrdinaryDiffEq, Test, PreallocationTools, CUDA, ForwardDiff, ArrayInterface
22

33
#Dispatch tests
4+
chunk_size = 5
45
u0_CU = cu(ones(5,5))
5-
dual_CU = cu(zeros(ForwardDiff.Dual{ForwardDiff.Tag{typeof(something), Float64}, Float64, chunk_size}, 2, 2))
6+
dual_CU = cu(zeros(ForwardDiff.Dual{ForwardDiff.Tag{typeof(something), Float32}, Float32, chunk_size}, 2, 2))
7+
dual_N = ForwardDiff.Dual{ForwardDiff.Tag{typeof(something), Float32}, Float32, 5}(0)
68
cache_CU = dualcache(u0_CU, chunk_size)
79
tmp_du_CUA = get_tmp(cache_CU, u0_CU)
810
tmp_dual_du_CUA = get_tmp(cache_CU, dual_CU)
9-
tmp_du_CUN = get_tmp(cache_CU, u0_CU[1])
10-
tmp_dual_du_CUN = get_tmp(cache_CU, dual_CU[1])
11-
@test typeof(cache_CU.dual_du) == typeof(u0_CU) #check that dual cache array is a GPU array for performance reasons.
11+
tmp_du_CUN = get_tmp(cache_CU, 0.0)
12+
tmp_dual_du_CUN = get_tmp(cache_CU, dual_N)
13+
@test ArrayInterface.parameterless_type(typeof(cache_CU.dual_du)) == ArrayInterface.parameterless_type(typeof(u0_CU)) #check that dual cache array is a GPU array for performance reasons.
1214
@test size(tmp_du_CUA) == size(u0_CU)
1315
@test typeof(tmp_du_CUA) == typeof(u0_CU)
1416
@test eltype(tmp_du_CUA) == eltype(u0_CU)
@@ -33,37 +35,34 @@ end
3335
chunk_size = 10
3436
u0 = cu(rand(10,10)) #example kept small for test purposes.
3537
A = cu(-randn(10,10))
36-
cache = dualcache(A, chunk_size)
37-
prob = ODEProblem(foo, u0, (0.0f0,1.0f0), (A, cache))
38+
cache = dualcache(cu(zeros(10, 10)), chunk_size)
39+
prob = ODEProblem(foo, u0, (0.0f0, 1.0f0), (A, cache))
3840
sol = solve(prob, TRBDF2(chunk_size = chunk_size))
3941
@test sol.retcode == :Success
4042

4143
#with auto-detected chunk_size
4244
u0 = cu(rand(10,10)) #example kept small for test purposes.
4345
A = cu(-randn(10,10))
44-
cache = dualcache(A)
46+
cache = dualcache(cu(zeros(10, 10)))
4547
prob = ODEProblem(foo, u0, (0.0f0,1.0f0), (A, cache))
4648
sol = solve(prob, TRBDF2())
4749
@test sol.retcode == :Success
4850

51+
#resizing tests
4952
randmat = cu(rand(5, 3))
5053
sto = similar(randmat)
5154
stod = dualcache(sto)
5255
function claytonsample!(sto, τ, α; randmat=randmat)
5356
sto = get_tmp(sto, τ)
54-
sto .= randmat
57+
sto .= randmat
5558
τ == 0 && return sto
56-
n = size(sto, 1)
57-
for i in 1:n
58-
v = sto[i, 2]
59-
u = sto[i, 1]
60-
sto[i, 1] = (1 - u^(-τ) + u^(-τ)*v^(-/(1 + τ))))^(-1/τ)*α
61-
sto[i, 2] = (1 - u^(-τ) + u^(-τ)*v^(-/(1 + τ))))^(-1/τ)
62-
end
59+
v = @view sto[:, 2]
60+
u = @view sto[:, 1]
61+
@. v = (1 - u^(-τ) + u^(-τ)*v^(-/(1 + τ))))^(-1/τ)*α
62+
@. u = (1 - u^(-τ) + u^(-τ)*v^(-/(1 + τ))))^(-1/τ)
6363
return sto
6464
end
6565

66-
#resizing tests
6766
#taking the derivative of claytonsample! with respect to τ only
6867
df1 = ForwardDiff.derivative-> claytonsample!(stod, τ, 0.0), 0.3)
6968
@test size(randmat) == size(df1)

test/runtests.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,8 @@ end
1313
if GROUP == "All" || GROUP == "Core"
1414
@safetestset "Dispatch" begin include("core_dispatch.jl") end
1515
@safetestset "ODE tests" begin include("core_odes.jl") end
16-
@safetestset "Base Array Resizing" begin include("core_resizing.jl") end
16+
@safetestset "Resizing" begin include("core_resizing.jl") end
17+
@safetestset "Nested Duals" begin include("core_nesteddual.jl") end
1718
end
1819

1920
if !is_APPVEYOR && GROUP == "GPU"

0 commit comments

Comments
 (0)