Skip to content

Commit 12bc059

Browse files
committed
Merge branch 'more-flexible-chunk-resizing' into Support-nested-duals
2 parents 0395dce + 422ce22 commit 12bc059

File tree

3 files changed

+248
-63
lines changed

3 files changed

+248
-63
lines changed

Project.toml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ authors = ["Chris Rackauckas <accounts@chrisrackauckas.com>"]
44
version = "0.1.1"
55

66
[deps]
7+
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
78
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
89
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
910
LabelledArrays = "2ee39098-c373-598a-b85f-a56591580800"
@@ -15,9 +16,11 @@ LabelledArrays = "1"
1516
julia = "1.6"
1617

1718
[extras]
19+
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
1820
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1921
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
2022
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
23+
RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
2124

2225
[targets]
23-
test = ["LinearAlgebra", "OrdinaryDiffEq", "Test"]
26+
test = ["LinearAlgebra", "OrdinaryDiffEq", "Test", "CUDA", "RecursiveArrayTools"]

src/PreallocationTools.jl

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,26 @@
11
module PreallocationTools
22

3-
using ForwardDiff, ArrayInterface, LabelledArrays
3+
using ForwardDiff, ArrayInterface, LabelledArrays, Adapt
44

55
struct DiffCache{T<:AbstractArray, S<:AbstractArray}
66
du::T
77
dual_du::S
88
end
99

10-
function DiffCache(u::AbstractArray{T}, siz, chunk_size::V) where {T,V<:Int}
11-
x = zeros(T,(chunk_size+1)*prod(siz))
10+
function DiffCache(u::AbstractArray{T}, siz, chunk_size::Int) where {T}
11+
x = adapt(ArrayInterface.parameterless_type(u), zeros(T, (chunk_size+1)*prod(siz)))
1212
DiffCache(u, x)
1313
end
1414

1515
function DiffCache(u::AbstractArray{T}, siz, chunk_sizes::AbstractArray{V}) where {T,V<:Int}
1616
clamp!(chunk_sizes,1,ForwardDiff.DEFAULT_CHUNK_THRESHOLD)
17-
x = zeros(T,prod(chunk_sizes.+1)*prod(siz))
17+
x = zeros(T,prod(chunk_sizes.+1)*prod(siz)) adapt(ArrayInterface.parameterless_type(u), zeros(T, (chunk_sizes .+ 1)*prod(siz)))
1818
DiffCache(u, x)
1919
end
2020

2121
"""
2222
23-
`dualcache(u::AbstractArray, N = Val{default_cache_size(length(u))})`
23+
`dualcache(u::AbstractArray, N = default_cache_size(length(u)))`
2424
2525
Builds a `DualCache` object that stores both a version of the cache for `u`
2626
and for the `Dual` version of `u`, allowing use of pre-cached vectors with
@@ -37,17 +37,17 @@ Returns the `Dual` or normal cache array stored in `dc` based on the type of `u`
3737
3838
"""
3939
function get_tmp(dc::DiffCache, u::T) where T<:ForwardDiff.Dual
40-
nelem = div(sizeof(T), sizeof(eltype(dc.dual_du)))*prod(size(dc.du))
41-
ArrayInterface.restructure(dc.du, reinterpret(T, view(dc.dual_du, 1:nelem)))
40+
nelem = div(sizeof(T), sizeof(eltype(dc.dual_du)))*length(dc.du)
41+
ArrayInterface.restructure(dc.du, reinterpret(T, view(dc.dual_du, 1:nelem)))
4242
end
4343

4444
function get_tmp(dc::DiffCache, u::AbstractArray{T}) where T<:ForwardDiff.Dual
45-
nelem = div(sizeof(T), sizeof(eltype(dc.dual_du)))*prod(size(dc.du))
45+
nelem = div(sizeof(T), sizeof(eltype(dc.dual_du)))*length(dc.du)
4646
ArrayInterface.restructure(dc.du, reinterpret(T, view(dc.dual_du, 1:nelem)))
4747
end
4848

4949
function get_tmp(dc::DiffCache, u::LabelledArrays.LArray{T,N,D,Syms}) where {T,N,D,Syms}
50-
nelem = div(sizeof(T), sizeof(eltype(dc.dual_du)))*prod(size(dc.du))
50+
nelem = div(sizeof(T), sizeof(eltype(dc.dual_du)))*length(dc.du)
5151
_x = ArrayInterface.restructure(dc.du, reinterpret(T, view(dc.dual_du, 1:nelem)))
5252
LabelledArrays.LArray{T,N,D,Syms}(_x)
5353
end

test/runtests.jl

Lines changed: 235 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -10,65 +10,238 @@ end
1010
chunk_size = 5
1111
prob = ODEProblem(foo, ones(5, 5), (0., 1.0), (ones(5,5), dualcache(zeros(5,5), chunk_size)))
1212
solve(prob, TRBDF2(chunk_size=chunk_size))
13+
using LinearAlgebra, OrdinaryDiffEq, Test, PreallocationTools, ForwardDiff, LabelledArrays, CUDA, RecursiveArrayTools
1314

14-
## Check ODE problem with auto-detected chunk_size
15-
function foo(du, u, (A, tmp), t)
16-
tmp = get_tmp(tmp, u)
17-
mul!(tmp, A, u)
18-
@. du = u + tmp
19-
nothing
20-
end
21-
prob = ODEProblem(foo, ones(5, 5), (0., 1.0), (ones(5,5), dualcache(zeros(5,5))))
22-
solve(prob, TRBDF2())
15+
@testset verbose = true "PreallocationTools tests" begin
16+
@testset "Dispatch" verbose = true begin #tests dispatching without changing chunk_size
17+
chunk_size = 5
18+
#base array tests
19+
@testset "Base Arrays" begin
20+
u0_B = ones(5, 5)
21+
dual_B = zeros(ForwardDiff.Dual{ForwardDiff.Tag{typeof(something), Float64}, Float64, chunk_size}, 2, 2)
22+
cache_B = dualcache(u0_B, Val{chunk_size})
23+
tmp_du_BA = get_tmp(cache_B, u0_B)
24+
tmp_dual_du_BA = get_tmp(cache_B, dual_B)
25+
tmp_du_BN = get_tmp(cache_B, u0_B[1])
26+
tmp_dual_du_BN = get_tmp(cache_B, dual_B[1])
27+
@test size(tmp_du_BA) == size(u0_B)
28+
@test typeof(tmp_du_BA) == typeof(u0_B)
29+
@test eltype(tmp_du_BA) == eltype(u0_B)
30+
@test size(tmp_dual_du_BA) == size(u0_B)
31+
@test typeof(tmp_dual_du_BA) == typeof(dual_B)
32+
@test eltype(tmp_dual_du_BA) == eltype(dual_B)
33+
@test size(tmp_du_BN) == size(u0_B)
34+
@test typeof(tmp_du_BN) == typeof(u0_B)
35+
@test eltype(tmp_du_BN) == eltype(u0_B)
36+
@test size(tmp_dual_du_BN) == size(u0_B)
37+
@test typeof(tmp_dual_du_BN) == typeof(dual_B)
38+
@test eltype(tmp_dual_du_BN) == eltype(dual_B)
39+
end
40+
@testset "Labelled Arrays" begin
41+
chunk_size = 4
42+
u0_L = LArray((2,2); a=1.0, b=1.0, c=1.0, d=1.0)
43+
zerodual = zero(ForwardDiff.Dual{ForwardDiff.Tag{typeof(something), Float64}, Float64, chunk_size})
44+
dual_L = LArray((2,2); a=zerodual, b=zerodual, c=zerodual, d=zerodual)
45+
cache_L = dualcache(u0_L, Val{chunk_size})
46+
tmp_du_LA = get_tmp(cache_L, u0_L)
47+
tmp_dual_du_LA = get_tmp(cache_L, dual_L)
48+
tmp_du_LN = get_tmp(cache_L, u0_L[1])
49+
tmp_dual_du_LN = get_tmp(cache_L, dual_L[1])
50+
@test size(tmp_du_LA) == size(u0_L)
51+
@test typeof(tmp_du_LA) == typeof(u0_L)
52+
@test eltype(tmp_du_LA) == eltype(u0_L)
53+
@test size(tmp_dual_du_LA) == size(u0_L)
54+
@test typeof(tmp_dual_du_LA) == typeof(dual_L)
55+
@test eltype(tmp_dual_du_LA) == eltype(dual_L)
56+
@test size(tmp_du_LN) == size(u0_L)
57+
@test typeof(tmp_du_LN) == typeof(u0_L)
58+
@test eltype(tmp_du_LN) == eltype(u0_L)
59+
@test size(tmp_dual_du_LN) == size(u0_L)
60+
@test typeof(tmp_dual_du_LN) == typeof(dual_L)
61+
@test eltype(tmp_dual_du_LN) == eltype(dual_L)
62+
end
63+
@testset "Array Partitions" begin
64+
u0_AP = ArrayPartition(ones(2,2), ones(3,3))
65+
dual_a = zeros(ForwardDiff.Dual{ForwardDiff.Tag{typeof(something), Float64}, Float64, chunk_size}, 2, 2)
66+
dual_b = zeros(ForwardDiff.Dual{ForwardDiff.Tag{typeof(something), Float64}, Float64, chunk_size}, 3, 3)
67+
dual_AP = ArrayPartition(dual_a, dual_b)
68+
cache_AP = dualcache(u0_AP, Val{chunk_size})
69+
tmp_du_APA = get_tmp(cache_AP, u0_AP)
70+
tmp_dual_du_APA = get_tmp(cache_AP, dual_AP)
71+
tmp_du_APN = get_tmp(cache_AP, u0_AP[1])
72+
tmp_dual_du_APN = get_tmp(cache_AP, dual_AP[1])
73+
@test size(tmp_du_APA) == size(u0_AP)
74+
@test typeof(tmp_du_APA) == typeof(u0_AP)
75+
@test eltype(tmp_du_APA) == eltype(u0_AP)
76+
@test size(tmp_dual_du_APA) == size(u0_AP)
77+
@test typeof(tmp_dual_du_APA) == typeof(dual_AP)
78+
@test eltype(tmp_dual_du_APA) == eltype(dual_AP)
79+
@test size(tmp_du_APN) == size(u0_AP)
80+
@test typeof(tmp_du_APN) == typeof(u0_AP)
81+
@test eltype(tmp_du_APN) == eltype(u0_AP)
82+
@test size(tmp_dual_du_APN) == size(u0_AP)
83+
@test typeof(tmp_dual_du_APN) == typeof(dual_AP)
84+
@test eltype(tmp_dual_du_APN) == eltype(dual_AP)
85+
end
86+
@testset "Cu Arrays" begin
87+
u0_CU = cu(ones(5,5))
88+
dual_CU = cu(zeros(ForwardDiff.Dual{ForwardDiff.Tag{typeof(something), Float64}, Float64, chunk_size}, 2, 2))
89+
cache_CU = dualcache(u0_CU, Val{chunk_size})
90+
tmp_du_CUA = get_tmp(cache_CU, u0_CU)
91+
tmp_dual_du_CUA = get_tmp(cache_CU, dual_CU)
92+
tmp_du_CUN = get_tmp(cache_CU, u0_CU[1])
93+
tmp_dual_du_CUN = get_tmp(cache_CU, dual_CU[1])
94+
@test typeof(cache_CU.dual_du) == typeof(u0_CU) #check that dual cache array is a GPU array for performance reasons.
95+
@test size(tmp_du_CUA) == size(u0_CU)
96+
@test typeof(tmp_du_CUA) == typeof(u0_CU)
97+
@test eltype(tmp_du_CUA) == eltype(u0_CU)
98+
@test size(tmp_dual_du_CUA) == size(u0_CU)
99+
@test typeof(tmp_dual_du_CUA) == typeof(dual_CU)
100+
@test eltype(tmp_dual_du_CUA) == eltype(dual_CU)
101+
@test size(tmp_du_CUN) == size(u0_CU)
102+
@test typeof(tmp_du_CUN) == typeof(u0_CU)
103+
@test eltype(tmp_du_CUN) == eltype(u0_CU)
104+
@test size(tmp_dual_du_CUN) == size(u0_CU)
105+
@test typeof(tmp_dual_du_CUN) == typeof(dual_CU)
106+
@test eltype(tmp_dual_du_CUN) == eltype(dual_CU)
107+
end
108+
end
109+
@testset "ODE tests" verbose = true begin
110+
@testset "Base Array" begin
111+
function foo(du, u, (A, tmp), t)
112+
tmp = get_tmp(tmp, u)
113+
mul!(tmp, A, u)
114+
@. du = u + tmp
115+
nothing
116+
end
117+
#with defined chunk_size
118+
chunk_size = 5
119+
u0 = ones(5, 5)
120+
A = ones(5,5)
121+
cache = dualcache(zeros(5,5), Val{chunk_size})
122+
prob = ODEProblem(foo, u0, (0., 1.0), (A, cache))
123+
sol = solve(prob, TRBDF2(chunk_size=chunk_size))
124+
@test sol.retcode == :Success
125+
126+
#with auto-detected chunk_size
127+
prob = ODEProblem(foo, ones(5, 5), (0., 1.0), (ones(5,5), dualcache(zeros(5,5))))
128+
sol = solve(prob, TRBDF2())
129+
@test sol.retcode == :Success
130+
end
23131

24-
## Check ODE problem with a lazy buffer cache
25-
function foo(du, u, (A, lbc), t)
26-
tmp = lbc[u]
27-
mul!(tmp, A, u)
28-
@. du = u + tmp
29-
nothing
30-
end
31-
prob = ODEProblem(foo, ones(5, 5), (0., 1.0), (ones(5,5), LazyBufferCache()))
32-
solve(prob, TRBDF2())
132+
@testset "Base Array and LBC" begin
133+
function foo(du, u, (A, lbc), t)
134+
tmp = lbc[u]
135+
mul!(tmp, A, u)
136+
@. du = u + tmp
137+
nothing
138+
end
139+
prob = ODEProblem(foo, ones(5, 5), (0., 1.0), (ones(5,5), LazyBufferCache()))
140+
sol = solve(prob, TRBDF2())
141+
@test sol.retcode == :Success
142+
end
33143

34-
## Check ODE problem with auto-detected chunk_size and LArray
35-
A = LArray((2,2); a=1.0, b=1.0, c=1.0, d=1.0)
36-
u0 = LArray((2,2); a=1.0, b=1.0, c=1.0, d=1.0)
37-
function foo(du, u, (A, tmp), t)
38-
tmp = get_tmp(tmp, u)
39-
mul!(tmp, A, u)
40-
@. du = u + tmp
41-
nothing
42-
end
43-
prob = ODEProblem(foo, u0, (0., 1.0), (A, dualcache(A)))
44-
solve(prob, TRBDF2())
45-
46-
## Check resizing
47-
randmat = rand(5, 3)
48-
sto = similar(randmat)
49-
stod = dualcache(sto)
50-
51-
function claytonsample!(sto, τ, α; randmat=randmat)
52-
sto = get_tmp(sto, τ)
53-
sto .= randmat
54-
τ == 0 && return sto
55-
56-
n = size(sto, 1)
57-
for i in 1:n
58-
v = sto[i, 2]
59-
u = sto[i, 1]
60-
sto[i, 1] = (1 - u^(-τ) + u^(-τ)*v^(-/(1 + τ))))^(-1/τ)*α
61-
sto[i, 2] = (1 - u^(-τ) + u^(-τ)*v^(-/(1 + τ))))^(-1/τ)
144+
@testset "LArray" begin
145+
A = LArray((2,2); a=1.0, b=1.0, c=1.0, d=1.0)
146+
c = LArray((2,2); a=0.0, b=0.0, c=0.0, d=0.0)
147+
u0 = LArray((2,2); a=1.0, b=1.0, c=1.0, d=1.0)
148+
function foo(du, u, (A, tmp), t)
149+
tmp = get_tmp(tmp, u)
150+
mul!(tmp, A, u)
151+
@. du = u + tmp
152+
nothing
153+
end
154+
#with specified chunk_size
155+
chunk_size = 4
156+
prob = ODEProblem(foo, u0, (0., 1.0), (A, dualcache(c, Val{chunk_size})))
157+
sol = solve(prob, TRBDF2(chunk_size = chunk_size))
158+
@test sol.retcode == :Success
159+
#with auto-detected chunk_size
160+
prob = ODEProblem(foo, u0, (0., 1.0), (A, dualcache(c)))
161+
sol = solve(prob, TRBDF2())
162+
@test sol.retcode == :Success
163+
end
164+
165+
@testset "cuarray" begin
166+
function foo(du, u, (A, tmp), t)
167+
tmp = get_tmp(tmp, u)
168+
mul!(tmp, A, u)
169+
@. du = u + tmp
170+
nothing
171+
end
172+
#with specified chunk_size
173+
chunk_size = 10
174+
u0 = cu(rand(10,10)) #example kept small for test purposes.
175+
A = cu(-randn(10,10))
176+
cache = dualcache(A, Val{chunk_size})
177+
prob = ODEProblem(foo, u0, (0.0f0,1.0f0), (A, cache))
178+
sol = solve(prob, TRBDF2(chunk_size = chunk_size))
179+
@test sol.retcode == :Success
180+
181+
#with auto-detected chunk_size
182+
u0 = cu(rand(10,10)) #example kept small for test purposes.
183+
A = cu(-randn(10,10))
184+
cache = dualcache(A)
185+
prob = ODEProblem(foo, u0, (0.0f0,1.0f0), (A, cache))
186+
sol = solve(prob, TRBDF2())
187+
@test sol.retcode == :Success
188+
end
62189
end
63-
return sto
64-
end
65190

66-
#taking the derivative of claytonsample! with respect to τ only
67-
df1 = ForwardDiff.derivative-> claytonsample!(stod, τ, 0.0), 0.3)
191+
@testset "Change of chunk_size" verbose = true begin
192+
@testset "Base array" begin
193+
randmat = rand(5, 3)
194+
sto = similar(randmat)
195+
stod = dualcache(sto)
196+
197+
function claytonsample!(sto, τ, α; randmat=randmat)
198+
sto = get_tmp(sto, τ)
199+
sto .= randmat
200+
τ == 0 && return sto
201+
202+
n = size(sto, 1)
203+
for i in 1:n
204+
v = sto[i, 2]
205+
u = sto[i, 1]
206+
sto[i, 1] = (1 - u^(-τ) + u^(-τ)*v^(-/(1 + τ))))^(-1/τ)*α
207+
sto[i, 2] = (1 - u^(-τ) + u^(-τ)*v^(-/(1 + τ))))^(-1/τ)
208+
end
209+
return sto
210+
end
68211

69-
#calculating the jacobian of claytonsample! with respect to τ and α
70-
df2 = ForwardDiff.jacobian(x -> claytonsample!(stod, x[1], x[2]), [0.3; 0.0]) #should give a 15x2 array,
71-
#because ForwardDiff flattens the output of jacobian, see: https://juliadiff.org/ForwardDiff.jl/stable/user/api/#ForwardDiff.jacobian
212+
#taking the derivative of claytonsample! with respect to τ only
213+
df1 = ForwardDiff.derivative-> claytonsample!(stod, τ, 0.0), 0.3)
214+
@test size(randmat) == size(df1)
215+
216+
#calculating the jacobian of claytonsample! with respect to τ and α
217+
df2 = ForwardDiff.jacobian(x -> claytonsample!(stod, x[1], x[2]), [0.3; 0.0]) #should give a 15x2 array,
218+
#because ForwardDiff flattens the output of jacobian, see: https://juliadiff.org/ForwardDiff.jl/stable/user/api/#ForwardDiff.jacobian
219+
220+
@test (length(randmat), 2) == size(df2)
221+
@test df1[1:5,2] df2[6:10,1]
222+
end
223+
224+
@testset "cuarray" begin
225+
randmat = cu(rand(5, 3))
226+
sto = similar(randmat)
227+
stod = dualcache(sto)
228+
function claytonsample!(sto, τ, α; randmat=randmat)
229+
sto = get_tmp(sto, τ)
230+
sto .= randmat
231+
τ == 0 && return sto
232+
n = size(sto, 1)
233+
for i in 1:n
234+
v = sto[i, 2]
235+
u = sto[i, 1]
236+
sto[i, 1] = (1 - u^(-τ) + u^(-τ)*v^(-/(1 + τ))))^(-1/τ)*α
237+
sto[i, 2] = (1 - u^(-τ) + u^(-τ)*v^(-/(1 + τ))))^(-1/τ)
238+
end
239+
return sto
240+
end
241+
242+
#taking the derivative of claytonsample! with respect to τ only
243+
df1 = ForwardDiff.derivative-> claytonsample!(stod, τ, 0.0), 0.3)
244+
@test size(randmat) == size(df1)
72245

73246
@test df1[1:5,2] df2[6:10,1]
74247

@@ -165,3 +338,12 @@ optprob = OptimizationProblem(optfun, rand(size(coeffs)...), (prob, realsol, cac
165338
newtonsol2 = solve(optprob, Newton())
166339

167340
@test all(abs.(coeffs .- newtonsol2.u) .< 1e-3)
341+
#calculating the jacobian of claytonsample! with respect to τ and α
342+
df2 = ForwardDiff.jacobian(x -> claytonsample!(stod, x[1], x[2]), [0.3; 0.0]) #should give a 15x2 array,
343+
#because ForwardDiff flattens the output of jacobian, see: https://juliadiff.org/ForwardDiff.jl/stable/user/api/#ForwardDiff.jacobian
344+
345+
@test (length(randmat), 2) == size(df2)
346+
@test df1[1:5,2] df2[6:10,1]
347+
end
348+
end
349+
end

0 commit comments

Comments
 (0)