Skip to content

Commit 411ec2a

Browse files
committed
added kwarg levels for standard chunking
1 parent ab4eba2 commit 411ec2a

File tree

2 files changed

+10
-15
lines changed

2 files changed

+10
-15
lines changed

src/PreallocationTools.jl

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7,27 +7,23 @@ struct DiffCache{T<:AbstractArray, S<:AbstractArray}
77
dual_du::S
88
end
99

10-
function DiffCache(u::AbstractArray{T}, siz, chunk_size::Int) where {T}
11-
x = adapt(ArrayInterface.parameterless_type(u), zeros(T, (chunk_size+1)*prod(siz)))
12-
DiffCache(u, x)
13-
end
14-
15-
function DiffCache(u::AbstractArray{T}, siz, chunk_sizes::AbstractArray{V}) where {T,V<:Int}
16-
clamp!(chunk_sizes,1,ForwardDiff.DEFAULT_CHUNK_THRESHOLD)
10+
function DiffCache(u::AbstractArray{T}, siz, chunk_sizes) where {T}
1711
x = adapt(ArrayInterface.parameterless_type(u), zeros(T, prod(chunk_sizes .+ 1)*prod(siz)))
1812
DiffCache(u, x)
1913
end
2014

2115
"""
2216
23-
`dualcache(u::AbstractArray, N = default_cache_size(length(u)))`
17+
`dualcache(u::AbstractArray, N::Int = ForwardDiff.pickchunksize(length(u)); levels::Int = 1)`
18+
`dualcache(u::AbstractArray; N::AbstractArray{<:Int})`
2419
2520
Builds a `DualCache` object that stores both a version of the cache for `u`
2621
and for the `Dual` version of `u`, allowing use of pre-cached vectors with
27-
forward-mode automatic differentiation.
22+
forward-mode automatic differentiation. Supports nested AD.
2823
2924
"""
30-
dualcache(u::AbstractArray, N=ForwardDiff.pickchunksize(length(u))) = DiffCache(u, size(u), N)
25+
dualcache(u::AbstractArray, N::Int=ForwardDiff.pickchunksize(length(u)); levels::Int = 1) = DiffCache(u, size(u), N^levels)
26+
dualcache(u::AbstractArray, N::AbstractArray{<:Int}) = DiffCache(u, size(u), N)
3127

3228
"""
3329

test/core_nesteddual.jl

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -42,9 +42,9 @@ function foo(du, u, p, t)
4242
nothing
4343
end
4444

45-
ps = 2 #use to specify problem size (ps ∈ {1,2})
46-
coeffs = rand(ps^2)
47-
cache = dualcache(zeros(ps,ps), [4, 4, 4])
45+
ps = 3 #use to specify problem size; don't go crazy on this, because compilation time...
46+
coeffs = -rand(ps^2)
47+
cache = dualcache(zeros(ps,ps), [9, 9, 9])
4848
prob = ODEProblem(foo, ones(ps, ps), (0., 1.0), (coeffs, cache))
4949
realsol = solve(prob, TRBDF2(), saveat = 0.0:0.01:1.0, reltol = 1e-8)
5050

@@ -66,7 +66,7 @@ fn(x,p) = objfun(x, p[1], p[2], p[3])
6666
optfun = OptimizationFunction(fn, GalacticOptim.AutoForwardDiff())
6767
optprob = OptimizationProblem(optfun, rand(size(coeffs)...), (prob, realsol, cache))
6868
newtonsol = solve(optprob, Newton())
69-
bfgssol = solve(optprob, BFGS()) #since only gradients are used here, we could go with a slim dualcache(zeros(ps,ps), [4,4]) as well.
69+
bfgssol = solve(optprob, BFGS()) #since only gradients are used here, we could go with a smaller dualcache(zeros(ps,ps), [9,9]) as well.
7070

7171
@test all(abs.(coeffs .- newtonsol.u) .< 1e-3)
7272
@test all(abs.(coeffs .- bfgssol.u) .< 1e-3)
@@ -81,7 +81,6 @@ function foo(du, u, p, t)
8181
nothing
8282
end
8383

84-
ps = 2 #use to specify problem size (ps ∈ {1,2})
8584
coeffs = rand(1)
8685
cache = dualcache(zeros(ps,ps), [1, 1, 4])
8786
prob = ODEProblem(foo, ones(ps, ps), (0., 1.0), (coeffs, cache))

0 commit comments

Comments
 (0)