4747ChainRulesCore. ProjectTo (x:: VectorOfArray ) = ChainRulesCore. ProjectTo {VectorOfArray} ()
4848
4949# Gradient from iteration will be e.g. Vector{Vector}, this makes it another AbstractMatrix
50- (:: ChainRulesCore.ProjectTo{VectorOfArray} )(dx:: AbstractVector{<:AbstractArray} ) = VectorOfArray (dx)
50+ # (::ChainRulesCore.ProjectTo{VectorOfArray})(dx::AbstractVector{<:AbstractArray}) = VectorOfArray(dx)
5151# Gradient from broadcasting will be another AbstractArray
5252# (::ChainRulesCore.ProjectTo{VectorOfArray})(dx::AbstractArray) = dx
5353
@@ -56,16 +56,16 @@ ChainRulesCore.ProjectTo(x::VectorOfArray) = ChainRulesCore.ProjectTo{VectorOfAr
5656
5757ZygoteRules. @adjoint function getindex (VA:: AbstractVectorOfArray , i:: Union{Int,AbstractArray{Int},CartesianIndex,Colon,BitArray,AbstractArray{Bool}} )
5858 function AbstractVectorOfArray_getindex_adjoint (Δ)
59- Δ′ = [ (i == j ? Δ : zero (x )) for (x,j) in zip (VA. u, 1 : length (VA))]
59+ Δ′ = [(i == j ? Δ : Fill ( zero (eltype (x)), size (x) )) for (x,j) in zip (VA. u, 1 : length (VA))]
6060 (VectorOfArray (Δ′),nothing )
6161 end
6262 VA[i],AbstractVectorOfArray_getindex_adjoint
6363end
6464
6565ZygoteRules. @adjoint function getindex (VA:: AbstractVectorOfArray , i:: Union{Int,AbstractArray{Int},CartesianIndex,Colon,BitArray,AbstractArray{Bool}} , j:: Union{Int,AbstractArray{Int},CartesianIndex,Colon,BitArray,AbstractArray{Bool}} ...)
6666 function AbstractVectorOfArray_getindex_adjoint (Δ)
67- Δ′ = zero (VA)
68- Δ′[i, j... ] = Δ
67+ Δ′ = [(i == j ? zero (x) : Fill ( zero ( eltype (x)), size (x))) for (x,j) in zip (VA . u, 1 : length (VA))]
68+ Δ′[i][ j... ] = Δ
6969 (VectorOfArray (Δ′), nothing , map (_ -> nothing , j)... )
7070 end
7171 VA[i,j... ],AbstractVectorOfArray_getindex_adjoint
0 commit comments