Skip to content

Commit 090adf9

Browse files
reshape to automatically handle smaller sizes
1 parent f1978a9 commit 090adf9

File tree

2 files changed

+40
-1
lines changed

2 files changed

+40
-1
lines changed

src/PreallocationTools.jl

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,17 +23,34 @@ forward-mode automatic differentiation.
2323
"""
2424
dualcache(u::AbstractArray, N=Val{ForwardDiff.pickchunksize(length(u))}) = DiffCache(u, size(u), N)
2525

26+
chunksize(::Type{ForwardDiff.Dual{T,V,N}}) where {T,V,N} = N
27+
2628
function get_tmp(dc::DiffCache, u::T) where T<:ForwardDiff.Dual
2729
x = reinterpret(T, dc.dual_du)
30+
if chunksize(T) === chunksize(eltype(dc.dual_du))
31+
x
32+
else
33+
@view x[axes(dc.du)...]
34+
end
2835
end
2936

3037
function get_tmp(dc::DiffCache, u::AbstractArray{T}) where T<:ForwardDiff.Dual
3138
x = reinterpret(T, dc.dual_du)
39+
if chunksize(T) === chunksize(eltype(dc.dual_du))
40+
x
41+
else
42+
@view x[axes(dc.du)...]
43+
end
3244
end
3345

3446
function get_tmp(dc::DiffCache, u::LabelledArrays.LArray{T,N,D,Syms}) where {T,N,D,Syms}
3547
x = reinterpret(T, dc.dual_du.__x)
36-
LabelledArrays.LArray{T,N,D,Syms}(x)
48+
_x = if chunksize(T) === chunksize(eltype(dc.dual_du))
49+
x
50+
else
51+
@view x[axes(dc.du)...]
52+
end
53+
LabelledArrays.LArray{T,N,D,Syms}(_x)
3754
end
3855

3956
get_tmp(dc::DiffCache, u::Number) = dc.du

test/runtests.jl

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,3 +28,25 @@ function foo(du, u, (A, lbc), t)
2828
end
2929
prob = ODEProblem(foo, ones(5, 5), (0., 1.0), (ones(5,5), LazyBufferCache()))
3030
solve(prob, TRBDF2())
31+
32+
## Check resizing
33+
34+
randmat = rand(10, 2)
35+
sto = similar(randmat)
36+
stod = dualcache(sto)
37+
38+
function claytonsample!(sto, τ; randmat=randmat)
39+
sto = get_tmp(sto, τ)
40+
sto .= randmat
41+
τ == 0 && return sto
42+
43+
n = size(sto, 1)
44+
for i in 1:n
45+
v = sto[i, 2]
46+
u = sto[i, 1]
47+
sto[i, 2] = (1 - u^(-τ) + u^(-τ)*v^(-/(1 + τ))))^(-1/τ)
48+
end
49+
return sto
50+
end
51+
52+
ForwardDiff.derivative-> claytonsample!(stod, τ), 0.3)

0 commit comments

Comments
 (0)