11module SerializedArrays
22
3+ using Base. PermutedDimsArrays: genperm
34using ConstructionBase: constructorof
4- using DiskArrays: DiskArrays, AbstractDiskArray, Unchunked
5- using LinearAlgebra: LinearAlgebra, mul!
5+ using DiskArrays: DiskArrays, AbstractDiskArray, Unchunked, readblock!, writeblock!
66using Serialization: deserialize, serialize
77
8- struct SerializedArray{T,N,A<: AbstractArray{T,N} ,Axes} <: AbstractDiskArray{T,N}
8+ abstract type AbstractSerializedArray{T,N} <: AbstractDiskArray{T,N} end
9+ const AbstractSerializedMatrix{T} = AbstractSerializedArray{T,2 }
10+ const AbstractSerializedVector{T} = AbstractSerializedArray{T,1 }
11+
12+ function _copyto_write! (dst, src)
13+ writeblock! (dst, src, axes (src)... )
14+ return dst
15+ end
16+ function _copyto_read! (dst, src)
17+ readblock! (src, dst, axes (src)... )
18+ return dst
19+ end
20+
21+ function Base. copyto! (dst:: AbstractSerializedArray , src:: AbstractArray )
22+ return _copyto_write! (dst, src)
23+ end
24+ function Base. copyto! (dst:: AbstractArray , src:: AbstractSerializedArray )
25+ return _copyto_read! (dst, src)
26+ end
27+ # Fix ambiguity error.
28+ function Base. copyto! (dst:: AbstractSerializedArray , src:: AbstractSerializedArray )
29+ return copyto! (dst, copy (src))
30+ end
31+ # Fix ambiguity error.
32+ function Base. copyto! (dst:: AbstractDiskArray , src:: AbstractSerializedArray )
33+ return copyto! (dst, copy (src))
34+ end
35+ # Fix ambiguity error.
36+ function Base. copyto! (dst:: AbstractSerializedArray , src:: AbstractDiskArray )
37+ return _copyto_write! (dst, src)
38+ end
39+ # Fix ambiguity error.
40+ function Base. copyto! (dst:: PermutedDimsArray , src:: AbstractSerializedArray )
41+ return _copyto_read! (dst, src)
42+ end
43+
44+ function Base.:(== )(a1:: AbstractSerializedArray , a2:: AbstractSerializedArray )
45+ return copy (a1) == copy (a2)
46+ end
47+ function Base.:(== )(a1:: AbstractArray , a2:: AbstractSerializedArray )
48+ return a1 == copy (a2)
49+ end
50+ function Base.:(== )(a1:: AbstractSerializedArray , a2:: AbstractArray )
51+ return copy (a1) == a2
52+ end
53+
54+ # # These cause too many ambiguity errors, try bringing them back.
55+ # function Base.convert(arrayt::Type{<:AbstractSerializedArray}, a::AbstractArray)
56+ # return arrayt(a)
57+ # end
58+ # function Base.convert(arrayt::Type{<:AbstractArray}, a::AbstractSerializedArray)
59+ # return convert(arrayt, copy(a))
60+ # end
61+ # # Fixes ambiguity error.
62+ # function Base.convert(arrayt::Type{<:Array}, a::AbstractSerializedArray)
63+ # return convert(arrayt, copy(a))
64+ # end
65+
66+ struct SerializedArray{T,N,A<: AbstractArray{T,N} ,Axes} <: AbstractSerializedArray{T,N}
967 file:: String
1068 axes:: Axes
1169end
@@ -22,17 +80,26 @@ function SerializedArray(a::AbstractArray)
2280 return SerializedArray (tempname (), a)
2381end
2482
83+ function Base. convert (arrayt:: Type{<:SerializedArray} , a:: AbstractArray )
84+ return arrayt (a)
85+ end
86+
2587function Base. similar (a:: SerializedArray , elt:: Type , dims:: Tuple{Vararg{Int}} )
2688 return constructorof (arraytype (a)){elt}(undef, dims... )
2789end
2890
91+ function materialize (a:: SerializedArray )
92+ return deserialize (file (a)):: arraytype (a)
93+ end
2994function Base. copy (a:: SerializedArray )
30- arrayt = arraytype (a)
31- return convert (arrayt, deserialize (file (a))):: arrayt
95+ return materialize (a)
3296end
3397
3498Base. size (a:: SerializedArray ) = length .(axes (a))
3599
100+ to_axis (r:: AbstractUnitRange ) = r
101+ to_axis (d:: Integer ) = Base. OneTo (d)
102+
36103#
37104# DiskArrays
38105#
@@ -64,6 +131,131 @@ function DiskArrays.create_outputarray(::Nothing, a::SerializedArray, output_siz
64131 return similar (a, output_size)
65132end
66133
134+ struct PermutedSerializedArray{T,N,P<: PermutedDimsArray{T,N} } < :
135+ AbstractSerializedArray{T,N}
136+ permuted_parent:: P
137+ end
138+ Base. parent (a:: PermutedSerializedArray ) = parent (getfield (a, :permuted_parent ))
139+
140+ perm (a:: PermutedSerializedArray ) = perm (a. permuted_parent)
141+ perm (:: PermutedDimsArray{<:Any,<:Any,p} ) where {p} = p
142+
143+ iperm (a:: PermutedSerializedArray ) = iperm (a. permuted_parent)
144+ iperm (:: PermutedDimsArray{<:Any,<:Any,<:Any,ip} ) where {ip} = ip
145+
146+ Base. axes (a:: PermutedSerializedArray ) = genperm (axes (parent (a)), perm (a))
147+ Base. size (a:: PermutedSerializedArray ) = length .(axes (a))
148+
149+ function PermutedSerializedArray (a:: AbstractArray , perm)
150+ a′ = PermutedDimsArray (a, perm)
151+ return PermutedSerializedArray {eltype(a),ndims(a),typeof(a′)} (a′)
152+ end
153+
154+ function Base. permutedims (a:: AbstractSerializedArray , perm)
155+ return PermutedSerializedArray (a, perm)
156+ end
157+
158+ function Base. similar (a:: PermutedSerializedArray , elt:: Type , dims:: Tuple{Vararg{Int}} )
159+ return similar (parent (a), elt, dims)
160+ end
161+
162+ function materialize (a:: PermutedSerializedArray )
163+ return PermutedDimsArray (copy (parent (a)), perm (a))
164+ end
165+ function Base. copy (a:: PermutedSerializedArray )
166+ return copy (materialize (a))
167+ end
168+
169+ haschunks (a:: PermutedSerializedArray ) = Unchunked ()
170+ function DiskArrays. readblock! (a:: PermutedSerializedArray , aout, i:: OrdinalRange... )
171+ ip = iperm (a)
172+ # Permute the indices
173+ inew = genperm (i, ip)
174+ # Permute the dest block and read from the true parent
175+ DiskArrays. readblock! (parent (a), PermutedDimsArray (aout, ip), inew... )
176+ return nothing
177+ end
178+ function DiskArrays. writeblock! (a:: PermutedSerializedArray , v, i:: OrdinalRange... )
179+ ip = iperm (a)
180+ inew = genperm (i, ip)
181+ # Permute the dest block and write from the true parent
182+ DiskArrays. writeblock! (parent (a), PermutedDimsArray (v, ip), inew... )
183+ return nothing
184+ end
185+
186+ struct ReshapedSerializedArray{T,N,P<: AbstractArray{T} ,Axes} <: AbstractSerializedArray{T,N}
187+ parent:: P
188+ axes:: Axes
189+ end
190+ Base. parent (a:: ReshapedSerializedArray ) = getfield (a, :parent )
191+ Base. axes (a:: ReshapedSerializedArray ) = getfield (a, :axes )
192+
193+ function ReshapedSerializedArray (
194+ a:: AbstractSerializedArray ,
195+ ax:: Tuple{AbstractUnitRange{<:Integer},Vararg{AbstractUnitRange{<:Integer}}} ,
196+ )
197+ return ReshapedSerializedArray {eltype(a),length(ax),typeof(a),typeof(ax)} (a, ax)
198+ end
199+ function ReshapedSerializedArray (
200+ a:: AbstractSerializedArray ,
201+ shape:: Tuple {
202+ Union{Integer,AbstractUnitRange{<: Integer }},
203+ Vararg{Union{Integer,AbstractUnitRange{<: Integer }}},
204+ },
205+ )
206+ return ReshapedSerializedArray (a, to_axis .(shape))
207+ end
208+
209+ Base. size (a:: ReshapedSerializedArray ) = length .(axes (a))
210+
211+ function Base. similar (a:: ReshapedSerializedArray , elt:: Type , dims:: Tuple{Vararg{Int}} )
212+ return similar (parent (a), elt, dims)
213+ end
214+
215+ function materialize (a:: ReshapedSerializedArray )
216+ return reshape (materialize (parent (a)), axes (a))
217+ end
218+ function Base. copy (a:: ReshapedSerializedArray )
219+ a′ = materialize (a)
220+ return a′ isa Base. ReshapedArray ? copy (a′) : a′
221+ end
222+
223+ # Special case for handling nested wrappers that aren't
224+ # friendly on GPU. Consider special cases of strded arrays
225+ # and handle with stride manipulations.
226+ function Base. copy (a:: ReshapedSerializedArray{<:Any,<:Any,<:PermutedSerializedArray} )
227+ a′ = reshape (copy (parent (a)), axes (a))
228+ return a′ isa Base. ReshapedArray ? copy (a′) : a′
229+ end
230+
231+ function Base. reshape (a:: AbstractSerializedArray , dims:: Tuple{Int,Vararg{Int}} )
232+ return ReshapedSerializedArray (a, dims)
233+ end
234+
235+ DiskArrays. haschunks (a:: ReshapedSerializedArray ) = Unchunked ()
236+ function DiskArrays. readblock! (
237+ a:: ReshapedSerializedArray{<:Any,N} , aout, i:: Vararg{AbstractUnitRange,N}
238+ ) where {N}
239+ if i == axes (a)
240+ aout .= copy (a)
241+ return a
242+ end
243+ aout .= @view copy (a)[i... ]
244+ return nothing
245+ end
246+ function DiskArrays. writeblock! (
247+ a:: ReshapedSerializedArray{<:Any,N} , ain, i:: Vararg{AbstractUnitRange,N}
248+ ) where {N}
249+ if i == axes (a)
250+ serialize (file (a), ain)
251+ return a
252+ end
253+ a′ = copy (a)
254+ a′[i... ] = ain
255+ serialize (file (a), a′)
256+ return nothing
257+ end
258+
67259#
68260# Broadcast
69261#
@@ -86,7 +278,7 @@ function Base.BroadcastStyle(::DefaultArrayStyle{M}, ::SerializedArrayStyle{N})
86278end
87279
88280struct BroadcastSerializedArray{T,N,BC<: Broadcasted{<:SerializedArrayStyle{N}} } < :
89- AbstractDiskArray {T,N}
281+ AbstractSerializedArray {T,N}
90282 broadcasted:: BC
91283end
92284function BroadcastSerializedArray (
@@ -106,15 +298,4 @@ function Base.copy(broadcasted::Broadcasted{SerializedArrayStyle{N}}) where {N}
106298 return BroadcastSerializedArray (flatten (broadcasted))
107299end
108300
109- #
110- # LinearAlgebra
111- #
112-
113- function LinearAlgebra. mul! (
114- a_dest:: AbstractMatrix , a1:: SerializedArray , a2:: SerializedArray , α:: Number , β:: Number
115- )
116- mul! (a_dest, copy (a1), copy (a2), α, β)
117- return a_dest
118- end
119-
120301end
0 commit comments