Skip to content

Commit 7b6bd16

Browse files
committed
2 parents a2f9f8b + 05b9f61 commit 7b6bd16

File tree

7 files changed

+46
-14
lines changed

7 files changed

+46
-14
lines changed

NEWS.md

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,18 @@
11
# ComponentArrays.jl NEWS
22
Notes on new features (minor releases). For more details on bugfixes and non-feature-adding changes (patch releases), check out the [releases page](https://github.com/jonniedie/ComponentArrays.jl/releases).
33

4+
### v0.9.0
5+
- Construct `ComponentArray`s from `Dict`s!
6+
```julia
7+
julia> d = Dict(:a=>rand(3), :b=>rand(2,2))
8+
Dict{Symbol, Array{Float64, N} where N} with 2 entries:
9+
:a => [0.996693, 0.148683, 0.203083]
10+
:b => [0.68759 0.41585; 0.900591 0.377475]
11+
12+
julia> ComponentArray(d)
13+
ComponentVector{Float64}(a = [0.9966932920820444, 0.14868304847436709, 0.20308284992079573], b = [0.6875902095731583 0.415850281435181; 0.9005909643364229 0.3774747843717925])
14+
```
15+
416
### v0.8.0
517
- Generated `valkeys` function for fast iteration over `ComponentVector` subcomponents!
618
```julia

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "ComponentArrays"
22
uuid = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
33
authors = ["Jonnie Diegelman <47193959+jonniedie@users.noreply.github.com>"]
4-
version = "0.8.20"
4+
version = "0.9.0"
55

66
[deps]
77
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"

src/broadcasting.jl

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,8 +69,13 @@ function Base.similar(bc::BC.Broadcasted{<:CAStyle{<:BC.Unknown, Axes, N}}, T::T
6969
end
7070

7171

72-
Base.Broadcast.broadcasted(f, x::ComponentArray) = ComponentArray(map(f, x), getaxes(x))
72+
Base.Broadcast.broadcasted(f, x::ComponentArray) = ComponentArray(map(f, getdata(x)), getaxes(x))
7373

74+
# Need a special case here because `map` doesn't follow same rules as normal broadcasting. To be safe and avoid ambiguities,
75+
# we'll just handle the case where everything is a ComponentArray. Else it falls back to a plain Array output.
76+
function Base.map(f, xs::ComponentArray{<:Any, <:Any, <:Any, Axes}...) where Axes
77+
return ComponentArray(map(f, getdata.(xs)...), getaxes(Axes))
78+
end
7479

7580
# function Base.copy(bc::BC.Broadcasted{<:CAStyle{InnerStyle, Axes, N}}) where {InnerStyle, Axes, N}
7681
# return ComponentArray{Axes}(Base.copy(BC.broadcasted(bc.f, map(getdata, bc.args)...)))

src/componentarray.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,9 +57,10 @@ function ComponentArray(data, ax::AbstractAxis...)
5757
return LazyArray(ComponentArray(x, axs) for x in part_data)
5858
end
5959

60-
# Entry from NamedTuple or kwargs
60+
# Entry from NamedTuple, Dict, or kwargs
6161
ComponentArray{T}(nt::NamedTuple) where T = ComponentArray(make_carray_args(T, nt)...)
6262
ComponentArray(nt::NamedTuple) = ComponentArray(make_carray_args(nt)...)
63+
ComponentArray(d::AbstractDict) = ComponentArray(NamedTuple{Tuple(keys(d))}(values(d)))
6364
ComponentArray{T}(;kwargs...) where T = ComponentArray{T}((;kwargs...))
6465
ComponentArray(;kwargs...) = ComponentArray((;kwargs...))
6566

src/if_required/chainrulescore.jl

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@ using ChainRulesCore: NO_FIELDS
66
# setproperty!(zero_x, s, Δ)
77
# return getproperty(x, s), zero_x
88
# end
9-
ChainRulesCore.rrule(::typeof(getproperty), x::ComponentArray, s::Symbol) = ChainRulesCore.rrule(getproperty, x, Val(s))
10-
function ChainRulesCore.rrule(::typeof(getproperty), x::ComponentArray, ::Val{s}) where s
9+
ChainRulesCore.rrule(::typeof(getproperty), x::ComponentArray, ::Val{s}) where s = ChainRulesCore.rrule(getproperty, x, s)
10+
function ChainRulesCore.rrule(::typeof(getproperty), x::ComponentArray, s::Symbol)
1111
function getproperty_adjoint(Δ)
1212
zero_x = zero(x)
1313
setproperty!(zero_x, s, Δ)
@@ -17,12 +17,8 @@ function ChainRulesCore.rrule(::typeof(getproperty), x::ComponentArray, ::Val{s}
1717
return getproperty(x, s), getproperty_adjoint
1818
end
1919

20-
# ChainRulesCore.frule(::typeof(getdata), x::ComponentArray) = getdata(x), Δ->ComponentArray(Δ, getaxes(x))
21-
ChainRulesCore.rrule(::typeof(getdata), x::ComponentArray) = getdata(x), Δ->(NO_FIELDS, ComponentArray(Δ, getaxes(x)))
22-
23-
ChainRulesCore.rrule(::typeof(getaxes), x::ComponentArray) = getaxes(x), Δ->(NO_FIELDS, ComponentArray(getdata(x), Δ))
20+
ChainRulesCore.rrule(::typeof(getdata), x::ComponentArray) = getdata(x), Δ->ComponentArray(Δ, getaxes(x))
2421

2522
ChainRulesCore.rrule(::Type{ComponentArray}, data, axes) = ComponentArray(data, axes), Δ->(NO_FIELDS, getdata(Δ), getaxes(Δ))
2623

27-
ChainRulesCore.rrule(::Type{Axis}, nt) = Axis(nt), Δ->(NO_FIELDS, ComponentArrays.indexmap(Δ))
28-
ChainRulesCore.rrule(::Type{Axis}; kwargs...) = Axis(; kwargs...), Δ->(NO_FIELDS, (; ComponentArrays.indexmap(Δ)...))
24+
ChainRulesCore.rrule(::Type{ComponentArray}, data, axes) = ComponentArray(data, axes), Δ->(getdata(Δ), getaxes(Δ))

src/similar_convert_copy.jl

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,21 @@
33
Base.similar(x::ComponentArray) = ComponentArray(similar(getdata(x)), getaxes(x)...)
44
Base.similar(x::ComponentArray, ::Type{T}) where T = ComponentArray(similar(getdata(x), T), getaxes(x)...)
55
Base.similar(x::ComponentArray, ::Type{T}, ax::Tuple{Vararg{Int64,N}}) where {T,N} = similar(x, T, ax...)
6-
Base.similar(x::ComponentArray, ::Type{T}, ax::Union{Integer, Base.OneTo}...) where T =
7-
similar(getdata(x), T, ax...)
6+
function Base.similar(x::ComponentArray, ::Type{T}, ax::Union{Integer, Base.OneTo}...) where T
7+
A = similar(getdata(x), T, ax...)
8+
if size(getdata(x)) == size(A)
9+
return ComponentArray(A, getaxes(x))
10+
else
11+
return A
12+
end
13+
end
14+
815
## TODO: write length method for AbstractAxis so we can do this?
916
# function Base.similar(::Type{CA}) where CA<:ComponentArray{T,N,A,Axes} where {T,N,A,Axes}
1017
# axs = getaxes(CA)
1118
# return ComponentArray(similar(A, length.(axs)...), axs...)
1219
# end
1320

14-
1521
Base.zero(x::ComponentArray) = zero.(x)
1622

1723
## FIXME: waiting on similar(::Type{<:ComponentArray})

test/runtests.jl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,12 @@ end
5656
@test typeof(ComponentArray{Float32}(undef, (ax,))) == typeof(ca_Float32)
5757
@test typeof(ComponentArray{MVector{10,Float64}}(undef, (ax,))) == typeof(ca_MVector)
5858

59+
# Entry from Dict
60+
dict1 = Dict(:a=>rand(5), :b=>rand(5,5))
61+
dict2 = Dict(:a=>3, :b=>dict1)
62+
@test ComponentArray(dict1) isa ComponentArray
63+
@test ComponentArray(dict2).b isa ComponentArray
64+
5965
@test ca == ComponentVector(a=100, b=[4, 1.3], c=(a=(a=1, b=[1.0, 4.4]), b=[0.4, 2, 1, 45]))
6066
@test cmat == ComponentMatrix(a .* a', ax, ax)
6167
@test_throws DimensionMismatch ComponentVector(sq_mat, ax)
@@ -277,6 +283,12 @@ end
277283
@test getaxes(x1mat + xmat) == (getaxes(x1)[1], FlatAxis())
278284
@test getaxes(x1mat + xmat') == (FlatAxis(), getaxes(x1)[1])
279285

286+
@test map(sqrt, ca) isa ComponentArray
287+
@test map(+, ca, sqrt.(ca)) isa ComponentArray
288+
@test map(+, sqrt.(ca), Float32.(ca), ca) isa ComponentArray
289+
@test map(+, ca, getdata(ca)) isa Array
290+
@test map(+, ca, ComponentArray(v=getdata(ca))) isa Array
291+
280292
x1 .+= x2
281293
@test getdata(x1) == 2getdata(x2)
282294

0 commit comments

Comments
 (0)