Skip to content

Commit d9a555e

Browse files
committed
updated ChainRules adjoints
1 parent 982f075 commit d9a555e

File tree

1 file changed

+13
-4
lines changed

1 file changed

+13
-4
lines changed

src/if_required/chainrulescore.jl

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,23 @@
1-
ChainRulesCore.rrule(::typeof(getproperty), x::ComponentArray, ::Val{s}) where s = ChainRulesCore.rrule(getproperty, x, s)
1+
function ChainRulesCore.rrule(::typeof(getproperty), x::ComponentArray, ::Val{s}) where {s}
2+
function getproperty_adjoint(Δ)
3+
zero_x = zero(x)
4+
setproperty!(zero_x, s, Δ)
5+
return (ChainRulesCore.NoTangent(), zero_x, ChainRulesCore.NoTangent())
6+
end
7+
8+
return getproperty(x, s), getproperty_adjoint
9+
end
10+
211
function ChainRulesCore.rrule(::typeof(getproperty), x::ComponentArray, s::Symbol)
312
function getproperty_adjoint(Δ)
413
zero_x = zero(x)
514
setproperty!(zero_x, s, Δ)
6-
return (ChainRulesCore.NO_FIELDS, zero_x)
15+
return (ChainRulesCore.NoTangent(), zero_x, ChainRulesCore.NoTangent())
716
end
817

918
return getproperty(x, s), getproperty_adjoint
1019
end
1120

12-
ChainRulesCore.rrule(::typeof(getdata), x::ComponentArray) = getdata(x), Δ->(ChainRulesCore.NO_FIELDS, ComponentArray(Δ, getaxes(x)))
21+
ChainRulesCore.rrule(::typeof(getdata), x::ComponentArray) = getdata(x), Δ->(ChainRulesCore.NoTangent(), ComponentArray(Δ, getaxes(x)))
1322

14-
ChainRulesCore.rrule(::Type{ComponentArray}, data, axes) = ComponentArray(data, axes), Δ->(ChainRulesCore.NO_FIELDS, getdata(Δ), getaxes(Δ))
23+
ChainRulesCore.rrule(::Type{ComponentArray}, data, axes) = ComponentArray(data, axes), Δ->(ChainRulesCore.NoTangent(), getdata(Δ), getaxes(Δ))

0 commit comments

Comments
 (0)