Skip to content

Commit d9b2905

Browse files
committed
Add BroadcastStyle support for BroadcastArray
1 parent d54e5a0 commit d9b2905

File tree

2 files changed

+11
-1
lines changed

2 files changed

+11
-1
lines changed

src/lazybroadcasting.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,8 @@ end
5252
const BroadcastVector{T,F,Args} = BroadcastArray{T,1,F,Args}
5353
const BroadcastMatrix{T,F,Args} = BroadcastArray{T,2,F,Args}
5454

55+
BroadcastStyle(::Type{<:BroadcastArray{<:Any,N,<:Any,Args}}) where {N,Args<:Tuple} = result_style(LazyArrayStyle{N}(), tuple_type_broadcastlayout(Args))
56+
5557
LazyArray(bc::Broadcasted) = BroadcastArray(bc)
5658

5759
BroadcastArray{T,N,F,Args}(bc::Broadcasted) where {T,N,F,Args} = BroadcastArray{T,N,F,Args}(bc.f,bc.args)

test/broadcasttests.jl

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,9 @@ module BroadcastTests
22

33
using LazyArrays, ArrayLayouts, LinearAlgebra, FillArrays, Base64, Test
44
using StaticArrays, Tracker
5-
import LazyArrays: BroadcastLayout, arguments, LazyArrayStyle, sub_materialize, simplifiable
5+
import LazyArrays: BroadcastLayout, CachedArrayStyle, arguments, LazyArrayStyle, sub_materialize, simplifiable
66
import Base: broadcasted
7+
import Base.Broadcast: BroadcastStyle
78

89
using ..InfiniteArrays
910
using Infinities
@@ -478,6 +479,13 @@ using Infinities
478479
@test_throws "MethodError: no method matching _vec_mul_arguments" LazyArrays._vec_mul_arguments(2, [])
479480
end
480481
end
482+
483+
@testset "BroadcastStyle" begin
484+
@test BroadcastStyle(typeof(BroadcastVector(exp, 1:10))) == LazyArrayStyle{1}()
485+
@test BroadcastStyle(typeof(BroadcastVector(+, 1:10, cache(1:10)))) == CachedArrayStyle{1}()
486+
@test BroadcastStyle(typeof(BroadcastMatrix(*, Accumulate(*, 1:10)', rand(10)'))) == CachedArrayStyle{2}()
487+
@test BroadcastStyle(typeof(BroadcastMatrix(*, rand(5, 5)', LazyArrays.CachedArray(rand(5, 5))))) == CachedArrayStyle{2}()
488+
end
481489
end
482490

483491
end #module

0 commit comments

Comments
 (0)