Skip to content

Commit 55d1268

Browse files
Merge pull request #53 from SciML/resizing
Add FixedSizeDiffCache
2 parents 5c8e225 + a9282ea commit 55d1268

File tree

8 files changed

+206
-46
lines changed

8 files changed

+206
-46
lines changed

src/PreallocationTools.jl

Lines changed: 82 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,73 @@ module PreallocationTools
22

33
using ForwardDiff, ArrayInterfaceCore, Adapt
44

5+
struct FixedSizeDiffCache{T <: AbstractArray, S <: AbstractArray}
6+
du::T
7+
dual_du::S
8+
any_du::Vector{Any}
9+
end
10+
11+
function FixedSizeDiffCache(u::AbstractArray{T}, siz,
12+
::Type{Val{chunk_size}}) where {T, chunk_size}
13+
x = ArrayInterfaceCore.restructure(u,
14+
zeros(ForwardDiff.Dual{nothing, T, chunk_size},
15+
siz...))
16+
xany = Any[]
17+
FixedSizeDiffCache(deepcopy(u), x, xany)
18+
end
19+
20+
"""
21+
22+
`FixedSizeDiffCache(u::AbstractArray, N = Val{default_cache_size(length(u))})`
23+
24+
Builds a `DualCache` object that stores both a version of the cache for `u`
25+
and for the `Dual` version of `u`, allowing use of pre-cached vectors with
26+
forward-mode automatic differentiation.
27+
"""
28+
function FixedSizeDiffCache(u::AbstractArray,
29+
::Type{Val{N}} = Val{ForwardDiff.pickchunksize(length(u))}) where {
30+
N
31+
}
32+
FixedSizeDiffCache(u, size(u), Val{N})
33+
end
34+
35+
function FixedSizeDiffCache(u::AbstractArray, N::Integer)
36+
FixedSizeDiffCache(u, size(u), Val{N})
37+
end
38+
39+
chunksize(::Type{ForwardDiff.Dual{T, V, N}}) where {T, V, N} = N
40+
41+
function get_tmp(dc::FixedSizeDiffCache, u::T) where {T <: ForwardDiff.Dual}
42+
x = reinterpret(T, dc.dual_du)
43+
if chunksize(T) === chunksize(eltype(dc.dual_du))
44+
x
45+
else
46+
@view x[axes(dc.du)...]
47+
end
48+
end
49+
50+
function get_tmp(dc::FixedSizeDiffCache, u::AbstractArray{T}) where {T <: ForwardDiff.Dual}
51+
x = reinterpret(T, dc.dual_du)
52+
if chunksize(T) === chunksize(eltype(dc.dual_du))
53+
x
54+
else
55+
@view x[axes(dc.du)...]
56+
end
57+
end
58+
59+
function get_tmp(dc::FixedSizeDiffCache, u::Union{Number, AbstractArray})
60+
if promote_type(eltype(dc.du), eltype(u)) <: eltype(dc.du)
61+
dc.du
62+
else
63+
if length(dc.du) > length(dc.any_du)
64+
resize!(dc.any_du, length(dc.du))
65+
end
66+
_restructure(dc.du, dc.any_du)
67+
end
68+
end
69+
70+
# DiffCache
71+
572
struct DiffCache{T <: AbstractArray, S <: AbstractArray}
673
du::T
774
dual_du::S
@@ -16,25 +83,27 @@ function DiffCache(u::AbstractArray{T}, siz, chunk_sizes) where {T}
1683
end
1784

1885
"""
86+
`DiffCache(u::AbstractArray, N::Int = ForwardDiff.pickchunksize(length(u)); levels::Int = 1)`
87+
`DiffCache(u::AbstractArray; N::AbstractArray{<:Int})`
1988
20-
`dualcache(u::AbstractArray, N::Int = ForwardDiff.pickchunksize(length(u)); levels::Int = 1)`
21-
`dualcache(u::AbstractArray; N::AbstractArray{<:Int})`
22-
23-
Builds a `DualCache` object that stores both a version of the cache for `u`
89+
Builds a `DiffCache` object that stores both a version of the cache for `u`
2490
and for the `Dual` version of `u`, allowing use of pre-cached vectors with
2591
forward-mode automatic differentiation. Supports nested AD via keyword `levels`
2692
or specifying an array of chunk_sizes.
2793
2894
"""
29-
function dualcache(u::AbstractArray, N::Int = ForwardDiff.pickchunksize(length(u));
95+
function DiffCache(u::AbstractArray, N::Int = ForwardDiff.pickchunksize(length(u));
3096
levels::Int = 1)
3197
DiffCache(u, size(u), N * ones(Int, levels))
3298
end
33-
dualcache(u::AbstractArray, N::AbstractArray{<:Int}) = DiffCache(u, size(u), N)
34-
function dualcache(u::AbstractArray, ::Type{Val{N}}; levels::Int = 1) where {N}
35-
dualcache(u, N; levels)
99+
DiffCache(u::AbstractArray, N::AbstractArray{<:Int}) = DiffCache(u, size(u), N)
100+
function DiffCache(u::AbstractArray, ::Type{Val{N}}; levels::Int = 1) where {N}
101+
DiffCache(u, N; levels)
36102
end
37-
dualcache(u::AbstractArray, ::Val{N}; levels::Int = 1) where {N} = dualcache(u, N; levels)
103+
DiffCache(u::AbstractArray, ::Val{N}; levels::Int = 1) where {N} = dualcache(u, N; levels)
104+
105+
# Legacy deprecate later
106+
const dualcache = DiffCache
38107

39108
"""
40109
@@ -89,6 +158,8 @@ function enlargedualcache!(dc, nelem) #warning comes only once per dualcache.
89158
resize!(dc.dual_du, nelem)
90159
end
91160

161+
# LazyBufferCache
162+
92163
"""
93164
b = LazyBufferCache(f=identity)
94165
@@ -112,6 +183,7 @@ function Base.getindex(b::LazyBufferCache, u::T) where {T <: AbstractArray}
112183
return buf
113184
end
114185

115-
export dualcache, get_tmp, LazyBufferCache
186+
export FixedSizeDiffCache, DiffCache, LazyBufferCache
187+
export get_tmp
116188

117189
end

test/core_dispatch.jl

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ using LinearAlgebra, Test, PreallocationTools, ForwardDiff, LabelledArrays,
22
RecursiveArrayTools
33

44
function test(u0, dual, chunk_size)
5-
cache = PreallocationTools.dualcache(u0, chunk_size)
5+
cache = PreallocationTools.DiffCache(u0, chunk_size)
66
allocs_normal1 = @allocated get_tmp(cache, u0)
77
allocs_normal2 = @allocated get_tmp(cache, first(u0))
88
allocs_dual1 = @allocated get_tmp(cache, dual)
@@ -41,6 +41,28 @@ results = test(u0, dual, chunk_size)
4141
@test eltype(results[5]) == eltype(u0)
4242
@test eltype(results[7]) == eltype(dual)
4343

44+
chunk_size = 5
45+
u0_B = ones(5, 5)
46+
dual_B = zeros(ForwardDiff.Dual{ForwardDiff.Tag{typeof(something), Float64}, Float64,
47+
chunk_size}, 2, 2)
48+
cache_B = FixedSizeDiffCache(u0_B, chunk_size)
49+
tmp_du_BA = get_tmp(cache_B, u0_B)
50+
tmp_dual_du_BA = get_tmp(cache_B, dual_B)
51+
tmp_du_BN = get_tmp(cache_B, u0_B[1])
52+
tmp_dual_du_BN = get_tmp(cache_B, dual_B[1])
53+
@test size(tmp_du_BA) == size(u0_B)
54+
@test typeof(tmp_du_BA) == typeof(u0_B)
55+
@test eltype(tmp_du_BA) == eltype(u0_B)
56+
@test size(tmp_dual_du_BA) == size(u0_B)
57+
@test_broken typeof(tmp_dual_du_BA) == typeof(dual_B)
58+
@test eltype(tmp_dual_du_BA) == eltype(dual_B)
59+
@test size(tmp_du_BN) == size(u0_B)
60+
@test typeof(tmp_du_BN) == typeof(u0_B)
61+
@test eltype(tmp_du_BN) == eltype(u0_B)
62+
@test size(tmp_dual_du_BN) == size(u0_B)
63+
@test_broken typeof(tmp_dual_du_BN) == typeof(dual_B)
64+
@test eltype(tmp_dual_du_BN) == eltype(dual_B)
65+
4466
#LArray tests
4567
chunk_size = 4
4668
u0 = LArray((2, 2); a = 1.0, b = 1.0, c = 1.0, d = 1.0)
@@ -94,3 +116,27 @@ results = test(u0, dual, chunk_size)
94116
#eltype tests
95117
@test eltype(results[5]) == eltype(u0)
96118
@test eltype(results[7]) == eltype(dual)
119+
120+
u0_AP = ArrayPartition(ones(2, 2), ones(3, 3))
121+
dual_a = zeros(ForwardDiff.Dual{ForwardDiff.Tag{typeof(something), Float64}, Float64,
122+
chunk_size}, 2, 2)
123+
dual_b = zeros(ForwardDiff.Dual{ForwardDiff.Tag{typeof(something), Float64}, Float64,
124+
chunk_size}, 3, 3)
125+
dual_AP = ArrayPartition(dual_a, dual_b)
126+
cache_AP = FixedSizeDiffCache(u0_AP, chunk_size)
127+
tmp_du_APA = get_tmp(cache_AP, u0_AP)
128+
tmp_dual_du_APA = get_tmp(cache_AP, dual_AP)
129+
tmp_du_APN = get_tmp(cache_AP, u0_AP[1])
130+
tmp_dual_du_APN = get_tmp(cache_AP, dual_AP[1])
131+
@test size(tmp_du_APA) == size(u0_AP)
132+
@test typeof(tmp_du_APA) == typeof(u0_AP)
133+
@test eltype(tmp_du_APA) == eltype(u0_AP)
134+
@test size(tmp_dual_du_APA) == size(u0_AP)
135+
@test_broken typeof(tmp_dual_du_APA) == typeof(dual_AP)
136+
@test eltype(tmp_dual_du_APA) == eltype(dual_AP)
137+
@test size(tmp_du_APN) == size(u0_AP)
138+
@test typeof(tmp_du_APN) == typeof(u0_AP)
139+
@test eltype(tmp_du_APN) == eltype(u0_AP)
140+
@test size(tmp_dual_du_APN) == size(u0_AP)
141+
@test_broken typeof(tmp_dual_du_APN) == typeof(dual_AP)
142+
@test eltype(tmp_dual_du_APN) == eltype(dual_AP)

test/core_nesteddual.jl

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -19,27 +19,27 @@ function claytonsample!(sto, τ, α; randmat = randmat)
1919
end
2020

2121
#= taking the second derivative of claytonsample! with respect to τ with manual chunk_sizes.
22-
In setting up the dualcache, we are setting chunk_size to [1, 1], because we differentiate
22+
In setting up the DiffCache, we are setting chunk_size to [1, 1], because we differentiate
2323
only with respect to τ. This initializes the cache with the minimum memory needed. =#
24-
stod = dualcache(sto, [1, 1])
24+
stod = DiffCache(sto, [1, 1])
2525
df3 = ForwardDiff.derivative-> ForwardDiff.derivative-> claytonsample!(stod, ξ, 0.0),
2626
τ), 0.3)
2727

28-
#= taking the second derivative of claytonsample! with respect to τ with auto-detected chunk-size.
29-
For the given size of sto, ForwardDiff's heuristic chooses chunk_size = 8. Since this is greater
28+
#= taking the second derivative of claytonsample! with respect to τ with auto-detected chunk-size.
29+
For the given size of sto, ForwardDiff's heuristic chooses chunk_size = 8. Since this is greater
3030
than what's needed (1+1), the auto-allocated cache is big enough to handle the nested dual numbers, even
31-
if we don't specify the keyword argument levels = 2. This should in general not be relied on to work,
31+
if we don't specify the keyword argument levels = 2. This should in general not be relied on to work,
3232
especially if more levels of nesting occur (see optimization example below). =#
33-
stod = dualcache(sto)
33+
stod = DiffCache(sto)
3434
df4 = ForwardDiff.derivative-> ForwardDiff.derivative-> claytonsample!(stod, ξ, 0.0),
3535
τ), 0.3)
3636

3737
@test df3 df4
3838

39-
#= taking the second derivative of claytonsample! with respect to τ with auto-detected chunk-size.
39+
#= taking the second derivative of claytonsample! with respect to τ with auto-detected chunk-size.
4040
For the given size of sto, ForwardDiff's heuristic chooses chunk_size = 8 and with keyword arg levels = 2,
4141
the created cache size is larger than what's needed (even more so than the last example). =#
42-
stod = dualcache(sto, levels = 2)
42+
stod = DiffCache(sto, levels = 2)
4343
df5 = ForwardDiff.derivative-> ForwardDiff.derivative-> claytonsample!(stod, ξ, 0.0),
4444
τ), 0.3)
4545

@@ -58,7 +58,7 @@ end
5858

5959
ps = 2 #use to specify problem size; don't go crazy on this, because of the compilation time...
6060
coeffs = -collect(0.1:0.1:(ps^2 / 10))
61-
cache = dualcache(zeros(ps, ps), levels = 3)
61+
cache = DiffCache(zeros(ps, ps), levels = 3)
6262
prob = ODEProblem{true, SciMLBase.FullSpecialize}(foo, ones(ps, ps), (0.0, 1.0),
6363
(coeffs, cache))
6464
realsol = solve(prob, TRBDF2(), saveat = 0.0:0.1:10.0, reltol = 1e-8)
@@ -68,7 +68,7 @@ function objfun(x, prob, realsol, cache)
6868
sol = solve(prob, TRBDF2(), saveat = 0.0:0.1:10.0, reltol = 1e-8)
6969

7070
ofv = 0.0
71-
if any((s.retcode != :Success for s in sol))
71+
if any((s.retcode != ReturnCode.Success for s in sol))
7272
ofv = 1e12
7373
else
7474
ofv = sum((sol .- realsol) .^ 2)
@@ -83,7 +83,7 @@ newtonsol = solve(optprob, Newton())
8383
@test all(abs.(coeffs .- newtonsol.u) .< 1e-3)
8484

8585
#an example where chunk_sizes are not the same on all differentiation levels:
86-
cache = dualcache(zeros(ps, ps), [4, 4, 2])
86+
cache = DiffCache(zeros(ps, ps), [4, 4, 2])
8787
prob = ODEProblem{true, SciMLBase.FullSpecialize}(foo, ones(ps, ps), (0.0, 1.0),
8888
(coeffs, cache))
8989
realsol = solve(prob, TRBDF2(chunk_size = 2), saveat = 0.0:0.1:10.0, reltol = 1e-8)
@@ -93,7 +93,7 @@ function objfun(x, prob, realsol, cache)
9393
sol = solve(prob, TRBDF2(chunk_size = 2), saveat = 0.0:0.1:10.0, reltol = 1e-8)
9494

9595
ofv = 0.0
96-
if any((s.retcode != :Success for s in sol))
96+
if any((s.retcode != ReturnCode.Success for s in sol))
9797
ofv = 1e12
9898
else
9999
ofv = sum((sol .- realsol) .^ 2)

test/core_odes.jl

Lines changed: 22 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
using LinearAlgebra, OrdinaryDiffEq, Test, PreallocationTools, LabelledArrays,
22
RecursiveArrayTools
33

4+
# upstream
5+
OrdinaryDiffEq.DiffEqBase.anyeltypedual(x::FixedSizeDiffCache, counter = 0) = Any
6+
47
#Base array
58
function foo(du, u, (A, tmp), t)
69
tmp = get_tmp(tmp, u)
@@ -12,16 +15,26 @@ end
1215
chunk_size = 5
1316
u0 = ones(5, 5)
1417
A = ones(5, 5)
15-
cache = dualcache(zeros(5, 5), chunk_size)
18+
cache = DiffCache(zeros(5, 5), chunk_size)
19+
prob = ODEProblem{true, SciMLBase.FullSpecialize}(foo, u0, (0.0, 1.0), (A, cache))
20+
sol = solve(prob, TRBDF2(chunk_size = chunk_size))
21+
@test sol.retcode == ReturnCode.Success
22+
23+
cache = FixedSizeDiffCache(zeros(5, 5), chunk_size)
1624
prob = ODEProblem{true, SciMLBase.FullSpecialize}(foo, u0, (0.0, 1.0), (A, cache))
1725
sol = solve(prob, TRBDF2(chunk_size = chunk_size))
18-
@test sol.retcode == :Success
26+
@test sol.retcode == ReturnCode.Success
1927

2028
#with auto-detected chunk_size
21-
cache = dualcache(zeros(5, 5))
29+
cache = DiffCache(zeros(5, 5))
2230
prob = ODEProblem{true, SciMLBase.FullSpecialize}(foo, ones(5, 5), (0.0, 1.0), (A, cache))
2331
sol = solve(prob, TRBDF2())
24-
@test sol.retcode == :Success
32+
@test sol.retcode == ReturnCode.Success
33+
34+
prob = ODEProblem(foo, ones(5, 5), (0.0, 1.0),
35+
(ones(5, 5), FixedSizeDiffCache(zeros(5, 5))))
36+
sol = solve(prob, TRBDF2())
37+
@test sol.retcode == ReturnCode.Success
2538

2639
#Base array with LBC
2740
function foo(du, u, (A, lbc), t)
@@ -33,7 +46,7 @@ end
3346
prob = ODEProblem{true, SciMLBase.FullSpecialize}(foo, ones(5, 5), (0.0, 1.0),
3447
(ones(5, 5), LazyBufferCache()))
3548
sol = solve(prob, TRBDF2())
36-
@test sol.retcode == :Success
49+
@test sol.retcode == ReturnCode.Success
3750

3851
#LArray
3952
A = LArray((2, 2); a = 1.0, b = 1.0, c = 1.0, d = 1.0)
@@ -48,10 +61,10 @@ end
4861
#with specified chunk_size
4962
chunk_size = 4
5063
prob = ODEProblem{true, SciMLBase.FullSpecialize}(foo, u0, (0.0, 1.0),
51-
(A, dualcache(c, chunk_size)))
64+
(A, DiffCache(c, chunk_size)))
5265
sol = solve(prob, TRBDF2(chunk_size = chunk_size))
53-
@test sol.retcode == :Success
66+
@test sol.retcode == ReturnCode.Success
5467
#with auto-detected chunk_size
55-
prob = ODEProblem{true, SciMLBase.FullSpecialize}(foo, u0, (0.0, 1.0), (A, dualcache(c)))
68+
prob = ODEProblem{true, SciMLBase.FullSpecialize}(foo, u0, (0.0, 1.0), (A, DiffCache(c)))
5669
sol = solve(prob, TRBDF2())
57-
@test sol.retcode == :Success
70+
@test sol.retcode == ReturnCode.Success

test/core_resizing.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ using Test, PreallocationTools, ForwardDiff, LinearAlgebra
33
#test for downsizing cache
44
randmat = rand(5, 3)
55
sto = similar(randmat)
6-
stod = dualcache(sto)
6+
stod = DiffCache(sto)
77

88
function claytonsample!(sto, τ, α; randmat = randmat)
99
sto = get_tmp(sto, τ)
@@ -48,7 +48,7 @@ u = [3.0, 0.0]
4848
A = ones(2, 2)
4949

5050
du = similar(u)
51-
_du = dualcache(du)
51+
_du = DiffCache(du)
5252
f = A -> loss(_du, u, A, 0.0)
5353
analyticalsolution = [3.0 0; 0 0]
5454
@test ForwardDiff.gradient(f, A) analyticalsolution

0 commit comments

Comments
 (0)