Skip to content

Commit 8a8c7ba

Browse files
committed
Fixes some chain rules stuff. Fixes #76
1 parent 277a1c4 commit 8a8c7ba

File tree

3 files changed

+3
-14
lines changed

3 files changed

+3
-14
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "ComponentArrays"
22
uuid = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
33
authors = ["Jonnie Diegelman <47193959+jonniedie@users.noreply.github.com>"]
4-
version = "0.9.2"
4+
version = "0.9.3"
55

66
[deps]
77
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"

src/ComponentArrays.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
module ComponentArrays
22

33
using ArrayInterface
4-
# using ChainRulesCore
54
using LinearAlgebra
65
using Requires
76

src/if_required/chainrulescore.jl

Lines changed: 2 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,3 @@
1-
using ChainRulesCore: NO_FIELDS
2-
3-
# ChainRulesCore.frule(Δ, ::typeof(getproperty), x::ComponentArray, s::Symbol) = frule((_, Δ), getproperty, x, Val(s))
4-
# function ChainRulesCore.frule(Δ, ::typeof(getproperty), x::ComponentArray, ::Val{s}) where s
5-
# zero_x = zero(x)
6-
# setproperty!(zero_x, s, Δ)
7-
# return getproperty(x, s), zero_x
8-
# end
91
ChainRulesCore.rrule(::typeof(getproperty), x::ComponentArray, ::Val{s}) where s = ChainRulesCore.rrule(getproperty, x, s)
102
function ChainRulesCore.rrule(::typeof(getproperty), x::ComponentArray, s::Symbol)
113
function getproperty_adjoint(Δ)
@@ -17,8 +9,6 @@ function ChainRulesCore.rrule(::typeof(getproperty), x::ComponentArray, s::Symbo
179
return getproperty(x, s), getproperty_adjoint
1810
end
1911

20-
ChainRulesCore.rrule(::typeof(getdata), x::ComponentArray) = getdata(x), Δ->ComponentArray(Δ, getaxes(x))
21-
22-
ChainRulesCore.rrule(::Type{ComponentArray}, data, axes) = ComponentArray(data, axes), Δ->(NO_FIELDS, getdata(Δ), getaxes(Δ))
12+
ChainRulesCore.rrule(::typeof(getdata), x::ComponentArray) = getdata(x), Δ->(ChainRulesCore.NO_FIELDS, ComponentArray(Δ, getaxes(x)))
2313

24-
ChainRulesCore.rrule(::Type{ComponentArray}, data, axes) = ComponentArray(data, axes), Δ->(getdata(Δ), getaxes(Δ))
14+
ChainRulesCore.rrule(::Type{ComponentArray}, data, axes) = ComponentArray(data, axes), Δ->(ChainRulesCore.NO_FIELDS, getdata(Δ), getaxes(Δ))

0 commit comments

Comments
 (0)