@@ -46,7 +46,7 @@ function Base.reducedim_init(::typeof(identity), ::typeof(accum), A::AbstractArr
4646end
4747
4848trim (x, Δ) = reshape (Δ, ntuple (i -> size (Δ, i), Val (ndims (x))))
49- trim (x:: Tuple , Δ) = ntuple (k -> Δ[k], length (x))
49+ trim (x:: Tuple , Δ) = NTuple { length(x)} (Δ )
5050
5151unbroadcast (x:: AbstractArray , x̄) =
5252 size (x) == size (x̄) ? x̄ :
@@ -75,23 +75,17 @@ unbroadcast(x::AbstractArray, x̄::Nothing) = nothing
7575
7676@adjoint broadcasted (:: typeof (* ), x:: Numeric , y:: Numeric ) = x.* y,
7777 Δ -> (nothing , unbroadcast (x, Δ .* conj .(y)), unbroadcast (y, Δ .* conj .(x)))
78- @adjoint function broadcasted (:: typeof (* ), x:: Number , y:: AbstractArray{<:Number} )
79- z, back = pullback (* , x, y) # this uses dot(y,Δ) instead of Δ .* conj.(y)
80- z, Δ -> (nothing , back (Δ)... )
81- end
82- @adjoint function broadcasted (:: typeof (* ), x:: AbstractArray{<:Number} , y:: Number )
83- z, back = pullback (* , x, y)
84- z, Δ -> (nothing , back (Δ)... )
85- end
78+ @adjoint broadcasted (:: typeof (* ), x:: Number , y:: AbstractArray{<:Number} ) =
79+ _pullback (* , x, y) # this uses dot(y,Δ) instead of sum(Δ .* conj.(y))
80+ @adjoint broadcasted (:: typeof (* ), x:: AbstractArray{<:Number} , y:: Number ) =
81+ _pullback (* , x, y)
8682
8783@adjoint function broadcasted (:: typeof (/ ), x:: Numeric , y:: Numeric )
8884 res = x ./ y
8985 res, Δ -> (nothing , unbroadcast (x, Δ ./ conj .(y)), unbroadcast (y, .- Δ .* conj .(res ./ y)))
9086end
91- @adjoint function broadcasted (:: typeof (/ ), x:: AbstractArray{<:Number} , y:: Number )
92- z, back = pullback (/ , x, y)
93- z, Δ -> (nothing , back (Δ)... )
94- end
87+ @adjoint broadcasted (:: typeof (/ ), x:: AbstractArray{<:Number} , y:: Number ) =
88+ _pullback (/ , x, y)
9589
9690@adjoint function broadcasted (:: typeof (Base. literal_pow), :: typeof (^ ), x:: Numeric , exp:: Val{p} ) where p
9791 y = Base. literal_pow .(^ , x, exp)
@@ -106,10 +100,10 @@ end
106100end
107101
108102@adjoint broadcasted (:: typeof (conj), x:: Numeric ) =
109- conj . (x), z̄ -> (nothing , conj . (z̄))
103+ conj (x), z̄ -> (nothing , conj (z̄))
110104
111105@adjoint broadcasted (:: typeof (real), x:: Numeric ) =
112- real . (x), z̄ -> (nothing , real . (z̄))
106+ real (x), z̄ -> (nothing , real (z̄))
113107
114108@adjoint broadcasted (:: typeof (imag), x:: Numeric ) =
115109 imag .(x), z̄ -> (nothing , im .* real .(z̄))
@@ -180,10 +174,9 @@ _dual_safearg(x) = false
180174 T = Broadcast. combine_eltypes (f, args)
181175 # Avoid generic broadcasting in two easy cases:
182176 if T == Bool
183- return f .(args... ), _-> nothing
177+ return ( f .(args... ), _ -> nothing )
184178 elseif T <: Real && isconcretetype (T) && _dual_purefun (F) && all (_dual_safearg, args)
185- y, back = broadcast_forward (f, args... )
186- return y, ȳ -> (nothing , nothing , back (ȳ)... )
179+ return broadcast_forward (f, args... )
187180 end
188181 len = inclen (args)
189182 y∂b = _broadcast ((x... ) -> _pullback (__context__, f, x... ), args... )
@@ -195,7 +188,7 @@ _dual_safearg(x) = false
195188 end
196189 (nothing , accum_sum (dxs[1 ]), map (unbroadcast, args, Base. tail (dxs))... )
197190 end
198- y, ∇broadcasted
191+ return y, ∇broadcasted
199192end
200193
201194@adjoint function broadcasted (:: AbstractArrayStyle{0} , f, args... )
@@ -226,6 +219,7 @@ using ForwardDiff: Dual
226219
227220dual (x, p) = x
228221dual (x:: Real , p) = Dual (x, p)
222+ dual (x:: Bool , p) = x
229223
230224function dual_function (f:: F ) where F
231225 function (args:: Vararg{Any,N} ) where N
@@ -237,28 +231,36 @@ function dual_function(f::F) where F
237231end
238232
239233@inline function broadcast_forward (f, args:: Vararg{Any,N} ) where N
240- T = Broadcast . combine_eltypes (f, args )
234+ valN = Val (N )
241235 out = dual_function (f).(args... )
242236 eltype (out) <: Dual || return (out, _ -> nothing )
243237 y = map (x -> x. value, out)
244- _back (ȳ, i) = unbroadcast (args[i], ((a, b) -> a* b. partials[i]). (ȳ, out))
245- back (ȳ) = ntuple (i -> _back (ȳ, i), N)
246- return y, back
238+ function bc_fwd_back (ȳ)
239+ dargs = ntuple (valN) do i
240+ unbroadcast (args[i], broadcast ((y1, o1) -> y1 * o1. partials[i], ȳ, out))
241+ end
242+ (nothing , nothing , dargs... ) # nothings for broadcasted & f
243+ end
244+ return y, bc_fwd_back
247245end
248246
249247@init @require CUDA= " 052768ef-5323-5732-b1bb-66c8b64840ba" begin
248+
250249 const CuArrayStyle = CUDA. AbstractGPUArrayStyle
251250
252- if isdefined (CUDA, :cufunc )
253- @eval @adjoint function broadcasted (:: CuArrayStyle , f, args... )
254- y, back = broadcast_forward (CUDA. cufunc (f), args... )
255- y, ȳ -> (nothing , nothing , back (ȳ)... )
256- end
257- else # CUDA >= 3.0
258- @eval @adjoint function broadcasted (:: CuArrayStyle , f, args... )
259- y, back = broadcast_forward (f, args... )
260- y, ȳ -> (nothing , nothing , back (ȳ)... )
261- end
251+ if isdefined (CUDA, :cufunc ) # CUDA < 3.0
252+
253+ @eval @adjoint broadcasted (:: CuArrayStyle , f, args... ) =
254+ broadcast_forward (CUDA. cufunc (f), args... )
255+
256+ else # CUDA >= 3.0 -- don't need cufunc(f).
257+ # Ordinary broadcasting calls broadcast_forward anyway when certain its' safe,
258+ # so perhaps this can be deleted? Possible edge case here:
259+ # https://github.com/FluxML/Zygote.jl/pull/1018#issuecomment-873629415
260+
261+ @eval @adjoint broadcasted (:: CuArrayStyle , f, args... ) =
262+ broadcast_forward (f, args... )
263+
262264 end
263265
264266 @adjoint CUDA. CuArray {N,T} (xs:: Array ) where {N,T} =
0 commit comments