|
1 | 1 | const BC = Base.Broadcast |
2 | 2 |
|
| 3 | +Base.BroadcastStyle(::Type{<:ComponentArray{T, N, A, Axes}}) where {T, N, A, Axes} = BC.BroadcastStyle(A) |
3 | 4 |
|
4 | | -struct CAStyle{InnerStyle<:BC.BroadcastStyle, Axes, N} <: BC.AbstractArrayStyle{N} end |
5 | | -CAStyle(::InnerStyle, ::Axes, N) where {InnerStyle, Axes} = CAStyle{InnerStyle, Axes, N}() |
6 | | -CAStyle(::InnerStyle, ::Type{<:Axes}, N) where {InnerStyle, Axes} = CAStyle{InnerStyle, Axes, N}() |
| 5 | +Base.getindex(bc::BC.Broadcasted, inds::ComponentIndex...) = bc[value.(inds)...] |
7 | 6 |
|
8 | | -function CAStyle(::InnerStyle, ax::Axes, ::Val{N}) where {InnerStyle, Axes, N} |
9 | | - return CAStyle(InnerStyle(), ax, N) |
10 | | -end |
11 | | - |
12 | | - |
13 | | -function Base.BroadcastStyle(::Type{<:ComponentArray{T, N, A, Axes}}) where {T, A, N, Axes} |
14 | | - return CAStyle(Base.BroadcastStyle(A), getaxes(Axes), ndims(A)) |
15 | | -end |
16 | | -function Base.BroadcastStyle(AA::Type{<:Adjoint{T, <:ComponentArray{T, N, A, Axes}}}) where {T, N, A, Axes} |
17 | | - return CAStyle(Base.BroadcastStyle(Adjoint{T,A}), getaxes(AA), ndims(AA)) |
18 | | -end |
19 | | -function Base.BroadcastStyle(AA::Type{<:Transpose{T, <:ComponentArray{T, N, A, Axes}}}) where {T, N, A, Axes} |
20 | | - return CAStyle(Base.BroadcastStyle(Transpose{T,A}), getaxes(AA), ndims(AA)) |
21 | | -end |
22 | | - |
23 | | -function Base.BroadcastStyle(::CAStyle{InnerStyle, Axes, N}, bc::BC.Broadcasted) where {InnerStyle, Axes, N} |
24 | | - return CAStyle(Base.BroadcastStyle(InnerStyle(), bc), Axes, N) |
25 | | -end |
26 | | - |
27 | | - |
28 | | -function BC.BroadcastStyle(::CAStyle{<:In1, <:Ax1, <:N1}, ::CAStyle{<:In2, <:Ax2, <:N2}) where {In1, Ax1, N1, In2, Ax2, N2} |
29 | | - ax, N = fill_flat(Ax1, Ax2, N1, N2) |
30 | | - inner_style = BC.BroadcastStyle(In1(), In2()) |
31 | | - if inner_style isa BC.Unknown |
32 | | - inner_style = BC.DefaultArrayStyle{N}() |
33 | | - end |
34 | | - return CAStyle(inner_style, ax, N) |
35 | | -end |
36 | | -function BC.BroadcastStyle(::CAStyle{In, Ax, N1}, ::Style) where Style<:BC.DefaultArrayStyle{N2} where {In, Ax, N1, N2} |
37 | | - N = max(N1, N2) |
38 | | - ax = fill_flat(Ax, max(N1, N2)) |
39 | | - inner_style = BC.BroadcastStyle(In(), Style()) |
40 | | - return CAStyle(inner_style, ax, N) |
41 | | -end |
42 | | -function BC.BroadcastStyle(CAS::CAStyle{In, Ax, N1}, ::BC.DefaultArrayStyle{0}) where {In, Ax, N1} |
43 | | - return CAS |
44 | | -end |
45 | | -function BC.BroadcastStyle(CAS::CAStyle{In, Ax, N}, ::BC.DefaultArrayStyle{N}) where {In, Ax, N} |
46 | | - return CAS |
47 | | -end |
48 | | -function BC.BroadcastStyle(::CAStyle{In, Ax, N1}, ::Style) where Style<:BC.AbstractArrayStyle{N2} where {In, Ax, N1, N2} |
49 | | - N = max(N1, N2) |
50 | | - ax = fill_flat(Ax, max(N1, N2)) |
51 | | - inner_style = BC.BroadcastStyle(In(), Style()) |
52 | | - return CAStyle(inner_style, ax, N) |
53 | | -end |
54 | | - |
55 | | - |
56 | | -Base.convert(::Type{<:BC.Broadcasted{Nothing}}, bc::BC.Broadcasted{<:CAStyle,Axes,F,Args}) where {Axes,F,Args} = getdata(bc) |
57 | | - |
58 | | -getdata(bc::BC.Broadcasted{<:CAStyle}) = BC.broadcasted(bc.f, map(getdata, bc.args)...) |
59 | | - |
60 | | - |
61 | | -function Base.similar(bc::BC.Broadcasted{<:CAStyle{InnerStyle, Axes, N}}, args...) where {InnerStyle, Axes, N} |
62 | | - return ComponentArray{Axes}(similar(BC.Broadcasted{InnerStyle}(bc.f, bc.args, bc.axes), args...)) |
63 | | -end |
64 | | -function Base.similar(bc::BC.Broadcasted{<:CAStyle{InnerStyle, Axes, N}}, T::Type) where {InnerStyle, Axes, N} |
65 | | - return ComponentArray{Axes}(similar(BC.Broadcasted{InnerStyle}(bc.f, bc.args, bc.axes), T)) |
66 | | -end |
67 | | -function Base.similar(bc::BC.Broadcasted{<:CAStyle{<:BC.Unknown, Axes, N}}, T::Type) where {InnerStyle, Axes, N} |
68 | | - return ComponentArray{Axes}(similar(BC.Broadcasted{BC.DefaultArrayStyle{N}}(bc.f, bc.args, bc.axes), T)) |
69 | | -end |
| 7 | +# Need special case here for adjoint vectors in order to avoid type instability in axistype |
| 8 | +BC.combine_axes(a::ComponentArray, b::AdjOrTransComponentVector) = (axes(a)[1], axes(b)[2]) |
| 9 | +BC.combine_axes(a::AdjOrTransComponentVector, b::ComponentArray) = (axes(b)[2], axes(a)[1]) |
70 | 10 |
|
| 11 | +BC.axistype(a::CombinedAxis, b::AbstractUnitRange) = a |
| 12 | +BC.axistype(a::AbstractUnitRange, b::CombinedAxis) = b |
| 13 | +BC.axistype(a::CombinedAxis, b::CombinedAxis) = CombinedAxis(FlatAxis(), Base.Broadcast.axistype(_array_axis(a), _array_axis(b))) |
| 14 | +BC.axistype(a::T, b::T) where {T<:CombinedAxis} = a |
71 | 15 |
|
72 | | -# BC.broadcasted(f, x::ComponentArray) = ComponentArray(map(f, getdata(x)), getaxes(x)) |
| 16 | +Base.promote_shape(a::Tuple{Vararg{CombinedAxis}}, b::NTuple{N,AbstractUnitRange}) where N = Base.promote_shape(_array_axis.(a), b) |
| 17 | +Base.promote_shape(a::NTuple{N,AbstractUnitRange}, b::Tuple{Vararg{CombinedAxis}}) where N = Base.promote_shape(a, _array_axis.(b)) |
| 18 | +Base.promote_shape(a::Tuple{Vararg{CombinedAxis}}, b::Tuple{Vararg{CombinedAxis}}) = Base.promote_shape(_array_axis.(a), _array_axis.(b)) |
| 19 | +Base.promote_shape(a::T, b::T) where {T<:Tuple{Vararg{CombinedAxis}}} = a |
73 | 20 |
|
74 | 21 | # Need a special case here because `map` doesn't follow same rules as normal broadcasting. To be safe and avoid ambiguities, |
75 | 22 | # we'll just handle the case where everything is a ComponentArray. Else it falls back to a plain Array output. |
76 | 23 | function Base.map(f, xs::ComponentArray{<:Any, <:Any, <:Any, Axes}...) where Axes |
77 | 24 | return ComponentArray(map(f, getdata.(xs)...), getaxes(Axes)) |
78 | 25 | end |
79 | 26 |
|
80 | | -# function Base.copy(bc::BC.Broadcasted{<:CAStyle{InnerStyle, Axes, N}}) where {InnerStyle, Axes, N} |
81 | | -# return ComponentArray{Axes}(Base.copy(BC.broadcasted(bc.f, map(getdata, bc.args)...))) |
82 | | -# end |
83 | | -# function Base.copy(bc::BC.Broadcasted{<:CAStyle{InnerStyle, Axes, N}}) where {InnerStyle, Axes, N} |
84 | | -# return ComponentArray{Axes}(Base.copy(BC.Broadcasted(InnerStyle()))) |
85 | | -# end |
86 | 27 |
|
87 | 28 | # From https://github.com/JuliaArrays/OffsetArrays.jl/blob/master/src/OffsetArrays.jl |
88 | 29 | Base.dataids(A::ComponentArray) = Base.dataids(parent(A)) |
89 | 30 | Broadcast.broadcast_unalias(dest::ComponentArray, src) = getdata(dest) === getdata(src) ? src : Broadcast.unalias(dest, src) |
90 | | - |
91 | | - |
92 | | - |
93 | | -# Helper for extruding axes |
94 | | -function fill_flat(Ax1, Ax2, N1, N2) |
95 | | - if N1<N2 |
96 | | - N = N2 |
97 | | - ax1 = fill_flat(Ax1,N) |
98 | | - ax2 = Ax2 |
99 | | - elseif N1>N2 |
100 | | - N = N1 |
101 | | - ax1 = Ax1 |
102 | | - ax2 = fill_flat(Ax2,N) |
103 | | - else |
104 | | - N = N1 |
105 | | - ax1, ax2 = Ax1, Ax2 |
106 | | - end |
107 | | - # Ax = Base.promote_typeof(getaxes(ax1), getaxes(ax2)) |
108 | | - Ax = broadcast_promote_typeof(getaxes(ax1), getaxes(ax2)) |
109 | | - return Ax, N |
110 | | -end |
111 | | -fill_flat(Ax::Type{<:VarAxes}, N) = fill_flat(getaxes(Ax), N) |> typeof |
112 | | -function fill_flat(Ax::VarAxes, N) |
113 | | - axs = Ax |
114 | | - n = length(axs) |
115 | | - if N>n |
116 | | - axs = (axs..., ntuple(x -> FlatAxis(), N-n)...) |
117 | | - end |
118 | | - return axs |
119 | | -end |
0 commit comments