Skip to content

Commit b33988e

Browse files
Merge pull request #1062 from DhairyaLGandhi/dg/cutypevar
Allow buffer typevar in CuArray type
2 parents 57adb2d + 649a6ac commit b33988e

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

src/lib/broadcast.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -263,8 +263,8 @@ end
263263

264264
end
265265

266-
@adjoint CUDA.CuArray{N,T}(xs::Array) where {N,T} =
267-
CUDA.CuArray{N,T}(xs), Δ -> (convert(Array, Δ), )
266+
@adjoint (::Type{T})(xs::Array) where {T <: CUDA.CuArray} =
267+
T(xs), Δ -> (convert(Array, Δ), )
268268

269269
@adjoint function sum(xs::CUDA.AbstractGPUArray; dims = :)
270270
placeholder = similar(xs)

0 commit comments

Comments
 (0)