Skip to content

Commit 0cba74d

Browse files
authored
Merge pull request #1018 from mcabbott/broadcast_tidy
Tidy up a few things in broadcasting
2 parents ddf860c + 687adbc commit 0cba74d

File tree

3 files changed

+44
-35
lines changed

3 files changed

+44
-35
lines changed

src/lib/array.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ using Distributed: pmap, AbstractWorkerPool
88
@adjoint Array(xs::AbstractArray) = Array(xs), ȳ -> (ȳ,)
99
@adjoint Array(xs::Array) = Array(xs), ȳ -> (ȳ,)
1010

11-
@nograd ones, zeros, Base.OneTo, Colon(), one, zero, sizehint!
11+
@nograd ones, zeros, Base.OneTo, Colon(), one, zero, sizehint!, count
1212

1313
@adjoint Base.vect(xs...) = Base.vect(xs...), Δ ->...,)
1414

src/lib/broadcast.jl

Lines changed: 35 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ function Base.reducedim_init(::typeof(identity), ::typeof(accum), A::AbstractArr
4646
end
4747

4848
trim(x, Δ) = reshape(Δ, ntuple(i -> size(Δ, i), Val(ndims(x))))
49-
trim(x::Tuple, Δ) = ntuple(k -> Δ[k], length(x))
49+
trim(x::Tuple, Δ) = NTuple{length(x)})
5050

5151
unbroadcast(x::AbstractArray, x̄) =
5252
size(x) == size(x̄) ?:
@@ -75,23 +75,17 @@ unbroadcast(x::AbstractArray, x̄::Nothing) = nothing
7575

7676
@adjoint broadcasted(::typeof(*), x::Numeric, y::Numeric) = x.*y,
7777
Δ -> (nothing, unbroadcast(x, Δ .* conj.(y)), unbroadcast(y, Δ .* conj.(x)))
78-
@adjoint function broadcasted(::typeof(*), x::Number, y::AbstractArray{<:Number})
79-
z, back = pullback(*, x, y) # this uses dot(y,Δ) instead of Δ .* conj.(y)
80-
z, Δ -> (nothing, back(Δ)...)
81-
end
82-
@adjoint function broadcasted(::typeof(*), x::AbstractArray{<:Number}, y::Number)
83-
z, back = pullback(*, x, y)
84-
z, Δ -> (nothing, back(Δ)...)
85-
end
78+
@adjoint broadcasted(::typeof(*), x::Number, y::AbstractArray{<:Number}) =
79+
_pullback(*, x, y) # this uses dot(y,Δ) instead of sum(Δ .* conj.(y))
80+
@adjoint broadcasted(::typeof(*), x::AbstractArray{<:Number}, y::Number) =
81+
_pullback(*, x, y)
8682

8783
@adjoint function broadcasted(::typeof(/), x::Numeric, y::Numeric)
8884
res = x ./ y
8985
res, Δ -> (nothing, unbroadcast(x, Δ ./ conj.(y)), unbroadcast(y, .-Δ .* conj.(res ./ y)))
9086
end
91-
@adjoint function broadcasted(::typeof(/), x::AbstractArray{<:Number}, y::Number)
92-
z, back = pullback(/, x, y)
93-
z, Δ -> (nothing, back(Δ)...)
94-
end
87+
@adjoint broadcasted(::typeof(/), x::AbstractArray{<:Number}, y::Number) =
88+
_pullback(/, x, y)
9589

9690
@adjoint function broadcasted(::typeof(Base.literal_pow), ::typeof(^), x::Numeric, exp::Val{p}) where p
9791
y = Base.literal_pow.(^, x, exp)
@@ -106,10 +100,10 @@ end
106100
end
107101

108102
@adjoint broadcasted(::typeof(conj), x::Numeric) =
109-
conj.(x), z̄ -> (nothing, conj.(z̄))
103+
conj(x), z̄ -> (nothing, conj(z̄))
110104

111105
@adjoint broadcasted(::typeof(real), x::Numeric) =
112-
real.(x), z̄ -> (nothing, real.(z̄))
106+
real(x), z̄ -> (nothing, real(z̄))
113107

114108
@adjoint broadcasted(::typeof(imag), x::Numeric) =
115109
imag.(x), z̄ -> (nothing, im .* real.(z̄))
@@ -180,10 +174,9 @@ _dual_safearg(x) = false
180174
T = Broadcast.combine_eltypes(f, args)
181175
# Avoid generic broadcasting in two easy cases:
182176
if T == Bool
183-
return f.(args...), _->nothing
177+
return (f.(args...), _ -> nothing)
184178
elseif T <: Real && isconcretetype(T) && _dual_purefun(F) && all(_dual_safearg, args)
185-
y, back = broadcast_forward(f, args...)
186-
return y, ȳ -> (nothing, nothing, back(ȳ)...)
179+
return broadcast_forward(f, args...)
187180
end
188181
len = inclen(args)
189182
y∂b = _broadcast((x...) -> _pullback(__context__, f, x...), args...)
@@ -195,7 +188,7 @@ _dual_safearg(x) = false
195188
end
196189
(nothing, accum_sum(dxs[1]), map(unbroadcast, args, Base.tail(dxs))...)
197190
end
198-
y, ∇broadcasted
191+
return y, ∇broadcasted
199192
end
200193

201194
@adjoint function broadcasted(::AbstractArrayStyle{0}, f, args...)
@@ -226,6 +219,7 @@ using ForwardDiff: Dual
226219

227220
dual(x, p) = x
228221
dual(x::Real, p) = Dual(x, p)
222+
dual(x::Bool, p) = x
229223

230224
function dual_function(f::F) where F
231225
function (args::Vararg{Any,N}) where N
@@ -237,28 +231,36 @@ function dual_function(f::F) where F
237231
end
238232

239233
@inline function broadcast_forward(f, args::Vararg{Any,N}) where N
240-
T = Broadcast.combine_eltypes(f, args)
234+
valN = Val(N)
241235
out = dual_function(f).(args...)
242236
eltype(out) <: Dual || return (out, _ -> nothing)
243237
y = map(x -> x.value, out)
244-
_back(ȳ, i) = unbroadcast(args[i], ((a, b) -> a*b.partials[i]).(ȳ, out))
245-
back(ȳ) = ntuple(i -> _back(ȳ, i), N)
246-
return y, back
238+
function bc_fwd_back(ȳ)
239+
dargs = ntuple(valN) do i
240+
unbroadcast(args[i], broadcast((y1, o1) -> y1 * o1.partials[i], ȳ, out))
241+
end
242+
(nothing, nothing, dargs...) # nothings for broadcasted & f
243+
end
244+
return y, bc_fwd_back
247245
end
248246

249247
@init @require CUDA="052768ef-5323-5732-b1bb-66c8b64840ba" begin
248+
250249
const CuArrayStyle = CUDA.AbstractGPUArrayStyle
251250

252-
if isdefined(CUDA, :cufunc)
253-
@eval @adjoint function broadcasted(::CuArrayStyle, f, args...)
254-
y, back = broadcast_forward(CUDA.cufunc(f), args...)
255-
y, ȳ -> (nothing, nothing, back(ȳ)...)
256-
end
257-
else # CUDA >= 3.0
258-
@eval @adjoint function broadcasted(::CuArrayStyle, f, args...)
259-
y, back = broadcast_forward(f, args...)
260-
y, ȳ -> (nothing, nothing, back(ȳ)...)
261-
end
251+
if isdefined(CUDA, :cufunc) # CUDA < 3.0
252+
253+
@eval @adjoint broadcasted(::CuArrayStyle, f, args...) =
254+
broadcast_forward(CUDA.cufunc(f), args...)
255+
256+
else # CUDA >= 3.0 -- don't need cufunc(f).
257+
# Ordinary broadcasting calls broadcast_forward anyway when certain its' safe,
258+
# so perhaps this can be deleted? Possible edge case here:
259+
# https://github.com/FluxML/Zygote.jl/pull/1018#issuecomment-873629415
260+
261+
@eval @adjoint broadcasted(::CuArrayStyle, f, args...) =
262+
broadcast_forward(f, args...)
263+
262264
end
263265

264266
@adjoint CUDA.CuArray{N,T}(xs::Array) where {N,T} =

test/cuda.jl

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ CUDA.allowscalar(false)
99
@test gradient(x -> sum(cu(x)), r)[1] isa Array{Float32, 2}
1010
end
1111

12-
@testset "basic bcasting" begin
12+
@testset "broadcasting" begin
1313
a = Float32.(1:9)
1414
a_gpu = a |> cu
1515

@@ -24,6 +24,13 @@ end
2424
g_gpu = gradient(x -> w(x), a_gpu)[1]
2525
@test g_gpu isa CuArray
2626
@test g_gpu |> collect g
27+
28+
# https://github.com/FluxML/Zygote.jl/issues/1027 # status on Zygote v0.6.14, CUDA v3.3.0 in comments:
29+
@test gradient(x -> sum(x .!= 0), a_gpu) == (nothing,) # was MethodError: no method matching iterate(::Nothing)
30+
@test gradient(x -> sum(x .> 3), a_gpu) == (nothing,)
31+
g3 = gradient(x -> sum(x .^ 3) / count(x .> 3), a)[1] # was Can't differentiate gc_preserve_end expression
32+
@test_skip cu(g3) gradient(x -> sum(x .^ 3) / sum(x .> 3), a_gpu)[1] # was KernelException -- not fixed by PR #1018
33+
@test cu(g3) gradient(x -> sum(x .^ 3) / count(x .> 3), a_gpu)[1]
2734
end
2835

2936
@testset "sum(f, x)" begin

0 commit comments

Comments
 (0)