Skip to content

Commit b3ed973

Browse files
fix pullbacks and use some FillArrays
1 parent c84842a commit b3ed973

File tree

3 files changed

+8
-4
lines changed

3 files changed

+8
-4
lines changed

Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ version = "2.17.2"
77
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
88
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
99
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
10+
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
1011
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1112
RecipesBase = "3cdcf5f2-1ef4-517c-9805-6587b60abb01"
1213
Requires = "ae029012-a4dd-5104-9daa-d747884805df"

src/RecursiveArrayTools.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,9 @@ using Requires, RecipesBase, StaticArrays, Statistics,
1111
import ChainRulesCore
1212
import ChainRulesCore: NoTangent
1313
import ZygoteRules
14+
15+
using FillArrays
16+
1417
abstract type AbstractVectorOfArray{T, N, A} <: AbstractArray{T, N} end
1518
abstract type AbstractDiffEqArray{T, N, A} <: AbstractVectorOfArray{T, N, A} end
1619

src/zygote.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ end
4747
ChainRulesCore.ProjectTo(x::VectorOfArray) = ChainRulesCore.ProjectTo{VectorOfArray}()
4848

4949
# Gradient from iteration will be e.g. Vector{Vector}, this makes it another AbstractMatrix
50-
(::ChainRulesCore.ProjectTo{VectorOfArray})(dx::AbstractVector{<:AbstractArray}) = VectorOfArray(dx)
50+
#(::ChainRulesCore.ProjectTo{VectorOfArray})(dx::AbstractVector{<:AbstractArray}) = VectorOfArray(dx)
5151
# Gradient from broadcasting will be another AbstractArray
5252
#(::ChainRulesCore.ProjectTo{VectorOfArray})(dx::AbstractArray) = dx
5353

@@ -56,16 +56,16 @@ ChainRulesCore.ProjectTo(x::VectorOfArray) = ChainRulesCore.ProjectTo{VectorOfAr
5656

5757
ZygoteRules.@adjoint function getindex(VA::AbstractVectorOfArray, i::Union{Int,AbstractArray{Int},CartesianIndex,Colon,BitArray,AbstractArray{Bool}})
5858
function AbstractVectorOfArray_getindex_adjoint(Δ)
59-
Δ′ = [ (i == j ? Δ : zero(x)) for (x,j) in zip(VA.u, 1:length(VA))]
59+
Δ′ = [(i == j ? Δ : Fill(zero(eltype(x)),size(x))) for (x,j) in zip(VA.u, 1:length(VA))]
6060
(VectorOfArray(Δ′),nothing)
6161
end
6262
VA[i],AbstractVectorOfArray_getindex_adjoint
6363
end
6464

6565
ZygoteRules.@adjoint function getindex(VA::AbstractVectorOfArray, i::Union{Int,AbstractArray{Int},CartesianIndex,Colon,BitArray,AbstractArray{Bool}}, j::Union{Int,AbstractArray{Int},CartesianIndex,Colon,BitArray,AbstractArray{Bool}}...)
6666
function AbstractVectorOfArray_getindex_adjoint(Δ)
67-
Δ′ = zero(VA)
68-
Δ′[i,j...] = Δ
67+
Δ′ = [(i == j ? zero(x) : Fill(zero(eltype(x)),size(x))) for (x,j) in zip(VA.u, 1:length(VA))]
68+
Δ′[i][j...] = Δ
6969
(VectorOfArray(Δ′), nothing, map(_ -> nothing, j)...)
7070
end
7171
VA[i,j...],AbstractVectorOfArray_getindex_adjoint

0 commit comments

Comments
 (0)