Skip to content

Commit 37b4b4f

Browse files
committed
ReverseDiff support. Closes #37 closes #78
1 parent 47d8ccf commit 37b4b4f

File tree

8 files changed

+292
-36
lines changed

8 files changed

+292
-36
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "ComponentArrays"
22
uuid = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
33
authors = ["Jonnie Diegelman <47193959+jonniedie@users.noreply.github.com>"]
4-
version = "0.9.4"
4+
version = "0.9.5"
55

66
[deps]
77
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"

src/ComponentArrays.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ function __init__()
3535
@require StaticArrays="90137ffa-7385-5640-81b9-e52037218182" required("staticarrays.jl")
3636
# @require Zygote="e88e6eb3-aa80-5325-afca-941959d7151f" required("zygote.jl")
3737
@require ChainRulesCore="d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" required("chainrulescore.jl")
38+
@require ReverseDiff="37e2e3b7-166d-5795-8a7a-e32c996b4267" required("reversediff.jl")
3839
end
3940

4041

src/if_required/reversediff.jl

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
const TrackedComponentArray{V, D, N, DA, N, A, Ax} = ReverseDiff.TrackedArray{V,D,N,ComponentArray{V,N,A,Ax},DA}
2+
3+
maybe_tracked_array(val::AbstractArray, der, t) = ReverseDiff.TrackedArray(val, der, t)
4+
maybe_tracked_array(val, der, t) = ReverseDiff.TrackedReal(val, der, t)
5+
6+
function Base.getindex(tca::TrackedComponentArray, inds::Union{Symbol, Val}...)
7+
val = ReverseDiff.value(tca)[inds...]
8+
der = ReverseDiff.deriv(tca)[inds...]
9+
t = ReverseDiff.tape(tca)
10+
return maybe_tracked_array(val, der, t)
11+
end
12+
13+
function Base.getproperty(tca::TrackedComponentArray, s::Symbol)
14+
if s in (:value, :deriv, :tape)
15+
return getfield(tca, s)
16+
else
17+
val = getproperty(ReverseDiff.value(tca), s)
18+
der = getproperty(ReverseDiff.deriv(tca), s)
19+
t = ReverseDiff.tape(tca)
20+
return maybe_tracked_array(val, der, t)
21+
end
22+
end

src/plot_utils.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ see also `labels`
8686
"""
8787
label2index(x::ComponentVector, str) = label2index(labels(x), str)
8888
function label2index(labs, str)
89-
idx = findall(startswith.(labs, Regex("\\Q$str\\E(?:(\\.|\\[))"))) #str * r"(\.|\[)"))
89+
idx = findall(startswith.(labs, Regex("\\Q$str\\E(?:(\\.|\\[))")))
9090
if !isempty(idx)
9191
return idx
9292
else

0 commit comments

Comments
 (0)