Skip to content

Commit fbf6a82

Browse files
authored
fix: restructure of tracked component arrays (#269)
1 parent f5574a5 commit fbf6a82

File tree

4 files changed

+15
-2
lines changed

4 files changed

+15
-2
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "ComponentArrays"
22
uuid = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
33
authors = ["Jonnie Diegelman <47193959+jonniedie@users.noreply.github.com>"]
4-
version = "0.15.16"
4+
version = "0.15.17"
55

66
[deps]
77
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"

ext/ComponentArraysTrackerExt.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
module ComponentArraysTrackerExt
22

3+
using ArrayInterface: ArrayInterface
34
using ComponentArrays, Tracker
45

56
function Tracker.param(ca::ComponentArray)
@@ -34,4 +35,10 @@ end
3435
return ComponentArrays._getindex(Base.getindex, x, v)
3536
end
3637

38+
function ArrayInterface.restructure(x::ComponentVector,
39+
y::ComponentVector{T, <:TrackedArray}) where {T}
40+
getaxes(x) == getaxes(y) || error("Axes must match")
41+
return y
42+
end
43+
3744
end

test/Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
[deps]
2+
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
23
ArrayInterfaceCore = "30b0a656-2188-435a-8636-2ec0e6a096e2"
34
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
45
FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41"

test/autodiff_tests.jl

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import FiniteDiff, ForwardDiff, ReverseDiff, Tracker, Zygote
2-
using Optimisers
2+
using Optimisers, ArrayInterface
33
using Test
44

55
F(a, x) = sum(abs2, a) * x^3
@@ -127,3 +127,8 @@ end
127127
@test eltype(getdata(ps_data)) <: Float64
128128
end
129129

130+
@testset "ArrayInterface restructure TrackedArray" begin
131+
ps = ComponentArray(; a = rand(2), b = (; c = rand(2)))
132+
ps_tracked = Tracker.param(ps)
133+
@test ArrayInterface.restructure(ps, ps_tracked) isa ComponentVector{<:Any, <:Tracker.TrackedArray}
134+
end

0 commit comments

Comments
 (0)