Skip to content

Commit 19862d1

Browse files
committed
comments
1 parent 2cca8fc commit 19862d1

File tree

2 files changed

+7
-4
lines changed

2 files changed

+7
-4
lines changed

src/lib/broadcast.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -253,7 +253,10 @@ end
253253
@eval @adjoint broadcasted(::CuArrayStyle, f, args...) =
254254
broadcast_forward(CUDA.cufunc(f), args...)
255255

256-
else # CUDA >= 3.0 -- don't need cufunc(f), and ordinary broadcasting calls broadcast_forward when safe
256+
else # CUDA >= 3.0 -- don't need cufunc(f).
257+
# Ordinary broadcasting calls broadcast_forward anyway when certain its' safe,
258+
# so perhaps this can be deleted? Possible edge case here:
259+
# https://github.com/FluxML/Zygote.jl/pull/1018#issuecomment-873629415
257260

258261
@eval @adjoint broadcasted(::CuArrayStyle, f, args...) =
259262
broadcast_forward(f, args...)

test/cuda.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,9 @@ end
2626
@test g_gpu |> collect g
2727

2828
# https://github.com/FluxML/Zygote.jl/issues/1027
29-
@test gradient(x -> sum(x .!= 0), a_gpu) == (nothing,)
30-
g3 = gradient(x -> sum(x .^ 3) / count(x .> 3), a)[1]
31-
@test cu(g3) gradient(x -> sum(x .^ 3) / sum(x .> 3), a_gpu)[1]
29+
@test gradient(x -> sum(x .!= 0), a_gpu) == (nothing,) # was MethodError: no method matching iterate(::Nothing)
30+
g3 = gradient(x -> sum(x .^ 3) / count(x .> 3), a)[1] # was Can't differentiate gc_preserve_end expression
31+
@test cu(g3) gradient(x -> sum(x .^ 3) / sum(x .> 3), a_gpu)[1] # was KernelException -- Zygote v0.6.14, CUDA v3.3.0
3232
end
3333

3434
@testset "sum(f, x)" begin

0 commit comments

Comments
 (0)