Skip to content
Merged
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 18 additions & 8 deletions src/zygote.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
function ChainRulesCore.rrule(::typeof(getindex),VA::AbstractVectorOfArray, i::Union{Int,AbstractArray{Int},CartesianIndex,Colon,BitArray,AbstractArray{Bool}})
function AbstractVectorOfArray_getindex_adjoint(Δ)
Δ′ = [ (i == j ? Δ : zero(x)) for (x,j) in zip(VA.u, 1:length(VA))]
(NoTangent(),Δ′,NoTangent())
(NoTangent(),VectorOfArray(Δ′),NoTangent())
end
VA[i],AbstractVectorOfArray_getindex_adjoint
end
Expand All @@ -10,7 +10,7 @@ function ChainRulesCore.rrule(::typeof(getindex),VA::AbstractVectorOfArray, indi
function AbstractVectorOfArray_getindex_adjoint(Δ)
Δ′ = zero(VA)
Δ′[indices...] = Δ
(NoTangent(), Δ′, indices[1],map(_ -> NoTangent(), indices[2:end])...)
(NoTangent(), VectorOfArray(Δ′), indices[1],map(_ -> NoTangent(), indices[2:end])...)
end
VA[indices...],AbstractVectorOfArray_getindex_adjoint
end
Expand All @@ -19,7 +19,7 @@ function ChainRulesCore.rrule(::Type{<:ArrayPartition}, x::S, ::Type{Val{copy_x}
function ArrayPartition_adjoint(_y)
y = Array(_y)
starts = vcat(0,cumsum(reduce(vcat,length.(x))))
NoTangent(), ntuple(i -> reshape(y[starts[i]+1:starts[i+1]], size(x[i])), length(x)), NoTangent()
NoTangent(), ArrayPartition(ntuple(i -> reshape(y[starts[i]+1:starts[i+1]], size(x[i]))), length(x)), NoTangent()
end

ArrayPartition(x, Val{copy_x}), ArrayPartition_adjoint
Expand All @@ -43,10 +43,21 @@ function ChainRulesCore.rrule(::typeof(getproperty),A::ArrayPartition, s::Symbol
A.x,literal_ArrayPartition_x_adjoint
end

# Define a new species of projection operator for this type:
ChainRulesCore.ProjectTo(x::VectorOfArray) = ChainRulesCore.ProjectTo{VectorOfArray}()

# Gradient from iteration will be e.g. Vector{Vector}, this makes it another AbstractMatrix
(::ChainRulesCore.ProjectTo{VectorOfArray})(dx::AbstractVector{<:AbstractArray}) = VectorOfArray(dx)
# Gradient from broadcasting will be another AbstractArray
(::ChainRulesCore.ProjectTo{VectorOfArray})(dx::AbstractArray) = dx
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe this may not be necessary at all. One thing I thought to test was whether iteration like this worked without it, but it does, it hits the @adjoint getindex rule:

julia> function iter(vofa)
       s = 0
       for a in vofa
         s += prod(a)
       end
       s
       end;

julia> gradient(iter, va)[1]
VectorOfArray{Float64,3}:
2-element Vector{Matrix{Float64}}:
 [0.007377548303139424 0.0014293014720444096 0.0004998127128840348; 0.0005414269337139141 0.0007721441834498009 0.0006559506948612249; 0.0042378737935180105 0.0006765914005991947 0.00045986425172967415]
 [0.002774992305796606 0.002978041675310144 0.004412709924140469; 0.0056425205408066285 0.005228088118453952 0.0036646150274027; 0.0036825199036535465 0.004902176341789764 0.045170987413739046]


# These rules duplicate the `rrule` methods above, because Zygote looks for an `@adjoint`
# definition first, and finds its own before finding those.

ZygoteRules.@adjoint function getindex(VA::AbstractVectorOfArray, i::Union{Int,AbstractArray{Int},CartesianIndex,Colon,BitArray,AbstractArray{Bool}})
function AbstractVectorOfArray_getindex_adjoint(Δ)
Δ′ = [ (i == j ? Δ : zero(x)) for (x,j) in zip(VA.u, 1:length(VA))]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unrelated to the PR, but this seems like it allocates quite a bit, when iterating a VectorOfArray. I guess that using Fill(0.0, size(Δ)) would often make Δ′ have an abstract type?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it would make it an abstract type and sometimes hurt inference. Then we'd have to rely on union optimizations and pray.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Relying on union optimizations might be the right idea here though, I'll have to check.

(Δ′,nothing)
(VectorOfArray(Δ′),nothing)
end
VA[i],AbstractVectorOfArray_getindex_adjoint
end
Expand All @@ -55,11 +66,10 @@ ZygoteRules.@adjoint function getindex(VA::AbstractVectorOfArray, i::Union{Int,A
function AbstractVectorOfArray_getindex_adjoint(Δ)
Δ′ = zero(VA)
Δ′[i,j...] = Δ
(Δ′, i,map(_ -> nothing, j)...)
(VectorOfArray(Δ′), nothing, map(_ -> nothing, j)...)
end
VA[i,j...],AbstractVectorOfArray_getindex_adjoint
end

ZygoteRules.@adjoint function ArrayPartition(x::S, ::Type{Val{copy_x}} = Val{false}) where {S<:Tuple,copy_x}
function ArrayPartition_adjoint(_y)
y = Array(_y)
Expand All @@ -71,11 +81,11 @@ ZygoteRules.@adjoint function ArrayPartition(x::S, ::Type{Val{copy_x}} = Val{fal
end

ZygoteRules.@adjoint function VectorOfArray(u)
VectorOfArray(u),y -> ([y[ntuple(x->Colon(),ndims(y)-1)...,i] for i in 1:size(y)[end]],)
VectorOfArray(u),y -> (VectorOfArray([y[ntuple(x->Colon(),ndims(y)-1)...,i] for i in 1:size(y)[end]]),)
end

ZygoteRules.@adjoint function DiffEqArray(u,t)
DiffEqArray(u,t),y -> ([y[ntuple(x->Colon(),ndims(y)-1)...,i] for i in 1:size(y)[end]],nothing)
DiffEqArray(u,t),y -> (DiffEqArray([y[ntuple(x->Colon(),ndims(y)-1)...,i] for i in 1:size(y)[end]],t),nothing)
end

ZygoteRules.@adjoint function ZygoteRules.literal_getproperty(A::ArrayPartition, ::Val{:x})
Expand Down