Skip to content

Commit 5807681

Browse files
committed
Broadcasting performance improvements
1 parent 85254df commit 5807681

File tree

4 files changed

+21
-14
lines changed

4 files changed

+21
-14
lines changed

.gitignore

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,6 @@ research/
1010
/examples/wip
1111
/examples/ignore
1212
/examples/.ipynb_checkpoints
13-
.vscode/settings.json
13+
.vscode
1414
/issue_debug
1515
Manifest.toml

src/array_interface.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -59,14 +59,14 @@ Base.@propagate_inbounds function Base.getindex(x::ComponentArray, idx::FlatOrCo
5959
return ComponentArray(getdata(x)[idx...], axs...)
6060
end
6161
Base.@propagate_inbounds Base.getindex(x::ComponentArray, ::Colon) = getdata(x)[:]
62-
@inline Base.getindex(x::ComponentArray, ::Colon...) = x
63-
Base.@propagate_inbounds Base.getindex(x::ComponentArray, idx...) = getindex(x, toval.(idx)...)
62+
Base.@propagate_inbounds Base.getindex(x::ComponentArray, ::Colon...) = x
63+
@inline Base.getindex(x::ComponentArray, idx...) = getindex(x, toval.(idx)...)
6464
@inline Base.getindex(x::ComponentArray, idx::Val...) = _getindex(x, idx...)
6565

6666
# Set ComponentArray index
67-
@inline Base.setindex!(x::ComponentArray, v, idx::FlatIdx...) = setindex!(getdata(x), v, idx...)
67+
Base.@propagate_inbounds Base.setindex!(x::ComponentArray, v, idx::FlatOrColonIdx...) = setindex!(getdata(x), v, idx...)
6868
Base.@propagate_inbounds Base.setindex!(x::ComponentArray, v, ::Colon) = setindex!(getdata(x), v, :)
69-
Base.@propagate_inbounds Base.setindex!(x::ComponentArray, v, idx...) = setindex!(x, v, toval.(idx)...)
69+
@inline Base.setindex!(x::ComponentArray, v, idx...) = setindex!(x, v, toval.(idx)...)
7070
@inline Base.setindex!(x::ComponentArray, v, idx::Val...) = _setindex!(x, v, idx...)
7171

7272
# Explicitly view

src/broadcasting.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ function Base.similar(bc::BC.Broadcasted{<:CAStyle{<:BC.Unknown, Axes, N}}, T::T
6969
end
7070

7171

72-
BC.broadcasted(f, x::ComponentArray) = ComponentArray(map(f, getdata(x)), getaxes(x))
72+
# BC.broadcasted(f, x::ComponentArray) = ComponentArray(map(f, getdata(x)), getaxes(x))
7373

7474
# Need a special case here because `map` doesn't follow same rules as normal broadcasting. To be safe and avoid ambiguities,
7575
# we'll just handle the case where everything is a ComponentArray. Else it falls back to a plain Array output.

src/similar_convert_copy.jl

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,21 @@
22
# Similar
33
Base.similar(x::ComponentArray) = ComponentArray(similar(getdata(x)), getaxes(x)...)
44
Base.similar(x::ComponentArray, ::Type{T}) where T = ComponentArray(similar(getdata(x), T), getaxes(x)...)
5-
Base.similar(x::ComponentArray, ::Type{T}, ax::Tuple{Vararg{Int64,N}}) where {T,N} = similar(x, T, ax...)
6-
function Base.similar(x::ComponentArray, ::Type{T}, ax::Union{Integer, Base.OneTo}...) where T
7-
A = similar(getdata(x), T, ax...)
8-
if size(getdata(x)) == size(A)
9-
return ComponentArray(A, getaxes(x))
10-
else
11-
return A
12-
end
5+
# Base.similar(x::ComponentArray, ::Type{T}, ax::Tuple{Vararg{Int64,N}}) where {T,N} = similar(x, T, ax...)
6+
# function Base.similar(x::ComponentArray, ::Type{T}, ax::Union{Integer, Base.OneTo}...) where T
7+
# A = similar(getdata(x), T, ax...)
8+
# if size(getdata(x)) == size(A)
9+
# return ComponentArray(A, getaxes(x))
10+
# else
11+
# return A
12+
# end
13+
# end
14+
function Base.similar(x::ComponentArray{T1,N,A,Ax}, ::Type{T}, dims::NTuple{N,Int}) where {T,T1,N,A,Ax}
15+
arr = similar(getdata(x), T, dims)
16+
return ComponentArray(arr, getaxes(x))
17+
end
18+
function Base.similar(x::ComponentArray{T1,N1,A,Ax}, ::Type{T}, dims::NTuple{N2,Int}) where {T,T1,N1,N2,A,Ax}
19+
return similar(getdata(x), T, dims)
1320
end
1421

1522
## TODO: write length method for AbstractAxis so we can do this?

0 commit comments

Comments
 (0)