11module PreallocationTools
22
3- using ForwardDiff, LabelledArrays
3+ using ForwardDiff, ArrayInterface, LabelledArrays
44
55struct DiffCache{T<: AbstractArray , S<: AbstractArray }
66 du:: T
77 dual_du:: S
88end
99
10- #= removed dependency on ArrayInterface, because it seemed not necessary anymore;
11- not sure whether it breaks things that are not in the testset; needs checking. =#
1210function DiffCache (u:: AbstractArray{T} , siz, :: Type{Val{chunk_size}} ) where {T, chunk_size}
13- x = zeros (T,(chunk_size+ 1 )* prod (siz))
11+ x = zeros (T,(chunk_size+ 1 )* prod (siz))
1412 DiffCache (u, x)
1513end
1614
1715"""
1816
1917`dualcache(u::AbstractArray, N = Val{default_cache_size(length(u))})`
2018
21- Builds a `DualCache` object that stores versions of the cache for `u` and for the `Dual` version of `u` allowing use of pre-cached arrays with
19+ Builds a `DualCache` object that stores both a version of the cache for `u`
20+ and for the `Dual` version of `u`, allowing use of pre-cached vectors with
2221forward-mode automatic differentiation.
2322
2423"""
@@ -33,17 +32,17 @@ Returns the `Dual` or normal cache array stored in `dc` based on the type of `u`
3332"""
3433function get_tmp (dc:: DiffCache , u:: T ) where T<: ForwardDiff.Dual
3534 nelem = div (sizeof (T), sizeof (eltype (dc. dual_du)))* prod (size (dc. du))
36- reshape ( reinterpret (T, view (dc. dual_du, 1 : nelem)), size (dc . du ))
35+ ArrayInterface . restructure (dc . du, reinterpret (T, view (dc. dual_du, 1 : nelem)))
3736end
3837
3938function get_tmp (dc:: DiffCache , u:: AbstractArray{T} ) where T<: ForwardDiff.Dual
4039 nelem = div (sizeof (T), sizeof (eltype (dc. dual_du)))* prod (size (dc. du))
41- reshape ( reinterpret (T, view (dc. dual_du, 1 : nelem)), size (dc . du ))
40+ ArrayInterface . restructure (dc . du, reinterpret (T, view (dc. dual_du, 1 : nelem)))
4241end
4342
4443function get_tmp (dc:: DiffCache , u:: LabelledArrays.LArray{T,N,D,Syms} ) where {T,N,D,Syms}
4544 nelem = div (sizeof (T), sizeof (eltype (dc. dual_du)))* prod (size (dc. du))
46- _x = reshape ( reinterpret (T, view (dc. dual_du, 1 : nelem)), size (dc . du ))
45+ _x = ArrayInterface . restructure (dc . du, reinterpret (T, view (dc. dual_du, 1 : nelem)))
4746 LabelledArrays. LArray {T,N,D,Syms} (_x)
4847end
4948
0 commit comments