Skip to content

Commit a8f1aad

Browse files
committed
Fixed ReverseDiff stuff
1 parent 1efe316 commit a8f1aad

File tree

3 files changed

+17
-13
lines changed

3 files changed

+17
-13
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "ComponentArrays"
22
uuid = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
33
authors = ["Jonnie Diegelman <47193959+jonniedie@users.noreply.github.com>"]
4-
version = "0.10.5"
4+
version = "0.10.6"
55

66
[deps]
77
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"

src/if_required/reversediff.jl

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,13 @@ const TrackedComponentArray{V, D, N, DA, N, A, Ax} = ReverseDiff.TrackedArray{V,
33
maybe_tracked_array(val::AbstractArray, der, t) = ReverseDiff.TrackedArray(val, der, t)
44
maybe_tracked_array(val, der, t) = ReverseDiff.TrackedReal(val, der, t)
55

6-
function Base.getindex(tca::TrackedComponentArray, inds::Union{Symbol, Val}...)
7-
val = ReverseDiff.value(tca)[inds...]
8-
der = ReverseDiff.deriv(tca)[inds...]
9-
t = ReverseDiff.tape(tca)
10-
return maybe_tracked_array(val, der, t)
6+
for f in [:getindex, :view]
7+
@eval function Base.$f(tca::TrackedComponentArray, inds::Union{Symbol, Val}...)
8+
val = $f(ReverseDiff.value(tca), inds...)
9+
der = Base.maybeview(ReverseDiff.deriv(tca), inds...)
10+
t = ReverseDiff.tape(tca)
11+
return maybe_tracked_array(val, der, t)
12+
end
1113
end
1214

1315
function Base.getproperty(tca::TrackedComponentArray, s::Symbol)

test/autodiff_tests.jl

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,25 +8,27 @@ using Test
88
F(x, θ, deg) = (θ[1] - x[1])^deg + θ[2] * (x[2] - x[1]^deg)^deg
99
F_idx_val(ca) = F(ca[Val(:x)], ca[Val()], ca[Val(:deg)])
1010
F_idx_sym(ca) = F(ca[:x], ca[], ca[:deg])
11+
F_view_val(ca) = F(@view(ca[Val(:x)]), @view(ca[Val()]), ca[Val(:deg)])
12+
F_view_sym(ca) = F(@view(ca[:x]), @view(ca[]), ca[:deg])
1113
F_prop(ca) = F(ca.x, ca.θ, ca.deg)
1214

1315
ca = ComponentArray(x = [1, 2], θ = [1.0, 100.0], deg = 2)
1416
truth = [-400, 200]
1517

16-
@testset "$(nameof(F_))" for F_ in (F_idx_val, F_idx_sym, F_prop)
18+
@testset "$(nameof(F_))" for F_ in (F_idx_val, F_idx_sym, F_view_val, F_view_sym, F_prop)
1719
finite = FiniteDiff.finite_difference_gradient(ca -> F_(ca), ca).x
1820
@test finite truth
1921

2022
forward = ForwardDiff.gradient(ca -> F_(ca), ca).x
2123
@test forward truth
2224

2325
reverse = ReverseDiff.gradient(ca -> F_(ca), ca).x
24-
if F_ in (F_idx_val, F_idx_sym)
25-
@test_broken reverse truth
26+
@test reverse truth
27+
28+
zygote_full = Zygote.gradient(ca -> F_(ca), ca)[1]
29+
if F_ == F_prop && VERSION < v"1.3"
30+
@test_broken zygote_full.x truth
2631
else
27-
@test reverse truth
32+
@test zygote_full.x truth
2833
end
29-
30-
zygote = Zygote.gradient(ca -> F_(ca), ca)[1].x
31-
@test zygote truth
3234
end

0 commit comments

Comments
 (0)