@@ -43,10 +43,28 @@ function ChainRulesCore.rrule(::typeof(getproperty),A::ArrayPartition, s::Symbol
4343 A. x,literal_ArrayPartition_x_adjoint
4444end
4545
46+ #=
47+
48+ # Define a new species of projection operator for this type:
49+ ChainRulesCore.ProjectTo(x::VectorOfArray) = ProjectTo{VectorOfArray}()
50+
51+ # Gradient from iteration will be e.g. Vector{Vector}, this makes it another AbstractMatrix
52+ (::ProjectTo{VectorOfArray})(dx::AbstractVector{<:AbstractArray}) = VectorOfArray(dx)
53+ # Gradient from broadcasting will be another AbstractArray
54+ (::ProjectTo{VectorOfArray})(dx::AbstractArray) = dx
55+
56+ But this may not be necessary?
57+
58+ =#
59+
60+
61+ # These rules duplicate the `rrule` methods above, because Zygote looks for an `@adjoint`
62+ # definition first, and finds its own before finding those.
63+
4664ZygoteRules. @adjoint function getindex (VA:: AbstractVectorOfArray , i:: Union{Int,AbstractArray{Int},CartesianIndex,Colon,BitArray,AbstractArray{Bool}} )
4765 function AbstractVectorOfArray_getindex_adjoint (Δ)
4866 Δ′ = [ (i == j ? Δ : zero (x)) for (x,j) in zip (VA. u, 1 : length (VA))]
49- (Δ′ ,nothing )
67+ (VectorOfArray (Δ′) ,nothing )
5068 end
5169 VA[i],AbstractVectorOfArray_getindex_adjoint
5270end
@@ -55,11 +73,13 @@ ZygoteRules.@adjoint function getindex(VA::AbstractVectorOfArray, i::Union{Int,A
5573 function AbstractVectorOfArray_getindex_adjoint (Δ)
5674 Δ′ = zero (VA)
5775 Δ′[i,j... ] = Δ
58- (Δ′, i,map (_ -> nothing , j)... )
76+ @show Δ′
77+ # (Δ′, i,map(_ -> nothing, j)...) # surely that i is a bug?
78+ (Δ′, nothing , map (_ -> nothing , j)... )
79+ # (VectorOfArray(Δ′), nothing, map(_ -> nothing, j)...)
5980 end
6081 VA[i,j... ],AbstractVectorOfArray_getindex_adjoint
6182end
62-
6383ZygoteRules. @adjoint function ArrayPartition (x:: S , :: Type{Val{copy_x}} = Val{false }) where {S<: Tuple ,copy_x}
6484 function ArrayPartition_adjoint (_y)
6585 y = Array (_y)
0 commit comments