Skip to content

Commit 1d94f35

Browse files
Merge pull request #169 from SciML/grad
Gradient definitions & supertypes for Zygote, continued
2 parents 196ecd7 + b3ed973 commit 1d94f35

File tree

3 files changed

+25
-11
lines changed

3 files changed

+25
-11
lines changed

Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ version = "2.17.2"
77
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
88
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
99
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
10+
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
1011
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1112
RecipesBase = "3cdcf5f2-1ef4-517c-9805-6587b60abb01"
1213
Requires = "ae029012-a4dd-5104-9daa-d747884805df"

src/RecursiveArrayTools.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,9 @@ using Requires, RecipesBase, StaticArrays, Statistics,
1111
import ChainRulesCore
1212
import ChainRulesCore: NoTangent
1313
import ZygoteRules
14+
15+
using FillArrays
16+
1417
abstract type AbstractVectorOfArray{T, N, A} <: AbstractArray{T, N} end
1518
abstract type AbstractDiffEqArray{T, N, A} <: AbstractVectorOfArray{T, N, A} end
1619

src/zygote.jl

Lines changed: 21 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
function ChainRulesCore.rrule(::typeof(getindex),VA::AbstractVectorOfArray, i::Union{Int,AbstractArray{Int},CartesianIndex,Colon,BitArray,AbstractArray{Bool}})
22
function AbstractVectorOfArray_getindex_adjoint(Δ)
33
Δ′ = [ (i == j ? Δ : zero(x)) for (x,j) in zip(VA.u, 1:length(VA))]
4-
(NoTangent(),Δ′,NoTangent())
4+
(NoTangent(),VectorOfArray(Δ′),NoTangent())
55
end
66
VA[i],AbstractVectorOfArray_getindex_adjoint
77
end
@@ -10,7 +10,7 @@ function ChainRulesCore.rrule(::typeof(getindex),VA::AbstractVectorOfArray, indi
1010
function AbstractVectorOfArray_getindex_adjoint(Δ)
1111
Δ′ = zero(VA)
1212
Δ′[indices...] = Δ
13-
(NoTangent(), Δ′, indices[1],map(_ -> NoTangent(), indices[2:end])...)
13+
(NoTangent(), VectorOfArray(Δ′), map(_ -> NoTangent(), indices)...)
1414
end
1515
VA[indices...],AbstractVectorOfArray_getindex_adjoint
1616
end
@@ -19,7 +19,7 @@ function ChainRulesCore.rrule(::Type{<:ArrayPartition}, x::S, ::Type{Val{copy_x}
1919
function ArrayPartition_adjoint(_y)
2020
y = Array(_y)
2121
starts = vcat(0,cumsum(reduce(vcat,length.(x))))
22-
NoTangent(), ntuple(i -> reshape(y[starts[i]+1:starts[i+1]], size(x[i])), length(x)), NoTangent()
22+
NoTangent(), ArrayPartition(ntuple(i -> reshape(y[starts[i]+1:starts[i+1]], size(x[i]))), length(x)), NoTangent()
2323
end
2424

2525
ArrayPartition(x, Val{copy_x}), ArrayPartition_adjoint
@@ -43,23 +43,33 @@ function ChainRulesCore.rrule(::typeof(getproperty),A::ArrayPartition, s::Symbol
4343
A.x,literal_ArrayPartition_x_adjoint
4444
end
4545

46+
# Define a new species of projection operator for this type:
47+
ChainRulesCore.ProjectTo(x::VectorOfArray) = ChainRulesCore.ProjectTo{VectorOfArray}()
48+
49+
# Gradient from iteration will be e.g. Vector{Vector}, this makes it another AbstractMatrix
50+
#(::ChainRulesCore.ProjectTo{VectorOfArray})(dx::AbstractVector{<:AbstractArray}) = VectorOfArray(dx)
51+
# Gradient from broadcasting will be another AbstractArray
52+
#(::ChainRulesCore.ProjectTo{VectorOfArray})(dx::AbstractArray) = dx
53+
54+
# These rules duplicate the `rrule` methods above, because Zygote looks for an `@adjoint`
55+
# definition first, and finds its own before finding those.
56+
4657
ZygoteRules.@adjoint function getindex(VA::AbstractVectorOfArray, i::Union{Int,AbstractArray{Int},CartesianIndex,Colon,BitArray,AbstractArray{Bool}})
4758
function AbstractVectorOfArray_getindex_adjoint(Δ)
48-
Δ′ = [ (i == j ? Δ : zero(x)) for (x,j) in zip(VA.u, 1:length(VA))]
49-
(Δ′,nothing)
59+
Δ′ = [(i == j ? Δ : Fill(zero(eltype(x)),size(x))) for (x,j) in zip(VA.u, 1:length(VA))]
60+
(VectorOfArray(Δ′),nothing)
5061
end
5162
VA[i],AbstractVectorOfArray_getindex_adjoint
5263
end
5364

5465
ZygoteRules.@adjoint function getindex(VA::AbstractVectorOfArray, i::Union{Int,AbstractArray{Int},CartesianIndex,Colon,BitArray,AbstractArray{Bool}}, j::Union{Int,AbstractArray{Int},CartesianIndex,Colon,BitArray,AbstractArray{Bool}}...)
5566
function AbstractVectorOfArray_getindex_adjoint(Δ)
56-
Δ′ = zero(VA)
57-
Δ′[i,j...] = Δ
58-
(Δ′, i,map(_ -> nothing, j)...)
67+
Δ′ = [(i == j ? zero(x) : Fill(zero(eltype(x)),size(x))) for (x,j) in zip(VA.u, 1:length(VA))]
68+
Δ′[i][j...] = Δ
69+
(VectorOfArray(Δ′), nothing, map(_ -> nothing, j)...)
5970
end
6071
VA[i,j...],AbstractVectorOfArray_getindex_adjoint
6172
end
62-
6373
ZygoteRules.@adjoint function ArrayPartition(x::S, ::Type{Val{copy_x}} = Val{false}) where {S<:Tuple,copy_x}
6474
function ArrayPartition_adjoint(_y)
6575
y = Array(_y)
@@ -71,11 +81,11 @@ ZygoteRules.@adjoint function ArrayPartition(x::S, ::Type{Val{copy_x}} = Val{fal
7181
end
7282

7383
ZygoteRules.@adjoint function VectorOfArray(u)
74-
VectorOfArray(u),y -> ([y[ntuple(x->Colon(),ndims(y)-1)...,i] for i in 1:size(y)[end]],)
84+
VectorOfArray(u),y -> (VectorOfArray([y[ntuple(x->Colon(),ndims(y)-1)...,i] for i in 1:size(y)[end]]),)
7585
end
7686

7787
ZygoteRules.@adjoint function DiffEqArray(u,t)
78-
DiffEqArray(u,t),y -> ([y[ntuple(x->Colon(),ndims(y)-1)...,i] for i in 1:size(y)[end]],nothing)
88+
DiffEqArray(u,t),y -> (DiffEqArray([y[ntuple(x->Colon(),ndims(y)-1)...,i] for i in 1:size(y)[end]],t),nothing)
7989
end
8090

8191
ZygoteRules.@adjoint function ZygoteRules.literal_getproperty(A::ArrayPartition, ::Val{:x})

0 commit comments

Comments
 (0)