1+ Base. parent (x:: ComponentArray ) = getfield (x, :data )
2+
3+ Base. size (x:: ComponentArray ) = size (getdata (x))
4+ ArrayInterface. size (A:: ComponentArray ) = ArrayInterface. size (parent (A))
5+
6+ Base. elsize (x:: Type{<:ComponentArray{T,N,A,Axes}} ) where {T,N,A,Axes} = Base. elsize (A)
7+
8+ Base. axes (x:: ComponentArray ) = axes (getdata (x))
9+
10+ Base. reinterpret (:: Type{T} , x:: ComponentArray , args... ) where T = ComponentArray (reinterpret (T, getdata (x), args... ), getaxes (x))
11+
12+ Base. hcat (x:: CV... ) where {CV<: ComponentVector } = ComponentArray (reduce (hcat, getdata .(x)), getaxes (x[1 ])[1 ], FlatAxis ())
13+
14+ Base. vcat (x:: ComponentVector , y:: AbstractVector ) = vcat (getdata (x), y)
15+ Base. vcat (x:: AbstractVector , y:: ComponentVector ) = vcat (x, getdata (y))
16+ function Base. vcat (x:: ComponentVector , y:: ComponentVector )
17+ if reduce ((accum, key) -> accum || (key in keys (x)), keys (y); init= false )
18+ return vcat (getdata (x), getdata (y))
19+ else
20+ data_x, data_y = getdata .((x, y))
21+ ax_x, ax_y = only .(getaxes .((x, y)))
22+ ax_y = reindex (ax_y, length (x))
23+ idxmap_x, idxmap_y = indexmap .((ax_x, ax_y))
24+ return ComponentArray (vcat (data_x, data_y), Axis ((;idxmap_x... , idxmap_y... )))
25+ end
26+ end
27+ Base. vcat (x:: CV... ) where {CV<: AdjOrTransComponentArray } = ComponentArray (reduce (vcat, map (y-> getdata (y. parent)' , x)), getaxes (x[1 ]))
28+ Base. vcat (x:: ComponentVector... ) = reduce (vcat, x)
29+ Base. vcat (x:: ComponentVector , args... ) = vcat (getdata (x), getdata .(args)... )
30+ Base. vcat (x:: ComponentVector , args:: Vararg{AbstractVector{T}, N} ) where {T,N} = vcat (getdata (x), getdata .(args)... )
31+
32+ function Base. permutedims (x:: ComponentArray , dims)
33+ axs = getaxes (x)
34+ return ComponentArray (permutedims (getdata (x), dims), map (i-> axs[i], dims)... )
35+ end
36+
37+ # # Indexing
38+ Base. IndexStyle (:: Type{<:ComponentArray{T,N,<:A,<:Axes}} ) where {T,N,A,Axes} = IndexStyle (A)
39+
40+ # Since we aren't really using the standard approach to indexing, this will forward things to
41+ # the correct methods
42+ Base. to_indices (x:: ComponentArray , i:: Tuple ) = i
43+ Base. to_indices (x:: ComponentArray , i:: Tuple{Vararg{Union{Integer, CartesianIndex}, N}} ) where N = i
44+ Base. to_indices (x:: ComponentArray , i:: Tuple{Vararg{Int64}} ) where N = i
45+ Base. to_index (x:: ComponentArray , i) = i
46+
47+ # Get AbstractAxis index
48+ @inline Base. getindex (:: AbstractAxis , idx:: FlatIdx ) = ComponentIndex (idx)
49+ @inline Base. getindex (ax:: AbstractAxis , :: Colon ) = ComponentIndex (:, ax)
50+ @inline Base. getindex (:: AbstractAxis{IdxMap} , s:: Symbol ) where IdxMap =
51+ ComponentIndex (getproperty (IdxMap, s))
52+
53+ # Get ComponentArray index
54+ Base. @propagate_inbounds Base. getindex (x:: ComponentArray , idx:: CartesianIndex ) = getdata (x)[idx]
55+ Base. @propagate_inbounds Base. getindex (x:: ComponentArray , idx:: FlatIdx... ) = getdata (x)[idx... ]
56+ Base. @propagate_inbounds function Base. getindex (x:: ComponentArray , idx:: FlatOrColonIdx... )
57+ axs = map ((ax, i) -> getindex (ax, i). ax, getaxes (x), idx)
58+ axs = remove_nulls (axs... )
59+ return ComponentArray (getdata (x)[idx... ], axs... )
60+ end
61+ Base. @propagate_inbounds Base. getindex (x:: ComponentArray , :: Colon ) = getdata (x)[:]
62+ @inline Base. getindex (x:: ComponentArray , :: Colon... ) = x
63+ Base. @propagate_inbounds Base. getindex (x:: ComponentArray , idx... ) = getindex (x, toval .(idx)... )
64+ @inline Base. getindex (x:: ComponentArray , idx:: Val... ) = _getindex (x, idx... )
65+
66+ # Set ComponentArray index
67+ @inline Base. setindex! (x:: ComponentArray , v, idx:: FlatIdx... ) = setindex! (getdata (x), v, idx... )
68+ Base. @propagate_inbounds Base. setindex! (x:: ComponentArray , v, :: Colon ) = setindex! (getdata (x), v, :)
69+ Base. @propagate_inbounds Base. setindex! (x:: ComponentArray , v, idx... ) = setindex! (x, v, toval .(idx)... )
70+ @inline Base. setindex! (x:: ComponentArray , v, idx:: Val... ) = _setindex! (x, v, idx... )
71+
72+ # Explicitly view
73+ Base. @propagate_inbounds Base. view (x:: ComponentArray , idx:: ComponentArrays.FlatIdx... ) = view (getdata (x), idx... )
74+ Base. @propagate_inbounds Base. view (x:: ComponentArray , idx... ) = _getindex (x, toval .(idx)... )
75+
76+ # Generated get and set index methods to do all of the heavy lifting in the type domain
77+ @generated function _getindex (x:: ComponentArray , idx... )
78+ ci = getindex .(getaxes (x), getval .(idx))
79+ inds = map (i -> i. idx, ci)
80+ axs = map (i -> i. ax, ci)
81+ axs = remove_nulls (axs... )
82+ # the index must be valid after computing `ci`
83+ :(Base. @_inline_meta ; ComponentArray (Base. maybeview (getdata (x), $ inds... ), $ axs... ))
84+ end
85+
86+ @generated function _setindex! (x:: ComponentArray , v, idx... )
87+ ci = getindex .(getaxes (x), getval .(idx))
88+ inds = map (i -> i. idx, ci)
89+ # the index must be valid after computing `ci`
90+ return :(Base. @_inline_meta ; setindex! (getdata (x), v, $ inds... ))
91+ end
92+
93+ # # Linear Algebra
94+ Base. pointer (x:: ComponentArray{T,N,A,Axes} ) where {T,N,A<: DenseArray ,Axes} = pointer (getdata (x))
95+
96+ Base. unsafe_convert (:: Type{Ptr{T}} , x:: ComponentArray{T,N,A,Axes} ) where {T,N,A,Axes} = Base. unsafe_convert (Ptr{T}, getdata (x))
97+
98+ Base. strides (x:: ComponentArray ) = strides (getdata (x))
99+ ArrayInterface. strides (A:: ComponentArray ) = ArrayInterface. strides (parent (A))
100+ for f in [:device , :stride_rank , :contiguous_axis , :contiguous_batch_size , :dense_dims ]
101+ @eval ArrayInterface.$ f (:: Type{ComponentArray{T,N,A,Axes}} ) where {T,N,A,Axes} = ArrayInterface.$ f (A)
102+ end
103+
104+ Base. stride (x:: ComponentArray , k) = stride (getdata (x), k)
105+ Base. stride (x:: ComponentArray , k:: Int64 ) = stride (getdata (x), k)
106+
107+ ArrayInterface. lu_instance (jac_prototype:: ComponentArray ) = ArrayInterface. lu_instance (getdata (jac_prototype))
108+
109+ ArrayInterface. parent_type (:: Type{ComponentArray{T,N,A,Axes}} ) where {T,N,A,Axes} = A
110+
111+
112+
113+ # While there are some cases where these were faster, it is going to be almost impossible to
114+ # to keep up with method ambiguity errors due to other array types overloading *, /, and \.
115+ # Leaving these here and commented out for now, but will delete them later.
116+
117+ # # Avoid slower fallback
118+ # for f in [:(*), :(/), :(\)]
119+ # @eval begin
120+ # # The normal stuff
121+ # Base.$f(x::ComponentArray, y::AbstractArray) = $f(getdata(x), y)
122+ # Base.$f(x::AbstractArray, y::ComponentArray) = $f(x, getdata(y))
123+ # Base.$f(x::ComponentArray, y::ComponentArray) = $f(getdata(x), getdata(y))
124+
125+ # # A bunch of special cases to avoid ambiguous method errors
126+ # Base.$f(x::ComponentArray, y::AbstractMatrix) = $f(getdata(x), y)
127+ # Base.$f(x::AbstractMatrix, y::ComponentArray) = $f(x, getdata(y))
128+
129+ # Base.$f(x::ComponentArray, y::AbstractVector) = $f(getdata(x), y)
130+ # Base.$f(x::AbstractVector, y::ComponentArray) = $f(x, getdata(y))
131+ # end
132+ # end
133+
134+ # # Adjoint/transpose special cases
135+ # for f in [:(*), :(/)]
136+ # @eval begin
137+ # Base.$f(x::Adjoint, y::ComponentArray) = $f(getdata(x), getdata(y))
138+ # Base.$f(x::Transpose, y::ComponentArray) = $f(getdata(x), getdata(y))
139+
140+ # Base.$f(x::Adjoint{T,<:AbstractVector{T}}, y::ComponentVector) where T = $f(x, getdata(y))
141+ # Base.$f(x::Transpose{T,<:AbstractVector{T}}, y::ComponentVector) where T = $f(x, getdata(y))
142+
143+ # Base.$f(x::Adjoint{T,<:AbstractVector{T}}, y::ComponentMatrix{T,A,Axes}) where {T,A,Axes} = $f(x, getdata(y))
144+ # Base.$f(x::Transpose{T,<:AbstractVector{T}}, y::ComponentMatrix{T,A,Axes}) where {T,A,Axes} = $f(x, getdata(y))
145+
146+ # Base.$f(x::Adjoint{T,<:AbstractMatrix{T}}, y::ComponentVector) where {T} = $f(x, getdata(y))
147+ # Base.$f(x::Transpose{T,<:AbstractMatrix{T}}, y::ComponentVector) where {T} = $f(x, getdata(y))
148+
149+ # Base.$f(x::ComponentArray, y::Adjoint{T,<:AbstractVector{T}}) where T = $f(getdata(x), y)
150+ # Base.$f(x::ComponentArray, y::Transpose{T,<:AbstractVector{T}}) where T = $f(getdata(x), y)
151+
152+ # Base.$f(x::ComponentArray, y::Adjoint{T,<:ComponentVector}) where T = $f(getdata(x), getdata(y))
153+ # Base.$f(x::ComponentArray, y::Transpose{T,<:ComponentVector}) where T = $f(getdata(x), getdata(y))
154+
155+ # # There seems to be a new method in Julia > v.1.4 that specializes on this
156+ # Base.$f(x::Adjoint{T,<:AbstractVector{T}}, y::ComponentVector{T,A,Axes}) where {T<:Number,A,Axes} = $f(x, getdata(y))
157+ # Base.$f(x::Transpose{T,<:AbstractVector{T}}, y::ComponentVector{T,A,Axes}) where {T<:Number,A,Axes} = $f(x, getdata(y))
158+
159+ # Base.$f(x::Adjoint{T,<:AbstractVector{T}}, y::ComponentVector{T,A,Axes}) where {T<:Real,A,Axes} = $f(getdata(x), getdata(y))
160+ # Base.$f(x::Transpose{T,<:AbstractVector{T}}, y::ComponentVector{T,A,Axes}) where {T<:Real,A,Axes} = $f(getdata(x), getdata(y))
161+ # end
162+ # end
0 commit comments