Skip to content

Commit 0395dce

Browse files
committed
Merge branch 'more-flexible-chunk-resizing' into Support-nested-duals
2 parents be939ac + 826c85c commit 0395dce

File tree

2 files changed

+8
-5
lines changed

2 files changed

+8
-5
lines changed

Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,12 @@ authors = ["Chris Rackauckas <accounts@chrisrackauckas.com>"]
44
version = "0.1.1"
55

66
[deps]
7+
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
78
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
89
LabelledArrays = "2ee39098-c373-598a-b85f-a56591580800"
910

1011
[compat]
12+
ArrayInterface = "2.6, 3.0"
1113
ForwardDiff = "0.10.3"
1214
LabelledArrays = "1"
1315
julia = "1.6"

src/PreallocationTools.jl

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
module PreallocationTools
22

3-
using ForwardDiff, LabelledArrays
3+
using ForwardDiff, ArrayInterface, LabelledArrays
44

55
struct DiffCache{T<:AbstractArray, S<:AbstractArray}
66
du::T
@@ -22,7 +22,8 @@ end
2222
2323
`dualcache(u::AbstractArray, N = Val{default_cache_size(length(u))})`
2424
25-
Builds a `DualCache` object that stores versions of the cache for `u` and for the `Dual` version of `u` allowing use of pre-cached arrays with
25+
Builds a `DualCache` object that stores both a version of the cache for `u`
26+
and for the `Dual` version of `u`, allowing use of pre-cached vectors with
2627
forward-mode automatic differentiation.
2728
2829
"""
@@ -37,17 +38,17 @@ Returns the `Dual` or normal cache array stored in `dc` based on the type of `u`
3738
"""
3839
function get_tmp(dc::DiffCache, u::T) where T<:ForwardDiff.Dual
3940
nelem = div(sizeof(T), sizeof(eltype(dc.dual_du)))*prod(size(dc.du))
40-
reshape(reinterpret(T, view(dc.dual_du, 1:nelem)), size(dc.du))
41+
ArrayInterface.restructure(dc.du, reinterpret(T, view(dc.dual_du, 1:nelem)))
4142
end
4243

4344
function get_tmp(dc::DiffCache, u::AbstractArray{T}) where T<:ForwardDiff.Dual
4445
nelem = div(sizeof(T), sizeof(eltype(dc.dual_du)))*prod(size(dc.du))
45-
reshape(reinterpret(T, view(dc.dual_du, 1:nelem)), size(dc.du))
46+
ArrayInterface.restructure(dc.du, reinterpret(T, view(dc.dual_du, 1:nelem)))
4647
end
4748

4849
function get_tmp(dc::DiffCache, u::LabelledArrays.LArray{T,N,D,Syms}) where {T,N,D,Syms}
4950
nelem = div(sizeof(T), sizeof(eltype(dc.dual_du)))*prod(size(dc.du))
50-
_x = reshape(reinterpret(T, view(dc.dual_du, 1:nelem)), size(dc.du))
51+
_x = ArrayInterface.restructure(dc.du, reinterpret(T, view(dc.dual_du, 1:nelem)))
5152
LabelledArrays.LArray{T,N,D,Syms}(_x)
5253
end
5354

0 commit comments

Comments
 (0)