Skip to content

Commit 07f0fdb

Browse files
authored
simplify reinterpret array code (#43955)
Avoid one of the memcpy calls, when possible.
1 parent 3897667 commit 07f0fdb

File tree

1 file changed

+89
-83
lines changed

1 file changed

+89
-83
lines changed

base/reinterpretarray.jl

Lines changed: 89 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -362,15 +362,11 @@ end
362362

363363
@inline @propagate_inbounds function getindex(a::ReshapedReinterpretArray{T,N,S}, ind::SCartesianIndex2) where {T,N,S}
364364
check_readable(a)
365-
n = sizeof(S) ÷ sizeof(T)
366-
t = Ref{NTuple{n,T}}()
367365
s = Ref{S}(a.parent[ind.j])
368-
GC.@preserve t s begin
369-
tptr = Ptr{UInt8}(unsafe_convert(Ref{T}, t))
370-
sptr = Ptr{UInt8}(unsafe_convert(Ref{S}, s))
371-
_memcpy!(tptr, sptr, sizeof(S))
366+
GC.@preserve s begin
367+
tptr = Ptr{T}(unsafe_convert(Ref{S}, s))
368+
return unsafe_load(tptr, ind.i)
372369
end
373-
return t[][ind.i]
374370
end
375371

376372
@inline _memcpy!(dst, src, n) = ccall(:memcpy, Cvoid, (Ptr{UInt8}, Ptr{UInt8}, Csize_t), dst, src, n)
@@ -386,29 +382,37 @@ end
386382
else
387383
@boundscheck checkbounds(a, i1, tailinds...)
388384
ind_start, sidx = divrem((i1-1)*sizeof(T), sizeof(S))
389-
t = Ref{T}()
390-
s = Ref{S}()
391-
GC.@preserve t s begin
392-
tptr = Ptr{UInt8}(unsafe_convert(Ref{T}, t))
393-
sptr = Ptr{UInt8}(unsafe_convert(Ref{S}, s))
394-
# Optimizations that avoid branches
395-
if sizeof(T) % sizeof(S) == 0
396-
# T is bigger than S and contains an integer number of them
397-
n = sizeof(T) ÷ sizeof(S)
385+
# Optimizations that avoid branches
386+
if sizeof(T) % sizeof(S) == 0
387+
# T is bigger than S and contains an integer number of them
388+
n = sizeof(T) ÷ sizeof(S)
389+
t = Ref{T}()
390+
GC.@preserve t begin
391+
sptr = Ptr{S}(unsafe_convert(Ref{T}, t))
398392
for i = 1:n
399-
s[] = a.parent[ind_start + i, tailinds...]
400-
_memcpy!(tptr + (i-1)*sizeof(S), sptr, sizeof(S))
393+
s = a.parent[ind_start + i, tailinds...]
394+
unsafe_store!(sptr, s, i)
401395
end
402-
elseif sizeof(S) % sizeof(T) == 0
403-
# S is bigger than T and contains an integer number of them
404-
s[] = a.parent[ind_start + 1, tailinds...]
405-
_memcpy!(tptr, sptr + sidx, sizeof(T))
406-
else
407-
i = 1
408-
nbytes_copied = 0
409-
# This is a bit complicated to deal with partial elements
410-
# at both the start and the end. LLVM will fold as appropriate,
411-
# once it knows the data layout
396+
end
397+
return t[]
398+
elseif sizeof(S) % sizeof(T) == 0
399+
# S is bigger than T and contains an integer number of them
400+
s = Ref{S}(a.parent[ind_start + 1, tailinds...])
401+
GC.@preserve s begin
402+
tptr = Ptr{T}(unsafe_convert(Ref{S}, s))
403+
return unsafe_load(tptr + sidx)
404+
end
405+
else
406+
i = 1
407+
nbytes_copied = 0
408+
# This is a bit complicated to deal with partial elements
409+
# at both the start and the end. LLVM will fold as appropriate,
410+
# once it knows the data layout
411+
s = Ref{S}()
412+
t = Ref{T}()
413+
GC.@preserve s t begin
414+
sptr = Ptr{S}(unsafe_convert(Ref{S}, s))
415+
tptr = Ptr{T}(unsafe_convert(Ref{T}, t))
412416
while nbytes_copied < sizeof(T)
413417
s[] = a.parent[ind_start + i, tailinds...]
414418
nb = min(sizeof(S) - sidx, sizeof(T)-nbytes_copied)
@@ -418,8 +422,8 @@ end
418422
i += 1
419423
end
420424
end
425+
return t[]
421426
end
422-
return t[]
423427
end
424428
end
425429

@@ -435,44 +439,39 @@ end
435439
@boundscheck checkbounds(a, i1, tailinds...)
436440
if sizeof(T) >= sizeof(S)
437441
t = Ref{T}()
438-
s = Ref{S}()
439-
GC.@preserve t s begin
440-
tptr = Ptr{UInt8}(unsafe_convert(Ref{T}, t))
441-
sptr = Ptr{UInt8}(unsafe_convert(Ref{S}, s))
442+
GC.@preserve t begin
443+
sptr = Ptr{S}(unsafe_convert(Ref{T}, t))
442444
if sizeof(T) > sizeof(S)
443445
# Extra dimension in the parent array
444446
n = sizeof(T) ÷ sizeof(S)
445447
if isempty(tailinds) && IndexStyle(a.parent) === IndexLinear()
446448
offset = n * (i1 - firstindex(a))
447449
for i = 1:n
448-
s[] = a.parent[i + offset]
449-
_memcpy!(tptr + (i-1)*sizeof(S), sptr, sizeof(S))
450+
s = a.parent[i + offset]
451+
unsafe_store!(sptr, s, i)
450452
end
451453
else
452454
for i = 1:n
453-
s[] = a.parent[i, i1, tailinds...]
454-
_memcpy!(tptr + (i-1)*sizeof(S), sptr, sizeof(S))
455+
s = a.parent[i, i1, tailinds...]
456+
unsafe_store!(sptr, s, i)
455457
end
456458
end
457459
else
458460
# No extra dimension
459-
s[] = a.parent[i1, tailinds...]
460-
_memcpy!(tptr, sptr, sizeof(S))
461+
s = a.parent[i1, tailinds...]
462+
unsafe_store!(sptr, s)
461463
end
462464
end
463465
return t[]
464466
end
465467
# S is bigger than T and contains an integer number of them
466-
n = sizeof(S) ÷ sizeof(T)
467-
t = Ref{NTuple{n,T}}()
468+
# n = sizeof(S) ÷ sizeof(T)
468469
s = Ref{S}()
469-
GC.@preserve t s begin
470-
tptr = Ptr{UInt8}(unsafe_convert(Ref{T}, t))
471-
sptr = Ptr{UInt8}(unsafe_convert(Ref{S}, s))
470+
GC.@preserve s begin
471+
tptr = Ptr{T}(unsafe_convert(Ref{S}, s))
472472
s[] = a.parent[tailinds...]
473-
_memcpy!(tptr, sptr, sizeof(S))
473+
return unsafe_load(tptr, i1)
474474
end
475-
return t[][i1]
476475
end
477476

478477
@inline @propagate_inbounds function setindex!(a::NonReshapedReinterpretArray{T,0,S}, v) where {T,S}
@@ -502,12 +501,10 @@ end
502501
@inline @propagate_inbounds function setindex!(a::ReshapedReinterpretArray{T,N,S}, v, ind::SCartesianIndex2) where {T,N,S}
503502
check_writable(a)
504503
v = convert(T, v)::T
505-
t = Ref{T}(v)
506504
s = Ref{S}(a.parent[ind.j])
507-
GC.@preserve t s begin
508-
tptr = Ptr{UInt8}(unsafe_convert(Ref{T}, t))
509-
sptr = Ptr{UInt8}(unsafe_convert(Ref{S}, s))
510-
_memcpy!(sptr + (ind.i-1)*sizeof(T), tptr, sizeof(T))
505+
GC.@preserve s begin
506+
tptr = Ptr{T}(unsafe_convert(Ref{S}, s))
507+
unsafe_store!(tptr, v, ind.i)
511508
end
512509
a.parent[ind.j] = s[]
513510
return a
@@ -526,25 +523,32 @@ end
526523
else
527524
@boundscheck checkbounds(a, i1, tailinds...)
528525
ind_start, sidx = divrem((i1-1)*sizeof(T), sizeof(S))
529-
t = Ref{T}(v)
530-
s = Ref{S}()
531-
GC.@preserve t s begin
532-
tptr = Ptr{UInt8}(unsafe_convert(Ref{T}, t))
533-
sptr = Ptr{UInt8}(unsafe_convert(Ref{S}, s))
534-
# Optimizations that avoid branches
535-
if sizeof(T) % sizeof(S) == 0
536-
# T is bigger than S and contains an integer number of them
526+
# Optimizations that avoid branches
527+
if sizeof(T) % sizeof(S) == 0
528+
# T is bigger than S and contains an integer number of them
529+
t = Ref{T}(v)
530+
GC.@preserve t begin
531+
sptr = Ptr{S}(unsafe_convert(Ref{T}, t))
537532
n = sizeof(T) ÷ sizeof(S)
538-
for i = 0:n-1
539-
_memcpy!(sptr, tptr + i*sizeof(S), sizeof(S))
540-
a.parent[ind_start + i + 1, tailinds...] = s[]
533+
for i = 1:n
534+
s = unsafe_load(sptr, i)
535+
a.parent[ind_start + i, tailinds...] = s
541536
end
542-
elseif sizeof(S) % sizeof(T) == 0
543-
# S is bigger than T and contains an integer number of them
544-
s[] = a.parent[ind_start + 1, tailinds...]
545-
_memcpy!(sptr + sidx, tptr, sizeof(T))
537+
end
538+
elseif sizeof(S) % sizeof(T) == 0
539+
# S is bigger than T and contains an integer number of them
540+
s = Ref{S}(a.parent[ind_start + 1, tailinds...])
541+
GC.@preserve s begin
542+
tptr = Ptr{T}(unsafe_convert(Ref{S}, s))
543+
unsafe_store!(tptr + sidx, v)
546544
a.parent[ind_start + 1, tailinds...] = s[]
547-
else
545+
end
546+
else
547+
t = Ref{T}(v)
548+
s = Ref{S}()
549+
GC.@preserve t s begin
550+
tptr = Ptr{UInt8}(unsafe_convert(Ref{T}, t))
551+
sptr = Ptr{UInt8}(unsafe_convert(Ref{S}, s))
548552
nbytes_copied = 0
549553
i = 1
550554
# Deal with any partial elements at the start. We'll have to copy in the
@@ -591,36 +595,38 @@ end
591595
end
592596
end
593597
@boundscheck checkbounds(a, i1, tailinds...)
594-
t = Ref{T}(v)
595-
s = Ref{S}()
596-
GC.@preserve t s begin
597-
tptr = Ptr{UInt8}(unsafe_convert(Ref{T}, t))
598-
sptr = Ptr{UInt8}(unsafe_convert(Ref{S}, s))
599-
if sizeof(T) >= sizeof(S)
598+
if sizeof(T) >= sizeof(S)
599+
t = Ref{T}(v)
600+
GC.@preserve t begin
601+
sptr = Ptr{S}(unsafe_convert(Ref{T}, t))
600602
if sizeof(T) > sizeof(S)
601603
# Extra dimension in the parent array
602604
n = sizeof(T) ÷ sizeof(S)
603605
if isempty(tailinds) && IndexStyle(a.parent) === IndexLinear()
604606
offset = n * (i1 - firstindex(a))
605607
for i = 1:n
606-
_memcpy!(sptr, tptr + (i-1)*sizeof(S), sizeof(S))
607-
a.parent[i + offset] = s[]
608+
s = unsafe_load(sptr, i)
609+
a.parent[i + offset] = s
608610
end
609611
else
610612
for i = 1:n
611-
_memcpy!(sptr, tptr + (i-1)*sizeof(S), sizeof(S))
612-
a.parent[i, i1, tailinds...] = s[]
613+
s = unsafe_load(sptr, i)
614+
a.parent[i, i1, tailinds...] = s
613615
end
614616
end
615-
else
617+
else # sizeof(T) == sizeof(S)
616618
# No extra dimension
617-
_memcpy!(sptr, tptr, sizeof(S))
618-
a.parent[i1, tailinds...] = s[]
619+
s = unsafe_load(sptr)
620+
a.parent[i1, tailinds...] = s
619621
end
620-
else
621-
# S is bigger than T and contains an integer number of them
622+
end
623+
else
624+
# S is bigger than T and contains an integer number of them
625+
s = Ref{S}()
626+
GC.@preserve s begin
627+
tptr = Ptr{T}(unsafe_convert(Ref{S}, s))
622628
s[] = a.parent[tailinds...]
623-
_memcpy!(sptr + (i1-1)*sizeof(T), tptr, sizeof(T))
629+
unsafe_store!(tptr, v, i1)
624630
a.parent[tailinds...] = s[]
625631
end
626632
end

0 commit comments

Comments
 (0)