Skip to content

Commit 826c85c

Browse files
committed
re-introduced ArrayInterface.restructure
1 parent ef77d8d commit 826c85c

File tree

2 files changed

+9
-8
lines changed

2 files changed

+9
-8
lines changed

Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,12 @@ authors = ["Chris Rackauckas <accounts@chrisrackauckas.com>"]
44
version = "0.1.1"
55

66
[deps]
7+
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
78
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
89
LabelledArrays = "2ee39098-c373-598a-b85f-a56591580800"
910

1011
[compat]
12+
ArrayInterface = "2.6, 3.0"
1113
ForwardDiff = "0.10.3"
1214
LabelledArrays = "1"
1315
julia = "1.6"

src/PreallocationTools.jl

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,23 @@
11
module PreallocationTools
22

3-
using ForwardDiff, LabelledArrays
3+
using ForwardDiff, ArrayInterface, LabelledArrays
44

55
struct DiffCache{T<:AbstractArray, S<:AbstractArray}
66
du::T
77
dual_du::S
88
end
99

10-
#= removed dependency on ArrayInterface, because it seemed not necessary anymore;
11-
not sure whether it breaks things that are not in the testset; needs checking. =#
1210
function DiffCache(u::AbstractArray{T}, siz, ::Type{Val{chunk_size}}) where {T, chunk_size}
13-
x = zeros(T,(chunk_size+1)*prod(siz))
11+
x = zeros(T,(chunk_size+1)*prod(siz))
1412
DiffCache(u, x)
1513
end
1614

1715
"""
1816
1917
`dualcache(u::AbstractArray, N = Val{default_cache_size(length(u))})`
2018
21-
Builds a `DualCache` object that stores versions of the cache for `u` and for the `Dual` version of `u` allowing use of pre-cached arrays with
19+
Builds a `DualCache` object that stores both a version of the cache for `u`
20+
and for the `Dual` version of `u`, allowing use of pre-cached vectors with
2221
forward-mode automatic differentiation.
2322
2423
"""
@@ -33,17 +32,17 @@ Returns the `Dual` or normal cache array stored in `dc` based on the type of `u`
3332
"""
3433
function get_tmp(dc::DiffCache, u::T) where T<:ForwardDiff.Dual
3534
nelem = div(sizeof(T), sizeof(eltype(dc.dual_du)))*prod(size(dc.du))
36-
reshape(reinterpret(T, view(dc.dual_du, 1:nelem)), size(dc.du))
35+
ArrayInterface.restructure(dc.du, reinterpret(T, view(dc.dual_du, 1:nelem)))
3736
end
3837

3938
function get_tmp(dc::DiffCache, u::AbstractArray{T}) where T<:ForwardDiff.Dual
4039
nelem = div(sizeof(T), sizeof(eltype(dc.dual_du)))*prod(size(dc.du))
41-
reshape(reinterpret(T, view(dc.dual_du, 1:nelem)), size(dc.du))
40+
ArrayInterface.restructure(dc.du, reinterpret(T, view(dc.dual_du, 1:nelem)))
4241
end
4342

4443
function get_tmp(dc::DiffCache, u::LabelledArrays.LArray{T,N,D,Syms}) where {T,N,D,Syms}
4544
nelem = div(sizeof(T), sizeof(eltype(dc.dual_du)))*prod(size(dc.du))
46-
_x = reshape(reinterpret(T, view(dc.dual_du, 1:nelem)), size(dc.du))
45+
_x = ArrayInterface.restructure(dc.du, reinterpret(T, view(dc.dual_du, 1:nelem)))
4746
LabelledArrays.LArray{T,N,D,Syms}(_x)
4847
end
4948

0 commit comments

Comments
 (0)