Skip to content

Commit 943a778

Browse files
Merge pull request #100 from LilithHafner/lh/lbc-fix
Fix and test LazyBufferCache on types that are not fixed points of similar [bugfix]
2 parents 5501e45 + ab41468 commit 943a778

File tree

3 files changed

+17
-8
lines changed

3 files changed

+17
-8
lines changed

docs/src/index.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -316,7 +316,7 @@ tries to do this with a bump allocator.
316316

317317
- See the [SciML Style Guide](https://github.com/SciML/SciMLStyle) for common coding practices and other style decisions.
318318
- There are a few community forums:
319-
319+
320320
+ The #diffeq-bridged and #sciml-bridged channels in the
321321
[Julia Slack](https://julialang.org/slack/)
322322
+ The #diffeq-bridged and #sciml-bridged channels in the

src/PreallocationTools.jl

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ struct FixedSizeDiffCache{T <: AbstractArray, S <: AbstractArray}
99
end
1010

1111
function FixedSizeDiffCache(u::AbstractArray{T}, siz,
12-
::Type{Val{chunk_size}}) where {T, chunk_size}
12+
::Type{Val{chunk_size}}) where {T, chunk_size}
1313
x = ArrayInterface.restructure(u,
1414
zeros(ForwardDiff.Dual{nothing, T, chunk_size},
1515
siz...))
@@ -25,8 +25,8 @@ and for the `Dual` version of `u`, allowing use of pre-cached vectors with
2525
forward-mode automatic differentiation.
2626
"""
2727
function FixedSizeDiffCache(u::AbstractArray,
28-
::Type{Val{N}} = Val{ForwardDiff.pickchunksize(length(u))}) where {
29-
N,
28+
::Type{Val{N}} = Val{ForwardDiff.pickchunksize(length(u))}) where {
29+
N,
3030
}
3131
FixedSizeDiffCache(u, size(u), Val{N})
3232
end
@@ -75,7 +75,7 @@ function get_tmp(dc::FixedSizeDiffCache, u::Union{Number, AbstractArray})
7575
end
7676
end
7777

78-
function get_tmp(dc::FixedSizeDiffCache, ::Type{T}) where T <: Number
78+
function get_tmp(dc::FixedSizeDiffCache, ::Type{T}) where {T <: Number}
7979
if promote_type(eltype(dc.du), T) <: eltype(dc.du)
8080
dc.du
8181
else
@@ -111,7 +111,7 @@ forward-mode automatic differentiation. Supports nested AD via keyword `levels`
111111
or specifying an array of chunk_sizes.
112112
"""
113113
function DiffCache(u::AbstractArray, N::Int = ForwardDiff.pickchunksize(length(u));
114-
levels::Int = 1)
114+
levels::Int = 1)
115115
DiffCache(u, size(u), N * ones(Int, levels))
116116
end
117117
DiffCache(u::AbstractArray, N::AbstractArray{<:Int}) = DiffCache(u, size(u), N)
@@ -164,7 +164,7 @@ function get_tmp(dc::DiffCache, u::Union{Number, AbstractArray})
164164
end
165165
end
166166

167-
function get_tmp(dc::DiffCache, ::Type{T}) where T <: Number
167+
function get_tmp(dc::DiffCache, ::Type{T}) where {T <: Number}
168168
if promote_type(eltype(dc.du), T) <: eltype(dc.du)
169169
dc.du
170170
else
@@ -209,12 +209,18 @@ struct LazyBufferCache{F <: Function}
209209
LazyBufferCache(f::F = identity) where {F <: Function} = new{F}(Dict(), f) # start with empty dict
210210
end
211211

212+
function similar_type(x::AbstractArray{T}, s::NTuple{N, Integer}) where {T, N}
213+
# The compiler is smart enough to not allocate
214+
# here for simple types like Array and SubArray
215+
typeof(similar(x, ntuple(Returns(1), N)))
216+
end
217+
212218
# override the [] method
213219
function Base.getindex(b::LazyBufferCache, u::T) where {T <: AbstractArray}
214220
s = b.sizemap(size(u)) # required buffer size
215221
get!(b.bufs, (T, s)) do
216222
similar(u, s) # buffer to allocate if it was not found in b.bufs
217-
end::T # declare type since b.bufs dictionary is untyped
223+
end::similar_type(u, s) # declare type since b.bufs dictionary is untyped
218224
end
219225

220226
# GeneralLazyBufferCache

test/general_lbc.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,9 @@ cache = LazyBufferCache()
3737
x = rand(1000)
3838
@inferred cache[x]
3939
@test 0 == @allocated cache[x]
40+
y = view(x, 1:900)
41+
@inferred cache[y]
42+
@test 0 == @allocated cache[y]
4043

4144
cache = GeneralLazyBufferCache(T -> Vector{T}(undef, 1000))
4245
# GeneralLazyBufferCache is documented not to infer.

0 commit comments

Comments
 (0)