Skip to content

Commit c1eb10b

Browse files
committed
Implement
1 parent 8bf603d commit c1eb10b

File tree

5 files changed

+45
-5
lines changed

5 files changed

+45
-5
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "LazyArrays"
22
uuid = "5078a376-72f3-5289-bfd5-ec5146d43c02"
3-
version = "2.9.2"
3+
version = "2.9.3"
44

55
[deps]
66
ArrayLayouts = "4c555306-a7a7-4459-81d9-ec55ddd5c99a"

src/LazyArrays.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ import Base: *, +, -, /, <, ==, >, \, ≤, ≥, (:), @_gc_preserve_begin, @_gc_p
1818
oneto, add_sum, promote_op
1919

2020
import Base.Broadcast: AbstractArrayStyle, BroadcastStyle, Broadcasted, DefaultArrayStyle, broadcasted, combine_eltypes,
21-
instantiate
21+
instantiate, result_style, Unknown
2222

2323
import LinearAlgebra: AbstractQ, AdjOrTrans, StructuredMatrixStyle, checksquare, det, diag, dot, lmul!, logabsdet,
2424
norm1, norm2, normInf, normp, pinv, rmul!, tr, tril, triu

src/lazyconcat.jl

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -439,8 +439,16 @@ end
439439
# to take advantage of special implementations of the sub-components
440440
######
441441

442-
BroadcastStyle(::Type{<:Vcat{<:Any,N}}) where N = LazyArrayStyle{N}()
443-
BroadcastStyle(::Type{<:Hcat{<:Any}}) = LazyArrayStyle{2}()
442+
@inline tuple_type_broadcastlayout(::Type{I}) where {I<:Tuple} = result_style(BroadcastStyle(Base.tuple_type_head(I)), tuple_type_broadcastlayout(Base.tuple_type_tail(I)))
443+
@inline tuple_type_broadcastlayout(::Type{Tuple{}}) = Unknown()
444+
@inline tuple_type_broadcastlayout(::Type{Tuple{A}}) where {A} = BroadcastStyle(A)
445+
@inline tuple_type_broadcastlayout(::Type{Tuple{A,B}}) where {A,B} = result_style(BroadcastStyle(A), BroadcastStyle(B))
446+
@inline tuple_type_broadcastlayout(::Type{Tuple{A,B,C}}) where {A,B,C} = result_style(BroadcastStyle(A), tuple_type_broadcastlayout(Tuple{B,C}))
447+
@inline tuple_type_broadcastlayout(::Type{Tuple{A,B,C,D}}) where {A,B,C,D} = result_style(BroadcastStyle(A), tuple_type_broadcastlayout(Tuple{B,C,D}))
448+
@inline tuple_type_broadcastlayout(::Type{Tuple{A,B,C,D,E}}) where {A,B,C,D,E} = result_style(BroadcastStyle(A), tuple_type_broadcastlayout(Tuple{B,C,D,E}))
449+
450+
BroadcastStyle(::Type{<:Vcat{<:Any,N,I}}) where {N,I<:Tuple} = result_style(LazyArrayStyle{N}(), tuple_type_broadcastlayout(I)) # the <:Tuple is to avoid ambiguity
451+
BroadcastStyle(::Type{<:Hcat{<:Any,I}}) where {I<:Tuple} = result_style(LazyArrayStyle{2}(), tuple_type_broadcastlayout(I))
444452

445453
# This is if we broadcast a function on a mixed concat f.([1; [2,3]])
446454
# such that f returns a vector, e.g., f(1) == [1,2], we don't want

test/cachetests.jl

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import LazyArrays: CachedArray, CachedMatrix, CachedVector, PaddedLayout, Cached
66
CachedAbstractArray, CachedAbstractVector, CachedAbstractMatrix, AbstractCachedArray, AbstractCachedMatrix,
77
PaddedColumns, cacheddata, LazyArrayStyle, maybe_cacheddata, Accumulate, CachedArrayStyle, GenericCachedLayout,
88
AccumulateAbstractVector
9+
import Base.Broadcast: BroadcastStyle
910

1011
using ..InfiniteArrays
1112
using .InfiniteArrays: OneToInf
@@ -574,6 +575,18 @@ using Infinities
574575
@test cacheddata(F) === view(cacheddata(parent(parent(F))), 1:1)'
575576
@test cacheddata(G) === adjoint(view(cacheddata(parent(G)), 1:1, 1:1))
576577
end
578+
579+
@testset "BroadcastStyle for Vcat/Hcat of CachedArrayStyles" begin
580+
@test BroadcastStyle(typeof(Vcat(cache(1:3), cache(4:6)))) == CachedArrayStyle{1}()
581+
d = Accumulate(*, 1:10)
582+
@test BroadcastStyle(typeof(Vcat(d, d))) == CachedArrayStyle{1}()
583+
@test BroadcastStyle(typeof(Vcat(d', d'))) == CachedArrayStyle{2}()
584+
@test BroadcastStyle(typeof(Hcat(d, d))) == CachedArrayStyle{2}()
585+
@test BroadcastStyle(typeof(Vcat(d', (1:10)'))) == CachedArrayStyle{2}()
586+
@test BroadcastStyle(typeof(Vcat((1:10)', d'))) == CachedArrayStyle{2}()
587+
@test BroadcastStyle(typeof(Hcat(d, (1:10)))) == CachedArrayStyle{2}()
588+
@test BroadcastStyle(typeof(Hcat((1:10), d))) == CachedArrayStyle{2}()
589+
end
577590
end
578591

579592
end # module

test/concattests.jl

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,9 @@ using LazyArrays, FillArrays, LinearAlgebra, ArrayLayouts, Test, Base64
44
using StaticArrays
55
import LazyArrays: MemoryLayout, DenseColumnMajor, materialize!, call, paddeddata,
66
MulAdd, Applied, ApplyLayout, DefaultApplyStyle, sub_materialize, resizedata!,
7-
CachedVector, ApplyLayout, arguments, BroadcastVector, LazyLayout, cacheddata
7+
CachedVector, ApplyLayout, arguments, BroadcastVector, LazyLayout, cacheddata,
8+
LazyArrayStyle, CachedArrayStyle, Accumulate
9+
import Base.Broadcast: BroadcastStyle
810

911
@testset "concat" begin
1012
@testset "Vcat" begin
@@ -712,4 +714,21 @@ import LazyArrays: MemoryLayout, DenseColumnMajor, materialize!, call, paddeddat
712714
end
713715
end
714716

717+
@testset "BroadcastStyle" begin
718+
args = (1:10, Accumulate(*, 1:10), BroadcastVector(exp, 1:10), BroadcastMatrix(exp, rand(10, 2)), Vcat(Accumulate(*, 1:10)', (1:10)')', [1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
719+
for i in 1:6
720+
@test @inferred(LazyArrays.tuple_type_broadcastlayout(typeof(args[1:i]))) == Base.Broadcast.combine_styles(args[1:i]...)
721+
end
722+
@test @inferred(LazyArrays.tuple_type_broadcastlayout(typeof(args))) == Base.Broadcast.combine_styles(args...) == CachedArrayStyle{2}()
723+
@test @inferred(LazyArrays.tuple_type_broadcastlayout(Tuple{})) == Base.Broadcast.Unknown()
724+
@test BroadcastStyle(typeof(Vcat(adjoint.(args)...))) == BroadcastStyle(typeof(Vcat(transpose.(args)...))) == CachedArrayStyle{2}()
725+
@test BroadcastStyle(typeof(Hcat(args...))) == CachedArrayStyle{2}()
726+
727+
@test BroadcastStyle(typeof(Vcat((1:10)'))) == LazyArrayStyle{2}() # make sure we preserve Lazy even without lazy args
728+
@test BroadcastStyle(typeof(Hcat((1:10)'))) == LazyArrayStyle{2}()
729+
730+
@test BroadcastStyle(typeof(Vcat())) == LazyArrayStyle{1}()
731+
@test BroadcastStyle(typeof(Hcat())) == LazyArrayStyle{2}()
732+
end
733+
715734
end # module

0 commit comments

Comments
 (0)