@@ -10,63 +10,89 @@ function CAStyle(::InnerStyle, ax::Axes, ::Val{N}) where {InnerStyle, Axes, N}
1010end
1111
1212
13- function Base. BroadcastStyle (:: Type{<:ComponentArray{T, N, A, Axes}} ) where {T, A, N, Axes}
14- return CAStyle (Base. BroadcastStyle (A), getaxes (Axes), ndims (A))
15- end
16- function Base. BroadcastStyle (AA:: Type{<:Adjoint{T, <:ComponentArray{T, N, A, Axes}}} ) where {T, N, A, Axes}
17- return CAStyle (Base. BroadcastStyle (Adjoint{T,A}), getaxes (AA), ndims (AA))
18- end
19- function Base. BroadcastStyle (AA:: Type{<:Transpose{T, <:ComponentArray{T, N, A, Axes}}} ) where {T, N, A, Axes}
20- return CAStyle (Base. BroadcastStyle (Transpose{T,A}), getaxes (AA), ndims (AA))
21- end
22-
23- function Base. BroadcastStyle (:: CAStyle{InnerStyle, Axes, N} , bc:: BC.Broadcasted ) where {InnerStyle, Axes, N}
24- return CAStyle (Base. BroadcastStyle (InnerStyle (), bc), Axes, N)
25- end
13+ # function Base.BroadcastStyle(::Type{<:ComponentArray{T, N, A, Axes}}) where {T, A, N, Axes}
14+ # return CAStyle(Base.BroadcastStyle(A), getaxes(Axes), ndims(A))
15+ # end
16+ # function Base.BroadcastStyle(AA::Type{<:Adjoint{T, <:ComponentArray{T, N, A, Axes}}}) where {T, N, A, Axes}
17+ # return CAStyle(Base.BroadcastStyle(Adjoint{T,A}), getaxes(AA), ndims(AA))
18+ # end
19+ # function Base.BroadcastStyle(AA::Type{<:Transpose{T, <:ComponentArray{T, N, A, Axes}}}) where {T, N, A, Axes}
20+ # return CAStyle(Base.BroadcastStyle(Transpose{T,A}), getaxes(AA), ndims(AA))
21+ # end
2622
23+ # function Base.BroadcastStyle(::CAStyle{InnerStyle, Axes, N}, bc::BC.Broadcasted) where {InnerStyle, Axes, N}
24+ # return CAStyle(Base.BroadcastStyle(InnerStyle(), bc), Axes, N)
25+ # end
2726
28- function BC. BroadcastStyle (:: CAStyle{<:In1, <:Ax1, <:N1} , :: CAStyle{<:In2, <:Ax2, <:N2} ) where {In1, Ax1, N1, In2, Ax2, N2}
29- ax, N = fill_flat (Ax1, Ax2, N1, N2)
30- inner_style = BC. BroadcastStyle (In1 (), In2 ())
31- if inner_style isa BC. Unknown
32- inner_style = BC. DefaultArrayStyle {N} ()
33- end
34- return CAStyle (inner_style, ax, N)
35- end
36- function BC. BroadcastStyle (:: CAStyle{In, Ax, N1} , :: Style ) where Style<: BC.DefaultArrayStyle{N2} where {In, Ax, N1, N2}
37- N = max (N1, N2)
38- ax = fill_flat (Ax, max (N1, N2))
39- inner_style = BC. BroadcastStyle (In (), Style ())
40- return CAStyle (inner_style, ax, N)
41- end
42- function BC. BroadcastStyle (CAS:: CAStyle{In, Ax, N1} , :: BC.DefaultArrayStyle{0} ) where {In, Ax, N1}
43- return CAS
44- end
45- function BC. BroadcastStyle (CAS:: CAStyle{In, Ax, N} , :: BC.DefaultArrayStyle{N} ) where {In, Ax, N}
46- return CAS
47- end
48- function BC. BroadcastStyle (:: CAStyle{In, Ax, N1} , :: Style ) where Style<: BC.AbstractArrayStyle{N2} where {In, Ax, N1, N2}
49- N = max (N1, N2)
50- ax = fill_flat (Ax, max (N1, N2))
51- inner_style = BC. BroadcastStyle (In (), Style ())
52- return CAStyle (inner_style, ax, N)
53- end
27+ Base. BroadcastStyle (:: Type{<:ComponentArray{T, N, A, Axes}} ) where {T, N, A, Axes} = BC. DefaultArrayStyle {N} ()
28+ # Base.BroadcastStyle(::Type{<:ComponentArray{T, N, A, Axes}}) where {T, N, A, Axes} = BC.BroadcastStyle(A)
29+
30+ Base. getindex (bc:: BC.Broadcasted , inds:: ComponentIndex... ) = bc[value .(inds)... ]
31+
32+ # Need special case here for adjoint vectors in order to avoid type instability in axistype
33+ BC. combine_axes (a:: ComponentArray , b:: AdjOrTransComponentVector ) = (axes (a)[1 ], axes (b)[2 ])
34+ BC. combine_axes (a:: AdjOrTransComponentVector , b:: ComponentArray ) = (axes (b)[2 ], axes (a)[1 ])
35+
36+ # BC.axistype(a::CombinedAxis, b::AbstractUnitRange) = Base.Broadcast.axistype(_array_axis(a), b)
37+ # BC.axistype(a::AbstractUnitRange, b::CombinedAxis) = Base.Broadcast.axistype(a, _array_axis(b))
38+ # BC.axistype(a::CombinedAxis, b::CombinedAxis) = Base.Broadcast.axistype(_array_axis(a), _array_axis(b))
39+ BC. axistype (a:: CombinedAxis , b:: AbstractUnitRange ) = a
40+ BC. axistype (a:: AbstractUnitRange , b:: CombinedAxis ) = b
41+ BC. axistype (a:: CombinedAxis , b:: CombinedAxis ) = CombinedAxis (FlatAxis (), Base. Broadcast. axistype (_array_axis (a), _array_axis (b)))
42+ BC. axistype (a:: T , b:: T ) where {T<: CombinedAxis } = a
43+
44+ Base. promote_shape (a:: Tuple{Vararg{CombinedAxis}} , b:: NTuple{N,AbstractUnitRange} ) where N = Base. promote_shape (_array_axis .(a), b)
45+ Base. promote_shape (a:: NTuple{N,AbstractUnitRange} , b:: Tuple{Vararg{CombinedAxis}} ) where N = Base. promote_shape (a, _array_axis .(b))
46+ Base. promote_shape (a:: Tuple{Vararg{CombinedAxis}} , b:: Tuple{Vararg{CombinedAxis}} ) = Base. promote_shape (_array_axis .(a), _array_axis .(b))
47+ # Base.promote_shape(a::Tuple{Vararg{Union{AbstractUnitRange, CombinedAxis}}}, b::Tuple{Vararg{Union{AbstractUnitRange, CombinedAxis}}}) = promote_shape(_array_axis.(a), _array_axis.(b))
48+ Base. promote_shape (a:: T , b:: T ) where {T<: Tuple{Vararg{CombinedAxis}} } = a
49+
50+ # # Hack to make things like Dual.(ComponentArray(a=1,b=1), [1,1]) work
51+ # BC.broadcasted(f::Type, arg1::ComponentArray, args...) = ComponentArray(f.(getdata(arg1), getdata.(args)...), getaxes(arg1))
52+ # BC.broadcasted(f::Type, arg1, arg2::ComponentArray, args...) = ComponentArray(f.(arg1, getdata(arg2), getdata.(args)...), getaxes(arg2))
53+
54+ # function BC.BroadcastStyle(::CAStyle{<:In1, <:Ax1, <:N1}, ::CAStyle{<:In2, <:Ax2, <:N2}) where {In1, Ax1, N1, In2, Ax2, N2}
55+ # ax, N = fill_flat(Ax1, Ax2, N1, N2)
56+ # inner_style = BC.BroadcastStyle(In1(), In2())
57+ # if inner_style isa BC.Unknown
58+ # inner_style = BC.DefaultArrayStyle{N}()
59+ # end
60+ # return CAStyle(inner_style, ax, N)
61+ # end
62+ # function BC.BroadcastStyle(::CAStyle{In, Ax, N1}, ::Style) where Style<:BC.DefaultArrayStyle{N2} where {In, Ax, N1, N2}
63+ # N = max(N1, N2)
64+ # ax = fill_flat(Ax, max(N1, N2))
65+ # inner_style = BC.BroadcastStyle(In(), Style())
66+ # return CAStyle(inner_style, ax, N)
67+ # end
68+ # function BC.BroadcastStyle(CAS::CAStyle{In, Ax, N1}, ::BC.DefaultArrayStyle{0}) where {In, Ax, N1}
69+ # return CAS
70+ # end
71+ # function BC.BroadcastStyle(CAS::CAStyle{In, Ax, N}, ::BC.DefaultArrayStyle{N}) where {In, Ax, N}
72+ # return CAS
73+ # end
74+ # function BC.BroadcastStyle(::CAStyle{In, Ax, N1}, ::Style) where Style<:BC.AbstractArrayStyle{N2} where {In, Ax, N1, N2}
75+ # N = max(N1, N2)
76+ # ax = fill_flat(Ax, max(N1, N2))
77+ # inner_style = BC.BroadcastStyle(In(), Style())
78+ # return CAStyle(inner_style, ax, N)
79+ # end
5480
5581
56- Base. convert (:: Type{<:BC.Broadcasted{Nothing}} , bc:: BC.Broadcasted{<:CAStyle,Axes,F,Args} ) where {Axes,F,Args} = getdata (bc)
82+ # Base.convert(::Type{<:BC.Broadcasted{Nothing}}, bc::BC.Broadcasted{<:CAStyle,Axes,F,Args}) where {Axes,F,Args} = getdata(bc)
5783
58- getdata (bc:: BC.Broadcasted{<:CAStyle} ) = BC. broadcasted (bc. f, map (getdata, bc. args)... )
84+ # getdata(bc::BC.Broadcasted{<:CAStyle}) = BC.broadcasted(bc.f, map(getdata, bc.args)...)
5985
6086
61- function Base. similar (bc:: BC.Broadcasted{<:CAStyle{InnerStyle, Axes, N}} , args... ) where {InnerStyle, Axes, N}
62- return ComponentArray {Axes} ( similar (BC. Broadcasted {InnerStyle} (bc. f, bc. args, bc. axes), args... ) )
63- end
64- function Base. similar (bc:: BC.Broadcasted{<:CAStyle{InnerStyle, Axes, N}} , T:: Type ) where {InnerStyle, Axes, N}
65- return ComponentArray {Axes} ( similar (BC. Broadcasted {InnerStyle} (bc. f, bc. args, bc. axes), T) )
66- end
67- function Base. similar (bc:: BC.Broadcasted{<:CAStyle{<:BC.Unknown, Axes, N}} , T:: Type ) where {InnerStyle, Axes, N}
68- return ComponentArray {Axes} ( similar (BC. Broadcasted {BC.DefaultArrayStyle{N}} (bc. f, bc. args, bc. axes), T) )
69- end
87+ # function Base.similar(bc::BC.Broadcasted{<:CAStyle{InnerStyle, Axes, N}}, args...) where {InnerStyle, Axes, N}
88+ # return similar(BC.Broadcasted{InnerStyle}(bc.f, bc.args, bc.axes), args...)
89+ # end
90+ # function Base.similar(bc::BC.Broadcasted{<:CAStyle{InnerStyle, Axes, N}}, T::Type) where {InnerStyle, Axes, N}
91+ # return similar(BC.Broadcasted{InnerStyle}(bc.f, bc.args, bc.axes), T)
92+ # end
93+ # function Base.similar(bc::BC.Broadcasted{<:CAStyle{<:BC.Unknown, Axes, N}}, T::Type) where {InnerStyle, Axes, N}
94+ # return similar(BC.Broadcasted{BC.DefaultArrayStyle{N}}(bc.f, bc.args, bc.axes), T)
95+ # end
7096
7197
7298# BC.broadcasted(f, x::ComponentArray) = ComponentArray(map(f, getdata(x)), getaxes(x))
@@ -77,6 +103,7 @@ function Base.map(f, xs::ComponentArray{<:Any, <:Any, <:Any, Axes}...) where Axe
77103 return ComponentArray (map (f, getdata .(xs)... ), getaxes (Axes))
78104end
79105
106+
80107# function Base.copy(bc::BC.Broadcasted{<:CAStyle{InnerStyle, Axes, N}}) where {InnerStyle, Axes, N}
81108# return ComponentArray{Axes}(Base.copy(BC.broadcasted(bc.f, map(getdata, bc.args)...)))
82109# end
0 commit comments