Skip to content

Commit 37d84a3

Browse files
committed
scalar indexing fix
1 parent 6029d0c commit 37d84a3

File tree

1 file changed

+6
-8
lines changed

1 file changed

+6
-8
lines changed

test/gpu_all.jl

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -48,24 +48,22 @@ prob = ODEProblem(foo, u0, (0.0f0,1.0f0), (A, cache))
4848
sol = solve(prob, TRBDF2())
4949
@test sol.retcode == :Success
5050

51+
#resizing tests
5152
randmat = cu(rand(5, 3))
5253
sto = similar(randmat)
5354
stod = dualcache(sto)
5455
function claytonsample!(sto, τ, α; randmat=randmat)
5556
sto = get_tmp(sto, τ)
56-
sto .= randmat
57+
sto .= randmat
5758
τ == 0 && return sto
5859
n = size(sto, 1)
59-
for i in 1:n
60-
v = sto[i, 2]
61-
u = sto[i, 1]
62-
sto[i, 1] = (1 - u^(-τ) + u^(-τ)*v^(-/(1 + τ))))^(-1/τ)*α
63-
sto[i, 2] = (1 - u^(-τ) + u^(-τ)*v^(-/(1 + τ))))^(-1/τ)
64-
end
60+
v = @view sto[:, 2]
61+
u = @view sto[:, 1]
62+
@. v = (1 - u^(-τ) + u^(-τ)*v^(-/(1 + τ))))^(-1/τ)*α
63+
@. u = (1 - u^(-τ) + u^(-τ)*v^(-/(1 + τ))))^(-1/τ)
6564
return sto
6665
end
6766

68-
#resizing tests
6967
#taking the derivative of claytonsample! with respect to τ only
7068
df1 = ForwardDiff.derivative-> claytonsample!(stod, τ, 0.0), 0.3)
7169
@test size(randmat) == size(df1)

0 commit comments

Comments
 (0)