@@ -253,43 +253,32 @@ end
253253 return y, bc_fwd_back
254254end
255255
256- @init @require CUDA= " 052768ef-5323-5732-b1bb-66c8b64840ba " begin
256+ using GPUArraysCore # replaces @require CUDA block, weird indenting to preserve git blame
257257
258- const CuArrayStyle = CUDA. AbstractGPUArrayStyle
259-
260- if isdefined (CUDA, :cufunc ) # CUDA < 3.0
261-
262- @eval @adjoint broadcasted (:: CuArrayStyle , f, args... ) =
263- broadcast_forward (CUDA. cufunc (f), args... )
264-
265- else # CUDA >= 3.0 -- don't need cufunc(f).
266258 # Ordinary broadcasting calls broadcast_forward anyway when certain its' safe,
267259 # so perhaps this can be deleted? Possible edge case here:
268260 # https://github.com/FluxML/Zygote.jl/pull/1018#issuecomment-873629415
261+ @adjoint broadcasted (:: AbstractGPUArrayStyle , f, args... ) =
262+ broadcast_forward (f, args... )
269263
270- @eval @adjoint broadcasted (:: CuArrayStyle , f, args... ) =
271- broadcast_forward (f, args... )
272-
273- end
274-
275- @adjoint (:: Type{T} )(xs:: Array ) where {T <: CUDA.CuArray } =
264+ @adjoint (:: Type{T} )(xs:: Array ) where {T <: AbstractGPUArray } =
276265 T (xs), Δ -> (convert (Array, Δ), )
277266
278- @adjoint function sum (xs:: CUDA. AbstractGPUArray ; dims = :)
267+ @adjoint function sum (xs:: AbstractGPUArray ; dims = :)
279268 placeholder = similar (xs)
280269 sum (xs, dims = dims), Δ -> (placeholder .= Δ,)
281270 end
282271
283272 # Make sure sum(f, ::CuArray) uses broadcase through forward-mode defined above
284273 # Not the ChainRules.rrule which will use the Zygote.Context and thus not be GPU compatible
285- @adjoint function sum (f, xs:: CUDA. AbstractGPUArray ; kws... )
274+ @adjoint function sum (f, xs:: AbstractGPUArray ; kws... )
286275 @assert ! haskey (kws, :init ) # TODO add init support (julia 1.6)
287276 return pullback (__context__, (f, xs) -> sum (f .(xs); kws... ), f, xs)
288277 end
289278
290- @adjoint function Base. convert (:: Type{T} , xs:: Array ) where {T<: CUDA. AbstractGPUArray }
279+ @adjoint function Base. convert (:: Type{T} , xs:: Array ) where {T<: AbstractGPUArray }
291280 Base. convert (T, xs), Δ -> (nothing , Base. convert (Array, Δ),)
292281 end
293282
294- @eval pull_block_vert (sz, Δ:: CUDA.CuArray , A:: Number ) = CUDA . @allowscalar Δ[sz]
295- end
283+ pull_block_vert (sz, Δ:: AbstractGPUArray , A:: Number ) = @allowscalar Δ[sz]
284+
0 commit comments