|
1 | | -ChainRulesCore.rrule(::typeof(getproperty), x::ComponentArray, ::Val{s}) where s = ChainRulesCore.rrule(getproperty, x, s) |
| 1 | +function ChainRulesCore.rrule(::typeof(getproperty), x::ComponentArray, ::Val{s}) where {s} |
| 2 | + function getproperty_adjoint(Δ) |
| 3 | + zero_x = zero(x) |
| 4 | + setproperty!(zero_x, s, Δ) |
| 5 | + return (ChainRulesCore.NoTangent(), zero_x, ChainRulesCore.NoTangent()) |
| 6 | + end |
| 7 | + |
| 8 | + return getproperty(x, s), getproperty_adjoint |
| 9 | +end |
| 10 | + |
2 | 11 | function ChainRulesCore.rrule(::typeof(getproperty), x::ComponentArray, s::Symbol) |
3 | 12 | function getproperty_adjoint(Δ) |
4 | 13 | zero_x = zero(x) |
5 | 14 | setproperty!(zero_x, s, Δ) |
6 | | - return (ChainRulesCore.NO_FIELDS, zero_x) |
| 15 | + return (ChainRulesCore.NoTangent(), zero_x, ChainRulesCore.NoTangent()) |
7 | 16 | end |
8 | 17 |
|
9 | 18 | return getproperty(x, s), getproperty_adjoint |
10 | 19 | end |
11 | 20 |
|
12 | | -ChainRulesCore.rrule(::typeof(getdata), x::ComponentArray) = getdata(x), Δ->(ChainRulesCore.NO_FIELDS, ComponentArray(Δ, getaxes(x))) |
| 21 | +ChainRulesCore.rrule(::typeof(getdata), x::ComponentArray) = getdata(x), Δ->(ChainRulesCore.NoTangent(), ComponentArray(Δ, getaxes(x))) |
13 | 22 |
|
14 | | -ChainRulesCore.rrule(::Type{ComponentArray}, data, axes) = ComponentArray(data, axes), Δ->(ChainRulesCore.NO_FIELDS, getdata(Δ), getaxes(Δ)) |
| 23 | +ChainRulesCore.rrule(::Type{ComponentArray}, data, axes) = ComponentArray(data, axes), Δ->(ChainRulesCore.NoTangent(), getdata(Δ), getaxes(Δ)) |
0 commit comments