Skip to content

Commit 1358905

Browse files
Merge pull request #102 from LilithHafner/lh/lazy-get-tmp
Make Lazy caches support get_tmp
2 parents de6e393 + aae9a51 commit 1358905

File tree

8 files changed

+50
-30
lines changed

8 files changed

+50
-30
lines changed

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -268,11 +268,11 @@ lbc = GeneralLazyBufferCache(function (p)
268268
end)
269269
```
270270

271-
then `lbc[p]` will be smart and reuse the caches. A full example looks like the following:
271+
then `lbc[p]` (or, equivalently, `get_tmp(lbc, p)`) will be smart and reuse the caches. A full example looks like the following:
272272

273273
```julia
274274
using Random, DifferentialEquations, LinearAlgebra, Optimization, OptimizationNLopt,
275-
OptimizationOptimJL, PreallocationTools
275+
OptimizationOptimJL, PreallocationTools
276276

277277
lbc = GeneralLazyBufferCache(function (p)
278278
DifferentialEquations.init(ODEProblem(ode_fnc, y₀, (0.0, T), p), Tsit5(); saveat = t)

docs/src/index.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -267,7 +267,7 @@ then `lbc[p]` will be smart and reuse the caches. A full example looks like the
267267

268268
```julia
269269
using Random, DifferentialEquations, LinearAlgebra, Optimization, OptimizationNLopt,
270-
OptimizationOptimJL, PreallocationTools
270+
OptimizationOptimJL, PreallocationTools
271271

272272
lbc = GeneralLazyBufferCache(function (p)
273273
DifferentialEquations.init(ODEProblem(ode_fnc, y₀, (0.0, T), p), Tsit5(); saveat = t)

src/PreallocationTools.jl

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -216,14 +216,16 @@ function similar_type(x::AbstractArray{T}, s::NTuple{N, Integer}) where {T, N}
216216
typeof(similar(x, ntuple(Returns(1), N)))
217217
end
218218

219-
# override the [] method
220-
function Base.getindex(b::LazyBufferCache, u::T) where {T <: AbstractArray}
219+
function get_tmp(b::LazyBufferCache, u::T) where {T <: AbstractArray}
221220
s = b.sizemap(size(u)) # required buffer size
222221
get!(b.bufs, (T, s)) do
223222
similar(u, s) # buffer to allocate if it was not found in b.bufs
224223
end::similar_type(u, s) # declare type since b.bufs dictionary is untyped
225224
end
226225

226+
# override the [] method
227+
Base.getindex(b::LazyBufferCache, u::T) where {T <: AbstractArray} = get_tmp(b, u)
228+
227229
# GeneralLazyBufferCache
228230

229231
"""
@@ -246,11 +248,12 @@ struct GeneralLazyBufferCache{F <: Function}
246248
GeneralLazyBufferCache(f::F = identity) where {F <: Function} = new{F}(Dict(), f) # start with empty dict
247249
end
248250

249-
function Base.getindex(b::GeneralLazyBufferCache, u::T) where {T}
251+
function get_tmp(b::GeneralLazyBufferCache, u::T) where {T}
250252
get!(b.bufs, T) do
251253
b.f(u)
252254
end
253255
end
256+
Base.getindex(b::GeneralLazyBufferCache, u::T) where {T} = get_tmp(b, u)
254257

255258
export GeneralLazyBufferCache, FixedSizeDiffCache, DiffCache, LazyBufferCache, dualcache
256259
export get_tmp

test/core_dispatch.jl

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
using LinearAlgebra,
2-
Test, PreallocationTools, ForwardDiff, LabelledArrays,
3-
RecursiveArrayTools
2+
Test, PreallocationTools, ForwardDiff, LabelledArrays,
3+
RecursiveArrayTools
44

55
function test(u0, dual, chunk_size)
66
cache = PreallocationTools.DiffCache(u0, chunk_size)
@@ -53,8 +53,10 @@ results = test(u0, dual, chunk_size)
5353

5454
chunk_size = 5
5555
u0_B = ones(5, 5)
56-
dual_B = zeros(ForwardDiff.Dual{ForwardDiff.Tag{typeof(something), Float64}, Float64,
57-
chunk_size}, 2, 2)
56+
dual_B = zeros(
57+
ForwardDiff.Dual{ForwardDiff.Tag{typeof(something), Float64}, Float64,
58+
chunk_size},
59+
2, 2)
5860
cache_B = FixedSizeDiffCache(u0_B, chunk_size)
5961
tmp_du_BA = get_tmp(cache_B, u0_B)
6062
tmp_dual_du_BA = get_tmp(cache_B, dual_B)
@@ -102,9 +104,11 @@ results = test(u0, dual, chunk_size)
102104
#ArrayPartition tests
103105
chunk_size = 2
104106
u0 = ArrayPartition(ones(2, 2), ones(3, 3))
105-
dual_a = zeros(ForwardDiff.Dual{ForwardDiff.Tag{nothing, Float64}, Float64,
107+
dual_a = zeros(
108+
ForwardDiff.Dual{ForwardDiff.Tag{nothing, Float64}, Float64,
106109
chunk_size}, 2, 2)
107-
dual_b = zeros(ForwardDiff.Dual{ForwardDiff.Tag{nothing, Float64}, Float64,
110+
dual_b = zeros(
111+
ForwardDiff.Dual{ForwardDiff.Tag{nothing, Float64}, Float64,
108112
chunk_size}, 3, 3)
109113
dual = ArrayPartition(dual_a, dual_b)
110114
results = test(u0, dual, chunk_size)
@@ -128,10 +132,14 @@ results = test(u0, dual, chunk_size)
128132
@test eltype(results[7]) == eltype(dual)
129133

130134
u0_AP = ArrayPartition(ones(2, 2), ones(3, 3))
131-
dual_a = zeros(ForwardDiff.Dual{ForwardDiff.Tag{typeof(something), Float64}, Float64,
132-
chunk_size}, 2, 2)
133-
dual_b = zeros(ForwardDiff.Dual{ForwardDiff.Tag{typeof(something), Float64}, Float64,
134-
chunk_size}, 3, 3)
135+
dual_a = zeros(
136+
ForwardDiff.Dual{ForwardDiff.Tag{typeof(something), Float64}, Float64,
137+
chunk_size},
138+
2, 2)
139+
dual_b = zeros(
140+
ForwardDiff.Dual{ForwardDiff.Tag{typeof(something), Float64}, Float64,
141+
chunk_size},
142+
3, 3)
135143
dual_AP = ArrayPartition(dual_a, dual_b)
136144
cache_AP = FixedSizeDiffCache(u0_AP, chunk_size)
137145
tmp_du_APA = get_tmp(cache_AP, u0_AP)

test/core_nesteddual.jl

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
using LinearAlgebra,
2-
OrdinaryDiffEq, Test, PreallocationTools, ForwardDiff, Optimization,
3-
OptimizationOptimJL
2+
OrdinaryDiffEq, Test, PreallocationTools, ForwardDiff, Optimization,
3+
OptimizationOptimJL
44

55
randmat = rand(5, 3)
66
sto = similar(randmat)
@@ -23,7 +23,8 @@ end
2323
In setting up the DiffCache, we are setting chunk_size to [1, 1], because we differentiate
2424
only with respect to τ. This initializes the cache with the minimum memory needed. =#
2525
stod = DiffCache(sto, [1, 1])
26-
df3 = ForwardDiff.derivative-> ForwardDiff.derivative-> claytonsample!(stod, ξ, 0.0),
26+
df3 = ForwardDiff.derivative(
27+
τ -> ForwardDiff.derivative-> claytonsample!(stod, ξ, 0.0),
2728
τ), 0.3)
2829

2930
#= taking the second derivative of claytonsample! with respect to τ with auto-detected chunk-size.
@@ -32,7 +33,8 @@ than what's needed (1+1), the auto-allocated cache is big enough to handle the n
3233
if we don't specify the keyword argument levels = 2. This should in general not be relied on to work,
3334
especially if more levels of nesting occur (see optimization example below). =#
3435
stod = DiffCache(sto)
35-
df4 = ForwardDiff.derivative-> ForwardDiff.derivative-> claytonsample!(stod, ξ, 0.0),
36+
df4 = ForwardDiff.derivative(
37+
τ -> ForwardDiff.derivative-> claytonsample!(stod, ξ, 0.0),
3638
τ), 0.3)
3739

3840
@test df3 df4
@@ -41,7 +43,8 @@ df4 = ForwardDiff.derivative(τ -> ForwardDiff.derivative(ξ -> claytonsample!(s
4143
For the given size of sto, ForwardDiff's heuristic chooses chunk_size = 8 and with keyword arg levels = 2,
4244
the created cache size is larger than what's needed (even more so than the last example). =#
4345
stod = DiffCache(sto, levels = 2)
44-
df5 = ForwardDiff.derivative-> ForwardDiff.derivative-> claytonsample!(stod, ξ, 0.0),
46+
df5 = ForwardDiff.derivative(
47+
τ -> ForwardDiff.derivative-> claytonsample!(stod, ξ, 0.0),
4548
τ), 0.3)
4649

4750
@test df3 df5

test/core_odes.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
using LinearAlgebra,
2-
OrdinaryDiffEq, Test, PreallocationTools, LabelledArrays,
3-
RecursiveArrayTools
2+
OrdinaryDiffEq, Test, PreallocationTools, LabelledArrays,
3+
RecursiveArrayTools
44

55
#Base array
66
function foo(du, u, (A, tmp), t)

test/general_lbc.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
using Random,
2-
OrdinaryDiffEq, LinearAlgebra, Optimization, OptimizationOptimJL,
3-
PreallocationTools
2+
OrdinaryDiffEq, LinearAlgebra, Optimization, OptimizationOptimJL,
3+
PreallocationTools
44

55
lbc = GeneralLazyBufferCache(function (p)
66
init(ODEProblem(ode_fnc, y₀,
@@ -40,6 +40,7 @@ x = rand(1000)
4040
y = view(x, 1:900)
4141
@inferred cache[y]
4242
@test 0 == @allocated cache[y]
43+
@test cache[y] === get_tmp(cache, y)
4344

4445
cache_17 = LazyBufferCache(Returns(17))
4546
x = 1:10
@@ -52,3 +53,4 @@ cache = GeneralLazyBufferCache(T -> Vector{T}(undef, 1000))
5253
# @inferred cache[Float64]
5354
cache[Float64] # generate the buffer
5455
@test 0 == @allocated cache[Float64]
56+
@test get_tmp(cache, Float64) === cache[Float64]

test/gpu_all.jl

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,16 @@
11
using LinearAlgebra,
2-
OrdinaryDiffEq, Test, PreallocationTools, CUDA, ForwardDiff
2+
OrdinaryDiffEq, Test, PreallocationTools, CUDA, ForwardDiff
33

44
# upstream
55
OrdinaryDiffEq.DiffEqBase.anyeltypedual(x::FixedSizeDiffCache, counter = 0) = Any
66

77
#Dispatch tests
88
chunk_size = 5
99
u0_CU = cu(ones(5, 5))
10-
dual_CU = cu(zeros(ForwardDiff.Dual{ForwardDiff.Tag{typeof(something), Float32}, Float32,
11-
chunk_size}, 2, 2))
10+
dual_CU = cu(zeros(
11+
ForwardDiff.Dual{ForwardDiff.Tag{typeof(something), Float32}, Float32,
12+
chunk_size},
13+
2, 2))
1214
dual_N = ForwardDiff.Dual{ForwardDiff.Tag{typeof(something), Float32}, Float32, 5}(0)
1315
cache_CU = DiffCache(u0_CU, chunk_size)
1416
tmp_du_CUA = get_tmp(cache_CU, u0_CU)
@@ -32,8 +34,10 @@ tmp_dual_du_CUN = get_tmp(cache_CU, dual_N)
3234

3335
chunk_size = 5
3436
u0_B = cu(ones(5, 5))
35-
dual_B = cu(zeros(ForwardDiff.Dual{ForwardDiff.Tag{typeof(something), Float32}, Float32,
36-
chunk_size}, 2, 2))
37+
dual_B = cu(zeros(
38+
ForwardDiff.Dual{ForwardDiff.Tag{typeof(something), Float32}, Float32,
39+
chunk_size},
40+
2, 2))
3741
cache_B = FixedSizeDiffCache(u0_B, chunk_size)
3842
tmp_du_BA = get_tmp(cache_B, u0_B)
3943
tmp_dual_du_BA = get_tmp(cache_B, dual_B)

0 commit comments

Comments
 (0)