Skip to content

Commit bf45ae4

Browse files
authored
SubArray, Adjoint, Transpose (#4)
1 parent 3631941 commit bf45ae4

File tree

3 files changed

+209
-9
lines changed

3 files changed

+209
-9
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "SerializedArrays"
22
uuid = "621c0da3-e96e-4f80-bd06-5ae31cdfcb39"
33
authors = ["ITensor developers <support@itensor.org> and contributors"]
4-
version = "0.1.1"
4+
version = "0.1.2"
55

66
[deps]
77
ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9"

src/SerializedArrays.jl

Lines changed: 161 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,10 @@ using ConstructionBase: constructorof
55
using DiskArrays: DiskArrays, AbstractDiskArray, Unchunked, readblock!, writeblock!
66
using Serialization: deserialize, serialize
77

8+
#
9+
# AbstractSerializedArray
10+
#
11+
812
abstract type AbstractSerializedArray{T,N} <: AbstractDiskArray{T,N} end
913
const AbstractSerializedMatrix{T} = AbstractSerializedArray{T,2}
1014
const AbstractSerializedVector{T} = AbstractSerializedArray{T,1}
@@ -63,6 +67,10 @@ end
6367
# return convert(arrayt, copy(a))
6468
# end
6569

70+
#
71+
# SerializedArray
72+
#
73+
6674
struct SerializedArray{T,N,A<:AbstractArray{T,N},Axes} <: AbstractSerializedArray{T,N}
6775
file::String
6876
axes::Axes
@@ -100,10 +108,7 @@ Base.size(a::SerializedArray) = length.(axes(a))
100108
to_axis(r::AbstractUnitRange) = r
101109
to_axis(d::Integer) = Base.OneTo(d)
102110

103-
#
104-
# DiskArrays
105-
#
106-
111+
# DiskArrays interface
107112
DiskArrays.haschunks(::SerializedArray) = Unchunked()
108113
function DiskArrays.readblock!(
109114
a::SerializedArray{<:Any,N}, aout, i::Vararg{AbstractUnitRange,N}
@@ -131,12 +136,18 @@ function DiskArrays.create_outputarray(::Nothing, a::SerializedArray, output_siz
131136
return similar(a, output_size)
132137
end
133138

139+
#
140+
# PermutedSerializedArray
141+
#
142+
134143
struct PermutedSerializedArray{T,N,P<:PermutedDimsArray{T,N}} <:
135144
AbstractSerializedArray{T,N}
136145
permuted_parent::P
137146
end
138147
Base.parent(a::PermutedSerializedArray) = parent(getfield(a, :permuted_parent))
139148

149+
file(a::PermutedSerializedArray) = file(parent(a))
150+
140151
perm(a::PermutedSerializedArray) = perm(a.permuted_parent)
141152
perm(::PermutedDimsArray{<:Any,<:Any,p}) where {p} = p
142153

@@ -172,24 +183,30 @@ function DiskArrays.readblock!(a::PermutedSerializedArray, aout, i::OrdinalRange
172183
# Permute the indices
173184
inew = genperm(i, ip)
174185
# Permute the dest block and read from the true parent
175-
DiskArrays.readblock!(parent(a), PermutedDimsArray(aout, ip), inew...)
186+
readblock!(parent(a), PermutedDimsArray(aout, ip), inew...)
176187
return nothing
177188
end
178189
function DiskArrays.writeblock!(a::PermutedSerializedArray, v, i::OrdinalRange...)
179190
ip = iperm(a)
180191
inew = genperm(i, ip)
181192
# Permute the dest block and write from the true parent
182-
DiskArrays.writeblock!(parent(a), PermutedDimsArray(v, ip), inew...)
193+
writeblock!(parent(a), PermutedDimsArray(v, ip), inew...)
183194
return nothing
184195
end
185196

197+
#
198+
# ReshapedSerializedArray
199+
#
200+
186201
struct ReshapedSerializedArray{T,N,P<:AbstractArray{T},Axes} <: AbstractSerializedArray{T,N}
187202
parent::P
188203
axes::Axes
189204
end
190205
Base.parent(a::ReshapedSerializedArray) = getfield(a, :parent)
191206
Base.axes(a::ReshapedSerializedArray) = getfield(a, :axes)
192207

208+
file(a::ReshapedSerializedArray) = file(parent(a))
209+
193210
function ReshapedSerializedArray(
194211
a::AbstractSerializedArray,
195212
ax::Tuple{AbstractUnitRange{<:Integer},Vararg{AbstractUnitRange{<:Integer}}},
@@ -256,6 +273,141 @@ function DiskArrays.writeblock!(
256273
return nothing
257274
end
258275

276+
#
277+
# SubSerializedArray
278+
#
279+
280+
struct SubSerializedArray{T,N,P,I,L} <: AbstractSerializedArray{T,N}
281+
sub_parent::SubArray{T,N,P,I,L}
282+
end
283+
284+
file(a::SubSerializedArray) = file(parent(a))
285+
286+
# Base methods
287+
function Base.view(a::SerializedArray, i...)
288+
return SubSerializedArray(SubArray(a, Base.to_indices(a, i)))
289+
end
290+
function Base.view(a::SerializedArray, i::CartesianIndices)
291+
return SubSerializedArray(SubArray(a, Base.to_indices(a, i)))
292+
end
293+
Base.view(a::SubSerializedArray, i...) = SubSerializedArray(view(a.sub_parent, i...))
294+
Base.view(a::SubSerializedArray, i::CartesianIndices) = view(a, i.indices...)
295+
Base.size(a::SubSerializedArray) = size(a.sub_parent)
296+
Base.axes(a::SubSerializedArray) = axes(a.sub_parent)
297+
Base.parent(a::SubSerializedArray) = parent(a.sub_parent)
298+
Base.parentindices(a::SubSerializedArray) = parentindices(a.sub_parent)
299+
300+
function materialize(a::SubSerializedArray)
301+
return view(copy(parent(a)), parentindices(a)...)
302+
end
303+
function Base.copy(a::SubSerializedArray)
304+
return copy(materialize(a))
305+
end
306+
307+
DiskArrays.haschunks(a::SubSerializedArray) = Unchunked()
308+
function DiskArrays.readblock!(a::SubSerializedArray, aout, i::OrdinalRange...)
309+
if i == axes(a)
310+
aout .= copy(a)
311+
end
312+
aout[i...] = copy(view(a, i...))
313+
return nothing
314+
end
315+
function DiskArrays.writeblock!(a::SubSerializedArray, ain, i::OrdinalRange...)
316+
if i == axes(a)
317+
serialize(file(a), ain)
318+
return a
319+
end
320+
a_parent = copy(parent(a))
321+
pinds = parentindices(view(a.sub_parent, i...))
322+
a_parent[pinds...] = ain
323+
serialize(file(a), a_parent)
324+
return nothing
325+
end
326+
327+
#
328+
# TransposeSerializedArray
329+
#
330+
331+
struct TransposeSerializedArray{T,P<:AbstractSerializedArray{T}} <:
332+
AbstractSerializedMatrix{T}
333+
parent::P
334+
end
335+
Base.parent(a::TransposeSerializedArray) = getfield(a, :parent)
336+
337+
file(a::TransposeSerializedArray) = file(parent(a))
338+
339+
Base.axes(a::TransposeSerializedArray) = reverse(axes(parent(a)))
340+
Base.size(a::TransposeSerializedArray) = length.(axes(a))
341+
342+
function Base.transpose(a::AbstractSerializedArray)
343+
return TransposeSerializedArray(a)
344+
end
345+
Base.transpose(a::TransposeSerializedArray) = parent(a)
346+
347+
function Base.similar(a::TransposeSerializedArray, elt::Type, dims::Tuple{Vararg{Int}})
348+
return similar(parent(a), elt, dims)
349+
end
350+
351+
function materialize(a::TransposeSerializedArray)
352+
return transpose(copy(parent(a)))
353+
end
354+
function Base.copy(a::TransposeSerializedArray)
355+
return copy(materialize(a))
356+
end
357+
358+
haschunks(a::TransposeSerializedArray) = Unchunked()
359+
function DiskArrays.readblock!(a::TransposeSerializedArray, aout, i::OrdinalRange...)
360+
readblock!(parent(a), transpose(aout), reverse(i)...)
361+
return nothing
362+
end
363+
function DiskArrays.writeblock!(a::TransposeSerializedArray, ain, i::OrdinalRange...)
364+
writeblock!(parent(a), transpose(aout), reverse(i)...)
365+
return nothing
366+
end
367+
368+
#
369+
# AdjointSerializedArray
370+
#
371+
372+
struct AdjointSerializedArray{T,P<:AbstractSerializedArray{T}} <:
373+
AbstractSerializedMatrix{T}
374+
parent::P
375+
end
376+
Base.parent(a::AdjointSerializedArray) = getfield(a, :parent)
377+
378+
file(a::AdjointSerializedArray) = file(parent(a))
379+
380+
Base.axes(a::AdjointSerializedArray) = reverse(axes(parent(a)))
381+
Base.size(a::AdjointSerializedArray) = length.(axes(a))
382+
383+
function Base.adjoint(a::AbstractSerializedArray)
384+
return AdjointSerializedArray(a)
385+
end
386+
Base.adjoint(a::AdjointSerializedArray) = parent(a)
387+
Base.adjoint(a::TransposeSerializedArray{<:Real}) = parent(a)
388+
Base.transpose(a::AdjointSerializedArray{<:Real}) = parent(a)
389+
390+
function Base.similar(a::AdjointSerializedArray, elt::Type, dims::Tuple{Vararg{Int}})
391+
return similar(parent(a), elt, dims)
392+
end
393+
394+
function materialize(a::AdjointSerializedArray)
395+
return adjoint(copy(parent(a)))
396+
end
397+
function Base.copy(a::AdjointSerializedArray)
398+
return copy(materialize(a))
399+
end
400+
401+
haschunks(a::AdjointSerializedArray) = Unchunked()
402+
function DiskArrays.readblock!(a::AdjointSerializedArray, aout, i::OrdinalRange...)
403+
readblock!(parent(a), adjoint(aout), reverse(i)...)
404+
return nothing
405+
end
406+
function DiskArrays.writeblock!(a::AdjointSerializedArray, ain, i::OrdinalRange...)
407+
writeblock!(parent(a), adjoint(aout), reverse(i)...)
408+
return nothing
409+
end
410+
259411
#
260412
# Broadcast
261413
#
@@ -264,7 +416,9 @@ using Base.Broadcast:
264416
BroadcastStyle, Broadcasted, DefaultArrayStyle, combine_styles, flatten
265417

266418
struct SerializedArrayStyle{N} <: Base.Broadcast.AbstractArrayStyle{N} end
267-
Base.BroadcastStyle(arrayt::Type{<:SerializedArray}) = SerializedArrayStyle{ndims(arrayt)}()
419+
function Base.BroadcastStyle(arrayt::Type{<:AbstractSerializedArray})
420+
SerializedArrayStyle{ndims(arrayt)}()
421+
end
268422
function Base.BroadcastStyle(
269423
::SerializedArrayStyle{N}, ::SerializedArrayStyle{M}
270424
) where {N,M}

test/test_basics.jl

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,12 @@
11
using GPUArraysCore: @allowscalar
22
using JLArrays: JLArray
3-
using SerializedArrays: PermutedSerializedArray, ReshapedSerializedArray, SerializedArray
3+
using SerializedArrays:
4+
AdjointSerializedArray,
5+
PermutedSerializedArray,
6+
ReshapedSerializedArray,
7+
SerializedArray,
8+
SubSerializedArray,
9+
TransposeSerializedArray
410
using StableRNGs: StableRNG
511
using Test: @test, @testset
612
using TestExtras: @constinferred
@@ -50,6 +56,33 @@ arrayts = (Array, JLArray)
5056
@test a isa PermutedSerializedArray{elt,2}
5157
@test similar(a) isa arrayt{elt,2}
5258
@test copy(a) == permutedims(x, (2, 1))
59+
@test copy(2a) == 2permutedims(x, (2, 1))
60+
61+
rng = StableRNG(123)
62+
x = arrayt(randn(rng, elt, 4, 4))
63+
a = transpose(SerializedArray(x))
64+
@test a isa TransposeSerializedArray{elt}
65+
@test similar(a) isa arrayt{elt,2}
66+
@test copy(a) == transpose(x)
67+
@test copy(2a) == 2transpose(x)
68+
69+
rng = StableRNG(123)
70+
x = arrayt(randn(rng, elt, 4, 4))
71+
a = adjoint(SerializedArray(x))
72+
@test a isa AdjointSerializedArray{elt}
73+
@test similar(a) isa arrayt{elt,2}
74+
@test copy(a) == adjoint(x)
75+
@test copy(2a) == 2adjoint(x)
76+
77+
rng = StableRNG(123)
78+
x = arrayt(randn(rng, elt, 4, 4))
79+
a = SerializedArray(x)
80+
@test transpose(transpose(a)) === a
81+
@test adjoint(adjoint(a)) === a
82+
if isreal(a)
83+
@test adjoint(transpose(a)) === a
84+
@test transpose(adjoint(a)) === a
85+
end
5386

5487
rng = StableRNG(123)
5588
x = arrayt(randn(rng, elt, 4, 4))
@@ -96,4 +129,17 @@ arrayts = (Array, JLArray)
96129
copyto!(y, a)
97130
b = SerializedArray(y)
98131
@test b == a
132+
133+
rng = StableRNG(123)
134+
x = arrayt(randn(rng, elt, 4, 4))
135+
y = @view x[2:3, 2:3]
136+
a = SerializedArray(a)
137+
b = @view a[2:3, 2:3]
138+
@test b isa SubSerializedArray{elt,2}
139+
c = 2b
140+
@test 2y == copy(c)
141+
@allowscalar begin
142+
b[1, 1] = 2
143+
@test @constinferred(b[1, 1]) == 2
144+
end
99145
end

0 commit comments

Comments
 (0)