Skip to content

Commit 3631941

Browse files
authored
Adapt, LinearAlgebra, etc. (#3)
1 parent 91c1a05 commit 3631941

File tree

8 files changed

+372
-20
lines changed

8 files changed

+372
-20
lines changed

Project.toml

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,23 @@
11
name = "SerializedArrays"
22
uuid = "621c0da3-e96e-4f80-bd06-5ae31cdfcb39"
33
authors = ["ITensor developers <support@itensor.org> and contributors"]
4-
version = "0.1.0"
4+
version = "0.1.1"
55

66
[deps]
77
ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9"
88
DiskArrays = "3c3547ce-8d99-4f5e-a174-61eb10b00ae3"
9-
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
109
Serialization = "9e88b42a-f829-5b0c-bbe9-9e923198166b"
1110

11+
[weakdeps]
12+
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
13+
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
14+
15+
[extensions]
16+
SerializedArraysAdaptExt = "Adapt"
17+
SerializedArraysLinearAlgebraExt = "LinearAlgebra"
18+
1219
[compat]
20+
Adapt = "4.3.0"
1321
ConstructionBase = "1.5.8"
1422
DiskArrays = "0.4.12"
1523
LinearAlgebra = "1.10"
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
module SerializedArraysAdaptExt
2+
3+
using Adapt: Adapt
4+
using SerializedArrays: SerializedArray
5+
6+
function Adapt.adapt_storage(arrayt::Type{<:SerializedArray}, a::AbstractArray)
7+
return convert(arrayt, a)
8+
end
9+
10+
end
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
module SerializedArraysLinearAlgebraExt
2+
3+
using LinearAlgebra: LinearAlgebra, mul!
4+
using SerializedArrays: AbstractSerializedMatrix
5+
6+
function LinearAlgebra.mul!(
7+
a_dest::AbstractMatrix,
8+
a1::AbstractSerializedMatrix,
9+
a2::AbstractSerializedMatrix,
10+
α::Number,
11+
β::Number,
12+
)
13+
mul!(a_dest, copy(a1), copy(a2), α, β)
14+
return a_dest
15+
end
16+
17+
for f in [:eigen, :qr, :svd]
18+
@eval begin
19+
function LinearAlgebra.$f(a::AbstractSerializedMatrix; kwargs...)
20+
return LinearAlgebra.$f(copy(a))
21+
end
22+
end
23+
end
24+
25+
end

src/SerializedArrays.jl

Lines changed: 198 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,69 @@
11
module SerializedArrays
22

3+
using Base.PermutedDimsArrays: genperm
34
using ConstructionBase: constructorof
4-
using DiskArrays: DiskArrays, AbstractDiskArray, Unchunked
5-
using LinearAlgebra: LinearAlgebra, mul!
5+
using DiskArrays: DiskArrays, AbstractDiskArray, Unchunked, readblock!, writeblock!
66
using Serialization: deserialize, serialize
77

8-
struct SerializedArray{T,N,A<:AbstractArray{T,N},Axes} <: AbstractDiskArray{T,N}
8+
abstract type AbstractSerializedArray{T,N} <: AbstractDiskArray{T,N} end
9+
const AbstractSerializedMatrix{T} = AbstractSerializedArray{T,2}
10+
const AbstractSerializedVector{T} = AbstractSerializedArray{T,1}
11+
12+
function _copyto_write!(dst, src)
13+
writeblock!(dst, src, axes(src)...)
14+
return dst
15+
end
16+
function _copyto_read!(dst, src)
17+
readblock!(src, dst, axes(src)...)
18+
return dst
19+
end
20+
21+
function Base.copyto!(dst::AbstractSerializedArray, src::AbstractArray)
22+
return _copyto_write!(dst, src)
23+
end
24+
function Base.copyto!(dst::AbstractArray, src::AbstractSerializedArray)
25+
return _copyto_read!(dst, src)
26+
end
27+
# Fix ambiguity error.
28+
function Base.copyto!(dst::AbstractSerializedArray, src::AbstractSerializedArray)
29+
return copyto!(dst, copy(src))
30+
end
31+
# Fix ambiguity error.
32+
function Base.copyto!(dst::AbstractDiskArray, src::AbstractSerializedArray)
33+
return copyto!(dst, copy(src))
34+
end
35+
# Fix ambiguity error.
36+
function Base.copyto!(dst::AbstractSerializedArray, src::AbstractDiskArray)
37+
return _copyto_write!(dst, src)
38+
end
39+
# Fix ambiguity error.
40+
function Base.copyto!(dst::PermutedDimsArray, src::AbstractSerializedArray)
41+
return _copyto_read!(dst, src)
42+
end
43+
44+
function Base.:(==)(a1::AbstractSerializedArray, a2::AbstractSerializedArray)
45+
return copy(a1) == copy(a2)
46+
end
47+
function Base.:(==)(a1::AbstractArray, a2::AbstractSerializedArray)
48+
return a1 == copy(a2)
49+
end
50+
function Base.:(==)(a1::AbstractSerializedArray, a2::AbstractArray)
51+
return copy(a1) == a2
52+
end
53+
54+
# # These cause too many ambiguity errors, try bringing them back.
55+
# function Base.convert(arrayt::Type{<:AbstractSerializedArray}, a::AbstractArray)
56+
# return arrayt(a)
57+
# end
58+
# function Base.convert(arrayt::Type{<:AbstractArray}, a::AbstractSerializedArray)
59+
# return convert(arrayt, copy(a))
60+
# end
61+
# # Fixes ambiguity error.
62+
# function Base.convert(arrayt::Type{<:Array}, a::AbstractSerializedArray)
63+
# return convert(arrayt, copy(a))
64+
# end
65+
66+
struct SerializedArray{T,N,A<:AbstractArray{T,N},Axes} <: AbstractSerializedArray{T,N}
967
file::String
1068
axes::Axes
1169
end
@@ -22,17 +80,26 @@ function SerializedArray(a::AbstractArray)
2280
return SerializedArray(tempname(), a)
2381
end
2482

83+
function Base.convert(arrayt::Type{<:SerializedArray}, a::AbstractArray)
84+
return arrayt(a)
85+
end
86+
2587
function Base.similar(a::SerializedArray, elt::Type, dims::Tuple{Vararg{Int}})
2688
return constructorof(arraytype(a)){elt}(undef, dims...)
2789
end
2890

91+
function materialize(a::SerializedArray)
92+
return deserialize(file(a))::arraytype(a)
93+
end
2994
function Base.copy(a::SerializedArray)
30-
arrayt = arraytype(a)
31-
return convert(arrayt, deserialize(file(a)))::arrayt
95+
return materialize(a)
3296
end
3397

3498
Base.size(a::SerializedArray) = length.(axes(a))
3599

100+
to_axis(r::AbstractUnitRange) = r
101+
to_axis(d::Integer) = Base.OneTo(d)
102+
36103
#
37104
# DiskArrays
38105
#
@@ -64,6 +131,131 @@ function DiskArrays.create_outputarray(::Nothing, a::SerializedArray, output_siz
64131
return similar(a, output_size)
65132
end
66133

134+
struct PermutedSerializedArray{T,N,P<:PermutedDimsArray{T,N}} <:
135+
AbstractSerializedArray{T,N}
136+
permuted_parent::P
137+
end
138+
Base.parent(a::PermutedSerializedArray) = parent(getfield(a, :permuted_parent))
139+
140+
perm(a::PermutedSerializedArray) = perm(a.permuted_parent)
141+
perm(::PermutedDimsArray{<:Any,<:Any,p}) where {p} = p
142+
143+
iperm(a::PermutedSerializedArray) = iperm(a.permuted_parent)
144+
iperm(::PermutedDimsArray{<:Any,<:Any,<:Any,ip}) where {ip} = ip
145+
146+
Base.axes(a::PermutedSerializedArray) = genperm(axes(parent(a)), perm(a))
147+
Base.size(a::PermutedSerializedArray) = length.(axes(a))
148+
149+
function PermutedSerializedArray(a::AbstractArray, perm)
150+
a′ = PermutedDimsArray(a, perm)
151+
return PermutedSerializedArray{eltype(a),ndims(a),typeof(a′)}(a′)
152+
end
153+
154+
function Base.permutedims(a::AbstractSerializedArray, perm)
155+
return PermutedSerializedArray(a, perm)
156+
end
157+
158+
function Base.similar(a::PermutedSerializedArray, elt::Type, dims::Tuple{Vararg{Int}})
159+
return similar(parent(a), elt, dims)
160+
end
161+
162+
function materialize(a::PermutedSerializedArray)
163+
return PermutedDimsArray(copy(parent(a)), perm(a))
164+
end
165+
function Base.copy(a::PermutedSerializedArray)
166+
return copy(materialize(a))
167+
end
168+
169+
haschunks(a::PermutedSerializedArray) = Unchunked()
170+
function DiskArrays.readblock!(a::PermutedSerializedArray, aout, i::OrdinalRange...)
171+
ip = iperm(a)
172+
# Permute the indices
173+
inew = genperm(i, ip)
174+
# Permute the dest block and read from the true parent
175+
DiskArrays.readblock!(parent(a), PermutedDimsArray(aout, ip), inew...)
176+
return nothing
177+
end
178+
function DiskArrays.writeblock!(a::PermutedSerializedArray, v, i::OrdinalRange...)
179+
ip = iperm(a)
180+
inew = genperm(i, ip)
181+
# Permute the dest block and write from the true parent
182+
DiskArrays.writeblock!(parent(a), PermutedDimsArray(v, ip), inew...)
183+
return nothing
184+
end
185+
186+
struct ReshapedSerializedArray{T,N,P<:AbstractArray{T},Axes} <: AbstractSerializedArray{T,N}
187+
parent::P
188+
axes::Axes
189+
end
190+
Base.parent(a::ReshapedSerializedArray) = getfield(a, :parent)
191+
Base.axes(a::ReshapedSerializedArray) = getfield(a, :axes)
192+
193+
function ReshapedSerializedArray(
194+
a::AbstractSerializedArray,
195+
ax::Tuple{AbstractUnitRange{<:Integer},Vararg{AbstractUnitRange{<:Integer}}},
196+
)
197+
return ReshapedSerializedArray{eltype(a),length(ax),typeof(a),typeof(ax)}(a, ax)
198+
end
199+
function ReshapedSerializedArray(
200+
a::AbstractSerializedArray,
201+
shape::Tuple{
202+
Union{Integer,AbstractUnitRange{<:Integer}},
203+
Vararg{Union{Integer,AbstractUnitRange{<:Integer}}},
204+
},
205+
)
206+
return ReshapedSerializedArray(a, to_axis.(shape))
207+
end
208+
209+
Base.size(a::ReshapedSerializedArray) = length.(axes(a))
210+
211+
function Base.similar(a::ReshapedSerializedArray, elt::Type, dims::Tuple{Vararg{Int}})
212+
return similar(parent(a), elt, dims)
213+
end
214+
215+
function materialize(a::ReshapedSerializedArray)
216+
return reshape(materialize(parent(a)), axes(a))
217+
end
218+
function Base.copy(a::ReshapedSerializedArray)
219+
a′ = materialize(a)
220+
return a′ isa Base.ReshapedArray ? copy(a′) : a′
221+
end
222+
223+
# Special case for handling nested wrappers that aren't
224+
# friendly on GPU. Consider special cases of strded arrays
225+
# and handle with stride manipulations.
226+
function Base.copy(a::ReshapedSerializedArray{<:Any,<:Any,<:PermutedSerializedArray})
227+
a′ = reshape(copy(parent(a)), axes(a))
228+
return a′ isa Base.ReshapedArray ? copy(a′) : a′
229+
end
230+
231+
function Base.reshape(a::AbstractSerializedArray, dims::Tuple{Int,Vararg{Int}})
232+
return ReshapedSerializedArray(a, dims)
233+
end
234+
235+
DiskArrays.haschunks(a::ReshapedSerializedArray) = Unchunked()
236+
function DiskArrays.readblock!(
237+
a::ReshapedSerializedArray{<:Any,N}, aout, i::Vararg{AbstractUnitRange,N}
238+
) where {N}
239+
if i == axes(a)
240+
aout .= copy(a)
241+
return a
242+
end
243+
aout .= @view copy(a)[i...]
244+
return nothing
245+
end
246+
function DiskArrays.writeblock!(
247+
a::ReshapedSerializedArray{<:Any,N}, ain, i::Vararg{AbstractUnitRange,N}
248+
) where {N}
249+
if i == axes(a)
250+
serialize(file(a), ain)
251+
return a
252+
end
253+
a′ = copy(a)
254+
a′[i...] = ain
255+
serialize(file(a), a′)
256+
return nothing
257+
end
258+
67259
#
68260
# Broadcast
69261
#
@@ -86,7 +278,7 @@ function Base.BroadcastStyle(::DefaultArrayStyle{M}, ::SerializedArrayStyle{N})
86278
end
87279

88280
struct BroadcastSerializedArray{T,N,BC<:Broadcasted{<:SerializedArrayStyle{N}}} <:
89-
AbstractDiskArray{T,N}
281+
AbstractSerializedArray{T,N}
90282
broadcasted::BC
91283
end
92284
function BroadcastSerializedArray(
@@ -106,15 +298,4 @@ function Base.copy(broadcasted::Broadcasted{SerializedArrayStyle{N}}) where {N}
106298
return BroadcastSerializedArray(flatten(broadcasted))
107299
end
108300

109-
#
110-
# LinearAlgebra
111-
#
112-
113-
function LinearAlgebra.mul!(
114-
a_dest::AbstractMatrix, a1::SerializedArray, a2::SerializedArray, α::Number, β::Number
115-
)
116-
mul!(a_dest, copy(a1), copy(a2), α, β)
117-
return a_dest
118-
end
119-
120301
end

test/Project.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
[deps]
2+
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
23
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
34
GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
45
JLArrays = "27aeb0d3-9eb9-45fb-866b-73c2ecf80fcb"
6+
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
57
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
68
SerializedArrays = "621c0da3-e96e-4f80-bd06-5ae31cdfcb39"
79
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
@@ -10,9 +12,11 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
1012
TestExtras = "5ed8adda-3752-4e41-b88a-e8b09835ee3a"
1113

1214
[compat]
15+
Adapt = "4"
1316
Aqua = "0.8"
1417
GPUArraysCore = "0.2"
1518
JLArrays = "0.2"
19+
LinearAlgebra = "1.10"
1620
SafeTestsets = "0.1"
1721
SerializedArrays = "0.1"
1822
StableRNGs = "1"

test/test_adaptext.jl

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
using Adapt: adapt
2+
using JLArrays: JLArray
3+
using SerializedArrays: SerializedArray
4+
using StableRNGs: StableRNG
5+
using Test: @test, @testset
6+
using TestExtras: @constinferred
7+
8+
elts = (Float32, Float64, Complex{Float32}, Complex{Float64})
9+
arrayts = (Array, JLArray)
10+
@testset "SerializedArraysAdaptExt (eltype=$elt, arraytype=$arrayt)" for elt in elts,
11+
arrayt in arrayts
12+
13+
rng = StableRNG(123)
14+
x = arrayt(randn(rng, elt, 4, 4))
15+
y = PermutedDimsArray(x, (2, 1))
16+
a = adapt(SerializedArray, x)
17+
@test a isa SerializedArray{elt,2,arrayt{elt,2}}
18+
b = adapt(SerializedArray, y)
19+
@test b isa
20+
PermutedDimsArray{elt,2,(2, 1),(2, 1),<:SerializedArray{elt,2,<:arrayt{elt,2}}}
21+
@test parent(b) == a
22+
end

0 commit comments

Comments
 (0)