Skip to content

Commit 8b289b6

Browse files
committed
Merge branch 'more-flexible-chunk-resizing' into Support-nested-duals
2 parents 715e2d5 + 8073d82 commit 8b289b6

File tree

1 file changed

+16
-17
lines changed

1 file changed

+16
-17
lines changed

test/gpu_all.jl

Lines changed: 16 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,16 @@
1-
using LinearAlgebra, OrdinaryDiffEq, Test, PreallocationTools, CUDA
1+
using LinearAlgebra, OrdinaryDiffEq, Test, PreallocationTools, CUDA, ForwardDiff, ArrayInterface
22

33
#Dispatch tests
4+
chunk_size = 5
45
u0_CU = cu(ones(5,5))
5-
dual_CU = cu(zeros(ForwardDiff.Dual{ForwardDiff.Tag{typeof(something), Float64}, Float64, chunk_size}, 2, 2))
6+
dual_CU = cu(zeros(ForwardDiff.Dual{ForwardDiff.Tag{typeof(something), Float32}, Float32, chunk_size}, 2, 2))
7+
dual_N = ForwardDiff.Dual{ForwardDiff.Tag{typeof(something), Float32}, Float32, 5}(0)
68
cache_CU = dualcache(u0_CU, chunk_size)
79
tmp_du_CUA = get_tmp(cache_CU, u0_CU)
810
tmp_dual_du_CUA = get_tmp(cache_CU, dual_CU)
9-
tmp_du_CUN = get_tmp(cache_CU, u0_CU[1])
10-
tmp_dual_du_CUN = get_tmp(cache_CU, dual_CU[1])
11-
@test typeof(cache_CU.dual_du) == typeof(u0_CU) #check that dual cache array is a GPU array for performance reasons.
11+
tmp_du_CUN = get_tmp(cache_CU, 0.0)
12+
tmp_dual_du_CUN = get_tmp(cache_CU, dual_N)
13+
@test ArrayInterface.parameterless_type(typeof(cache_CU.dual_du)) == ArrayInterface.parameterless_type(typeof(u0_CU)) #check that dual cache array is a GPU array for performance reasons.
1214
@test size(tmp_du_CUA) == size(u0_CU)
1315
@test typeof(tmp_du_CUA) == typeof(u0_CU)
1416
@test eltype(tmp_du_CUA) == eltype(u0_CU)
@@ -33,37 +35,34 @@ end
3335
chunk_size = 10
3436
u0 = cu(rand(10,10)) #example kept small for test purposes.
3537
A = cu(-randn(10,10))
36-
cache = dualcache(A, chunk_size)
37-
prob = ODEProblem(foo, u0, (0.0f0,1.0f0), (A, cache))
38+
cache = dualcache(cu(zeros(10, 10)), chunk_size)
39+
prob = ODEProblem(foo, u0, (0.0f0, 1.0f0), (A, cache))
3840
sol = solve(prob, TRBDF2(chunk_size = chunk_size))
3941
@test sol.retcode == :Success
4042

4143
#with auto-detected chunk_size
4244
u0 = cu(rand(10,10)) #example kept small for test purposes.
4345
A = cu(-randn(10,10))
44-
cache = dualcache(A)
46+
cache = dualcache(cu(zeros(10, 10)))
4547
prob = ODEProblem(foo, u0, (0.0f0,1.0f0), (A, cache))
4648
sol = solve(prob, TRBDF2())
4749
@test sol.retcode == :Success
4850

51+
#resizing tests
4952
randmat = cu(rand(5, 3))
5053
sto = similar(randmat)
5154
stod = dualcache(sto)
5255
function claytonsample!(sto, τ, α; randmat=randmat)
5356
sto = get_tmp(sto, τ)
54-
sto .= randmat
57+
sto .= randmat
5558
τ == 0 && return sto
56-
n = size(sto, 1)
57-
for i in 1:n
58-
v = sto[i, 2]
59-
u = sto[i, 1]
60-
sto[i, 1] = (1 - u^(-τ) + u^(-τ)*v^(-/(1 + τ))))^(-1/τ)*α
61-
sto[i, 2] = (1 - u^(-τ) + u^(-τ)*v^(-/(1 + τ))))^(-1/τ)
62-
end
59+
v = @view sto[:, 2]
60+
u = @view sto[:, 1]
61+
@. v = (1 - u^(-τ) + u^(-τ)*v^(-/(1 + τ))))^(-1/τ)*α
62+
@. u = (1 - u^(-τ) + u^(-τ)*v^(-/(1 + τ))))^(-1/τ)
6363
return sto
6464
end
6565

66-
#resizing tests
6766
#taking the derivative of claytonsample! with respect to τ only
6867
df1 = ForwardDiff.derivative-> claytonsample!(stod, τ, 0.0), 0.3)
6968
@test size(randmat) == size(df1)

0 commit comments

Comments
 (0)