Skip to content

Commit a2f9f8b

Browse files
committed
Fixed struct definition
1 parent 43a398e commit a2f9f8b

File tree

2 files changed

+9
-4
lines changed

2 files changed

+9
-4
lines changed

src/componentarray.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ julia> collect(x)
3131
2
3232
```
3333
"""
34-
struct ComponentArray{T,N,A<:AbstractArray{T,N},Axes<:Tuple{Vararg{AbstractAxis}}} <: DenseArray{T,N}
34+
struct ComponentArray{T,N,A<:AbstractArray{T,N},Axes<:Tuple{Vararg{<:AbstractAxis}}} <: DenseArray{T,N}
3535
data::A
3636
axes::Axes
3737
end

src/if_required/chainrulescore.jl

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
using ChainRulesCore: NO_FIELDS
2+
13
# ChainRulesCore.frule(Δ, ::typeof(getproperty), x::ComponentArray, s::Symbol) = frule((_, Δ), getproperty, x, Val(s))
24
# function ChainRulesCore.frule(Δ, ::typeof(getproperty), x::ComponentArray, ::Val{s}) where s
35
# zero_x = zero(x)
@@ -16,8 +18,11 @@ function ChainRulesCore.rrule(::typeof(getproperty), x::ComponentArray, ::Val{s}
1618
end
1719

1820
# ChainRulesCore.frule(::typeof(getdata), x::ComponentArray) = getdata(x), Δ->ComponentArray(Δ, getaxes(x))
19-
ChainRulesCore.rrule(::typeof(getdata), x::ComponentArray) = getdata(x), Δ->ComponentArray(Δ, getaxes(x))
21+
ChainRulesCore.rrule(::typeof(getdata), x::ComponentArray) = getdata(x), Δ->(NO_FIELDS, ComponentArray(Δ, getaxes(x)))
22+
23+
ChainRulesCore.rrule(::typeof(getaxes), x::ComponentArray) = getaxes(x), Δ->(NO_FIELDS, ComponentArray(getdata(x), Δ))
2024

21-
ChainRulesCore.rrule(::typeof(getaxes), x::ComponentArray) = getaxes(x), Δ->ComponentArray(getdata(x), Δ)
25+
ChainRulesCore.rrule(::Type{ComponentArray}, data, axes) = ComponentArray(data, axes), Δ->(NO_FIELDS, getdata(Δ), getaxes(Δ))
2226

23-
ChainRulesCore.rrule(::typeof(ComponentArray), data, axes) = ComponentArray(data, axes), Δ->(getdata(Δ), getaxes(Δ))
27+
ChainRulesCore.rrule(::Type{Axis}, nt) = Axis(nt), Δ->(NO_FIELDS, ComponentArrays.indexmap(Δ))
28+
ChainRulesCore.rrule(::Type{Axis}; kwargs...) = Axis(; kwargs...), Δ->(NO_FIELDS, (; ComponentArrays.indexmap(Δ)...))

0 commit comments

Comments
 (0)