11module SerializedArrays
22
3+ export SerializedArray, disk, memory
4+
35using Base. PermutedDimsArrays: genperm
46using ConstructionBase: constructorof
57using DiskArrays: DiskArrays, AbstractDiskArray, Unchunked, readblock!, writeblock!
68using Serialization: deserialize, serialize
79
8- memory (a) = a
10+ adapt_serialized (to, x) = adapt_structure_serialized (to, x)
11+ adapt_serialized (to) = Base. Fix1 (adapt_structure_serialized, to)
12+ adapt_structure_serialized (to, x) = adapt_storage_serialized (to, x)
13+ adapt_storage_serialized (to, x) = x
14+
15+ struct DeepMemoryAdaptor end
16+ deepmemory (x) = adapt_serialized (DeepMemoryAdaptor (), x)
17+
18+ struct MemoryAdaptor end
19+ memory (x) = adapt_serialized (MemoryAdaptor (), x)
920
1021#
1122# AbstractSerializedArray
@@ -15,9 +26,12 @@ abstract type AbstractSerializedArray{T,N} <: AbstractDiskArray{T,N} end
1526const AbstractSerializedMatrix{T} = AbstractSerializedArray{T,2 }
1627const AbstractSerializedVector{T} = AbstractSerializedArray{T,1 }
1728
18- memory (a:: AbstractSerializedArray ) = copy (a)
1929disk (a:: AbstractSerializedArray ) = a
2030
31+ function Base. copy (a:: AbstractSerializedArray )
32+ return copy (memory (a))
33+ end
34+
2135function _copyto_write! (dst, src)
2236 writeblock! (dst, src, axes (src)... )
2337 return dst
@@ -62,18 +76,6 @@ function Base.:(==)(a1::AbstractSerializedArray, a2::AbstractArray)
6276 return equals_serialized (a1, a2)
6377end
6478
65- # # These cause too many ambiguity errors, try bringing them back.
66- # function Base.convert(arrayt::Type{<:AbstractSerializedArray}, a::AbstractArray)
67- # return arrayt(a)
68- # end
69- # function Base.convert(arrayt::Type{<:AbstractArray}, a::AbstractSerializedArray)
70- # return convert(arrayt, memory(a))
71- # end
72- # # Fixes ambiguity error.
73- # function Base.convert(arrayt::Type{<:Array}, a::AbstractSerializedArray)
74- # return convert(arrayt, memory(a))
75- # end
76-
7779#
7880# SerializedArray
7981#
@@ -105,11 +107,19 @@ function Base.similar(a::SerializedArray, elt::Type, dims::Tuple{Vararg{Int}})
105107 return constructorof (arraytype (a)){elt}(undef, dims... )
106108end
107109
108- function materialize (a:: SerializedArray )
110+ function _memory (a:: SerializedArray )
109111 return deserialize (file (a)):: arraytype (a)
110112end
113+
114+ function adapt_storage_serialized (:: DeepMemoryAdaptor , a:: SerializedArray )
115+ return _memory (a)
116+ end
117+ function adapt_storage_serialized (:: MemoryAdaptor , a:: SerializedArray )
118+ return _memory (a)
119+ end
120+
111121function Base. copy (a:: SerializedArray )
112- return materialize (a)
122+ return memory (a)
113123end
114124
115125Base. size (a:: SerializedArray ) = length .(axes (a))
@@ -123,7 +133,7 @@ function DiskArrays.readblock!(
123133 a:: SerializedArray{<:Any,N} , aout, i:: Vararg{AbstractUnitRange,N}
124134) where {N}
125135 if i == axes (a)
126- aout .= memory (a)
136+ aout .= deepmemory (a)
127137 return a
128138 end
129139 aout .= @view memory (a)[i... ]
@@ -179,11 +189,13 @@ function Base.similar(a::PermutedSerializedArray, elt::Type, dims::Tuple{Vararg{
179189 return similar (parent (a), elt, dims)
180190end
181191
182- function materialize ( a:: PermutedSerializedArray )
183- return PermutedDimsArray (memory ( parent (a)), perm (a))
192+ function adapt_structure_serialized (to, a:: PermutedSerializedArray )
193+ return PermutedDimsArray (adapt_serialized (to, parent (a)), perm (a))
184194end
185- function Base. copy (a:: PermutedSerializedArray )
186- return copy (materialize (a))
195+
196+ # Special case to eagerly instantiate permutations.
197+ function adapt_structure_serialized (to:: MemoryAdaptor , a:: PermutedSerializedArray )
198+ return copy (deepmemory (a))
187199end
188200
189201haschunks (a:: PermutedSerializedArray ) = Unchunked ()
@@ -238,19 +250,14 @@ function Base.similar(a::ReshapedSerializedArray, elt::Type, dims::Tuple{Vararg{
238250 return similar (parent (a), elt, dims)
239251end
240252
241- function materialize ( a:: ReshapedSerializedArray )
242- return reshape (materialize ( parent (a)), axes (a))
253+ function adapt_structure_serialized (to, a:: ReshapedSerializedArray )
254+ return reshape (adapt_serialized (to, parent (a)), axes (a))
243255end
244256function Base. copy (a:: ReshapedSerializedArray )
245- a′ = materialize (a)
246- return a′ isa Base. ReshapedArray ? copy (a′) : a′
247- end
248-
249- # Special case for handling nested wrappers that aren't
250- # friendly on GPU. Consider special cases of strded arrays
251- # and handle with stride manipulations.
252- function Base. copy (a:: ReshapedSerializedArray{<:Any,<:Any,<:PermutedSerializedArray} )
253- a′ = reshape (memory (parent (a)), axes (a))
257+ # `memory` instantiates `PermutedSerializedArray`, which is
258+ # friendlier for GPU. Consider special cases of strded arrays
259+ # and handle with stride manipulations.
260+ a′ = memory (a)
254261 return a′ isa Base. ReshapedArray ? copy (a′) : a′
255262end
256263
@@ -306,17 +313,14 @@ Base.axes(a::SubSerializedArray) = axes(a.sub_parent)
306313Base. parent (a:: SubSerializedArray ) = parent (a. sub_parent)
307314Base. parentindices (a:: SubSerializedArray ) = parentindices (a. sub_parent)
308315
309- function materialize (a:: SubSerializedArray )
310- return view (copy (parent (a)), parentindices (a)... )
311- end
312- function Base. copy (a:: SubSerializedArray )
313- return copy (materialize (a))
316+ function adapt_structure_serialized (to, a:: SubSerializedArray )
317+ return view (adapt_serialized (to, parent (a)), parentindices (a)... )
314318end
315319
316320DiskArrays. haschunks (a:: SubSerializedArray ) = Unchunked ()
317321function DiskArrays. readblock! (a:: SubSerializedArray , aout, i:: OrdinalRange... )
318322 if i == axes (a)
319- aout .= memory (a)
323+ aout .= deepmemory (a)
320324 end
321325 aout[i... ] = memory (view (a, i... ))
322326 return nothing
@@ -326,7 +330,7 @@ function DiskArrays.writeblock!(a::SubSerializedArray, ain, i::OrdinalRange...)
326330 serialize (file (a), ain)
327331 return a
328332 end
329- a_parent = memory (parent (a))
333+ a_parent = deepmemory (parent (a))
330334 pinds = parentindices (view (a. sub_parent, i... ))
331335 a_parent[pinds... ] = ain
332336 serialize (file (a), a_parent)
@@ -357,11 +361,8 @@ function Base.similar(a::TransposeSerializedArray, elt::Type, dims::Tuple{Vararg
357361 return similar (parent (a), elt, dims)
358362end
359363
360- function materialize (a:: TransposeSerializedArray )
361- return transpose (memory (parent (a)))
362- end
363- function Base. copy (a:: TransposeSerializedArray )
364- return copy (materialize (a))
364+ function adapt_structure_serialized (to, a:: TransposeSerializedArray )
365+ return transpose (adapt_serialized (to, parent (a)))
365366end
366367
367368haschunks (a:: TransposeSerializedArray ) = Unchunked ()
@@ -400,11 +401,8 @@ function Base.similar(a::AdjointSerializedArray, elt::Type, dims::Tuple{Vararg{I
400401 return similar (parent (a), elt, dims)
401402end
402403
403- function materialize (a:: AdjointSerializedArray )
404- return adjoint (memory (parent (a)))
405- end
406- function Base. copy (a:: AdjointSerializedArray )
407- return copy (materialize (a))
404+ function adapt_structure_serialized (to, a:: AdjointSerializedArray )
405+ return adjoint (adapt_serialized (to, parent (a)))
408406end
409407
410408haschunks (a:: AdjointSerializedArray ) = Unchunked ()
@@ -452,9 +450,16 @@ function BroadcastSerializedArray(
452450end
453451Base. size (a:: BroadcastSerializedArray ) = size (a. broadcasted)
454452Base. broadcastable (a:: BroadcastSerializedArray ) = a. broadcasted
455- function Base. copy (a:: BroadcastSerializedArray )
456- # Broadcast over the materialized arrays.
457- return copy (Base. Broadcast. broadcasted (a. broadcasted. f, memory .(a. broadcasted. args)... ))
453+
454+ function adapt_structure_serialized (to, a:: BroadcastSerializedArray )
455+ return Base. Broadcast. broadcasted (
456+ a. broadcasted. f, map (adapt_serialized (to), a. broadcasted. args)...
457+ )
458+ end
459+
460+ # Special case to eagerly instantiate broadcasts.
461+ function adapt_storage_serialized (:: MemoryAdaptor , a:: BroadcastSerializedArray )
462+ return copy (a)
458463end
459464
460465function Base. copy (broadcasted:: Broadcasted{SerializedArrayStyle{N}} ) where {N}
0 commit comments