Skip to content

Commit 4b7524e

Browse files
committed
use ProjectTo in broadcasting, etc
1 parent 0cba74d commit 4b7524e

File tree

4 files changed

+22
-10
lines changed

4 files changed

+22
-10
lines changed

src/compiler/chainrules.jl

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,20 @@ Convert `x` from the format Zygote uses internally to differentials types ChainR
128128
ChainRules.Tangent{Any, typeof(xp)}(xp)
129129
end
130130

131+
"""
132+
_project(x)(dx)
133+
_project(x, dx)
134+
135+
The function `_project(x)` returns a projector, which standardises the gradient `dx` for type & shape.
136+
Uses `ChainRulesCore.ProjectTo`, but is safe to apply to arbitrary input.
137+
The two-argument `_project(x, dx)` applies this immediately.
138+
"""
139+
@inline _project(x) = identity # fallback: do nothing!
140+
@inline _project(x::Numeric) = wrap_chainrules_output ProjectTo(x)
141+
@inline _project(x::Ref{<:Numeric}) = wrap_chainrules_output ProjectTo(x)
142+
143+
@inline _project(x, dx) = _project(x)(dx)
144+
131145
"""
132146
ZBack{F}(back) <: Function
133147

src/compiler/interface.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,8 @@ julia> gradient([7, 11], 0, 1) do x, y, d
7373
"""
7474
function gradient(f, args...)
7575
y, back = pullback(f, args...)
76-
return back(sensitivity(y))
76+
grad = back(sensitivity(y))
77+
map(_project, args, grad)
7778
end
7879

7980
Base.adjoint(f::Function) = x -> gradient(f, x)[1]
@@ -95,7 +96,8 @@ true
9596
"""
9697
function withgradient(f, args...)
9798
y, back = pullback(f, args...)
98-
(val = y, grad = back(sensitivity(y)))
99+
grad = back(sensitivity(y))
100+
(val = y, grad = map(_project, args, grad))
99101
end
100102

101103
# Param-style wrappers

src/lib/array.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ end
4343
dxv = view(dx, inds...)
4444
dxv .= accum.(dxv, _droplike(dy, dxv))
4545
end
46-
return (dx, map(_->nothing, inds)...)
46+
return (_project(x, dx), map(_->nothing, inds)...)
4747
end
4848

4949
"""

src/lib/broadcast.jl

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -45,18 +45,14 @@ function Base.reducedim_init(::typeof(identity), ::typeof(accum), A::AbstractArr
4545
Base.reducedim_initarray(A, region, nothing, Union{Nothing,eltype(A)})
4646
end
4747

48-
trim(x, Δ) = reshape(Δ, ntuple(i -> size(Δ, i), Val(ndims(x))))
49-
trim(x::Tuple, Δ) = NTuple{length(x)}(Δ)
50-
5148
unbroadcast(x::AbstractArray, x̄) =
52-
size(x) == size(x̄) ?:
53-
length(x) == length(x̄) ? trim(x, x̄) :
54-
trim(x, accum_sum(x̄, dims = ntuple(i -> size(x, i) == 1 ? i : ndims(x̄)+1, Val(ndims(x̄)))))
49+
length(x) == length(x̄) ? _project(x, x̄) :
50+
_project(x, accum_sum(x̄, dims = ntuple(i -> size(x, i) == 1 ? i : ndims(x̄)+1, Val(ndims(x̄)))))
5551

5652
unbroadcast(x::Number, x̄) = accum_sum(x̄)
5753
unbroadcast(x::Tuple{<:Any}, x̄) = (accum_sum(x̄),)
5854
unbroadcast(x::Base.RefValue, x̄) = (x=accum_sum(x̄),)
59-
unbroadcast(x::Tuple, x̄) = trim(x, length(x) == length(x̄) ?: accum_sum(x̄; dims=2:ndims(x̄))) # case length(x) > 1
55+
unbroadcast(x::Tuple, x̄) = NTuple{length(x)}(length(x) == length(x̄) ?: accum_sum(x̄; dims=2:ndims(x̄))) # case length(x) > 1
6056

6157
unbroadcast(x::AbstractArray, x̄::Nothing) = nothing
6258

0 commit comments

Comments
 (0)