Skip to content

Commit 5ffbd43

Browse files
authored
Replace @require CUDA with using GPUArraysCore (#1272)
* require GPUArrays instead of CUDA * more * change to unconditionally load GPUArraysCore * add GPUArrays dep * trivial trigger commit
1 parent 995778d commit 5ffbd43

File tree

2 files changed

+13
-20
lines changed

2 files changed

+13
-20
lines changed

Project.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@ DiffRules = "b552c78f-8df3-52c6-915a-8e097449b14b"
1010
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
1111
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
1212
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
13+
GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7" # not loaded, just a version bound
14+
GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
1315
IRTools = "7869d1d1-7146-5819-86e3-90919afe41df"
1416
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
1517
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
@@ -31,6 +33,8 @@ ChainRulesTestUtils = "1"
3133
DiffRules = "1.4"
3234
FillArrays = "0.8, 0.9, 0.10, 0.11, 0.12, 0.13"
3335
ForwardDiff = "0.10"
36+
GPUArrays = "8.4.2" # not loaded, just a version bound
37+
GPUArraysCore = "0.1.1"
3438
IRTools = "0.4.4"
3539
LogExpFunctions = "0.3.1"
3640
MacroTools = "0.5"

src/lib/broadcast.jl

Lines changed: 9 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -253,43 +253,32 @@ end
253253
return y, bc_fwd_back
254254
end
255255

256-
@init @require CUDA="052768ef-5323-5732-b1bb-66c8b64840ba" begin
256+
using GPUArraysCore # replaces @require CUDA block, weird indenting to preserve git blame
257257

258-
const CuArrayStyle = CUDA.AbstractGPUArrayStyle
259-
260-
if isdefined(CUDA, :cufunc) # CUDA < 3.0
261-
262-
@eval @adjoint broadcasted(::CuArrayStyle, f, args...) =
263-
broadcast_forward(CUDA.cufunc(f), args...)
264-
265-
else # CUDA >= 3.0 -- don't need cufunc(f).
266258
# Ordinary broadcasting calls broadcast_forward anyway when certain its' safe,
267259
# so perhaps this can be deleted? Possible edge case here:
268260
# https://github.com/FluxML/Zygote.jl/pull/1018#issuecomment-873629415
261+
@adjoint broadcasted(::AbstractGPUArrayStyle, f, args...) =
262+
broadcast_forward(f, args...)
269263

270-
@eval @adjoint broadcasted(::CuArrayStyle, f, args...) =
271-
broadcast_forward(f, args...)
272-
273-
end
274-
275-
@adjoint (::Type{T})(xs::Array) where {T <: CUDA.CuArray} =
264+
@adjoint (::Type{T})(xs::Array) where {T <: AbstractGPUArray} =
276265
T(xs), Δ -> (convert(Array, Δ), )
277266

278-
@adjoint function sum(xs::CUDA.AbstractGPUArray; dims = :)
267+
@adjoint function sum(xs::AbstractGPUArray; dims = :)
279268
placeholder = similar(xs)
280269
sum(xs, dims = dims), Δ -> (placeholder .= Δ,)
281270
end
282271

283272
# Make sure sum(f, ::CuArray) uses broadcase through forward-mode defined above
284273
# Not the ChainRules.rrule which will use the Zygote.Context and thus not be GPU compatible
285-
@adjoint function sum(f, xs::CUDA.AbstractGPUArray; kws...)
274+
@adjoint function sum(f, xs::AbstractGPUArray; kws...)
286275
@assert !haskey(kws, :init) # TODO add init support (julia 1.6)
287276
return pullback(__context__, (f, xs) -> sum(f.(xs); kws...), f, xs)
288277
end
289278

290-
@adjoint function Base.convert(::Type{T}, xs::Array) where {T<:CUDA.AbstractGPUArray}
279+
@adjoint function Base.convert(::Type{T}, xs::Array) where {T<:AbstractGPUArray}
291280
Base.convert(T, xs), Δ -> (nothing, Base.convert(Array, Δ),)
292281
end
293282

294-
@eval pull_block_vert(sz, Δ::CUDA.CuArray, A::Number) = CUDA.@allowscalar Δ[sz]
295-
end
283+
pull_block_vert(sz, Δ::AbstractGPUArray, A::Number) = @allowscalar Δ[sz]
284+

0 commit comments

Comments
 (0)