11function ChainRulesCore. rrule (:: typeof (getindex),VA:: AbstractVectorOfArray , i:: Union{Int,AbstractArray{Int},CartesianIndex,Colon,BitArray,AbstractArray{Bool}} )
22 function AbstractVectorOfArray_getindex_adjoint (Δ)
33 Δ′ = [ (i == j ? Δ : zero (x)) for (x,j) in zip (VA. u, 1 : length (VA))]
4- (NoTangent (),Δ′ ,NoTangent ())
4+ (NoTangent (),VectorOfArray (Δ′) ,NoTangent ())
55 end
66 VA[i],AbstractVectorOfArray_getindex_adjoint
77end
@@ -10,7 +10,7 @@ function ChainRulesCore.rrule(::typeof(getindex),VA::AbstractVectorOfArray, indi
1010 function AbstractVectorOfArray_getindex_adjoint (Δ)
1111 Δ′ = zero (VA)
1212 Δ′[indices... ] = Δ
13- (NoTangent (), Δ′, indices[ 1 ], map (_ -> NoTangent (), indices[ 2 : end ] )... )
13+ (NoTangent (), VectorOfArray (Δ′), map (_ -> NoTangent (), indices)... )
1414 end
1515 VA[indices... ],AbstractVectorOfArray_getindex_adjoint
1616end
@@ -19,7 +19,7 @@ function ChainRulesCore.rrule(::Type{<:ArrayPartition}, x::S, ::Type{Val{copy_x}
1919 function ArrayPartition_adjoint (_y)
2020 y = Array (_y)
2121 starts = vcat (0 ,cumsum (reduce (vcat,length .(x))))
22- NoTangent (), ntuple (i -> reshape (y[starts[i]+ 1 : starts[i+ 1 ]], size (x[i])), length (x)), NoTangent ()
22+ NoTangent (), ArrayPartition ( ntuple (i -> reshape (y[starts[i]+ 1 : starts[i+ 1 ]], size (x[i]) )), length (x)), NoTangent ()
2323 end
2424
2525 ArrayPartition (x, Val{copy_x}), ArrayPartition_adjoint
@@ -43,23 +43,33 @@ function ChainRulesCore.rrule(::typeof(getproperty),A::ArrayPartition, s::Symbol
4343 A. x,literal_ArrayPartition_x_adjoint
4444end
4545
46+ # Define a new species of projection operator for this type:
47+ ChainRulesCore. ProjectTo (x:: VectorOfArray ) = ChainRulesCore. ProjectTo {VectorOfArray} ()
48+
49+ # Gradient from iteration will be e.g. Vector{Vector}, this makes it another AbstractMatrix
50+ # (::ChainRulesCore.ProjectTo{VectorOfArray})(dx::AbstractVector{<:AbstractArray}) = VectorOfArray(dx)
51+ # Gradient from broadcasting will be another AbstractArray
52+ # (::ChainRulesCore.ProjectTo{VectorOfArray})(dx::AbstractArray) = dx
53+
54+ # These rules duplicate the `rrule` methods above, because Zygote looks for an `@adjoint`
55+ # definition first, and finds its own before finding those.
56+
4657ZygoteRules. @adjoint function getindex (VA:: AbstractVectorOfArray , i:: Union{Int,AbstractArray{Int},CartesianIndex,Colon,BitArray,AbstractArray{Bool}} )
4758 function AbstractVectorOfArray_getindex_adjoint (Δ)
48- Δ′ = [ (i == j ? Δ : zero (x )) for (x,j) in zip (VA. u, 1 : length (VA))]
49- (Δ′ ,nothing )
59+ Δ′ = [(i == j ? Δ : Fill ( zero (eltype (x)), size (x) )) for (x,j) in zip (VA. u, 1 : length (VA))]
60+ (VectorOfArray (Δ′) ,nothing )
5061 end
5162 VA[i],AbstractVectorOfArray_getindex_adjoint
5263end
5364
5465ZygoteRules. @adjoint function getindex (VA:: AbstractVectorOfArray , i:: Union{Int,AbstractArray{Int},CartesianIndex,Colon,BitArray,AbstractArray{Bool}} , j:: Union{Int,AbstractArray{Int},CartesianIndex,Colon,BitArray,AbstractArray{Bool}} ...)
5566 function AbstractVectorOfArray_getindex_adjoint (Δ)
56- Δ′ = zero (VA)
57- Δ′[i, j... ] = Δ
58- (Δ′, i, map (_ -> nothing , j)... )
67+ Δ′ = [(i == j ? zero (x) : Fill ( zero ( eltype (x)), size (x))) for (x,j) in zip (VA . u, 1 : length (VA))]
68+ Δ′[i][ j... ] = Δ
69+ (VectorOfArray (Δ′), nothing , map (_ -> nothing , j)... )
5970 end
6071 VA[i,j... ],AbstractVectorOfArray_getindex_adjoint
6172end
62-
6373ZygoteRules. @adjoint function ArrayPartition (x:: S , :: Type{Val{copy_x}} = Val{false }) where {S<: Tuple ,copy_x}
6474 function ArrayPartition_adjoint (_y)
6575 y = Array (_y)
@@ -71,11 +81,11 @@ ZygoteRules.@adjoint function ArrayPartition(x::S, ::Type{Val{copy_x}} = Val{fal
7181end
7282
7383ZygoteRules. @adjoint function VectorOfArray (u)
74- VectorOfArray (u),y -> ([y[ntuple (x-> Colon (),ndims (y)- 1 )... ,i] for i in 1 : size (y)[end ]],)
84+ VectorOfArray (u),y -> (VectorOfArray ( [y[ntuple (x-> Colon (),ndims (y)- 1 )... ,i] for i in 1 : size (y)[end ]]) ,)
7585end
7686
7787ZygoteRules. @adjoint function DiffEqArray (u,t)
78- DiffEqArray (u,t),y -> ([y[ntuple (x-> Colon (),ndims (y)- 1 )... ,i] for i in 1 : size (y)[end ]],nothing )
88+ DiffEqArray (u,t),y -> (DiffEqArray ( [y[ntuple (x-> Colon (),ndims (y)- 1 )... ,i] for i in 1 : size (y)[end ]],t) ,nothing )
7989end
8090
8191ZygoteRules. @adjoint function ZygoteRules. literal_getproperty (A:: ArrayPartition , :: Val{:x} )
0 commit comments