Skip to content

Commit 791f25d

Browse files
committed
small Improvements based on code review
1 parent 353361e commit 791f25d

File tree

2 files changed

+8
-4
lines changed

2 files changed

+8
-4
lines changed

src/accumulation.jl

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ function add!!(x, t::InplaceableThunk)
1717
if !debug_mode()
1818
t.add!(x)
1919
else
20-
return debug_add!(x, t)
20+
debug_add!(x, t)
2121
end
2222
else
2323
x + t
@@ -34,7 +34,7 @@ end
3434

3535

3636
"""
37-
is_inplaceable_destination(x)
37+
is_inplaceable_destination(x) -> Bool
3838
3939
Returns true if `x` is suitable for for storing inplace accumulation of gradients.
4040
For arrays this boils down `x .= y` if will work to mutate `x`, if `y` is an appropriate
@@ -63,13 +63,16 @@ function debug_add!(accumuland, t::InplaceableThunk)
6363
end
6464
return returned_value
6565
end
66+
6667
struct BadInplaceException <: Exception
67-
t::InplaceableThunk
68+
ithunk::InplaceableThunk
6869
accumuland
6970
returned_value
7071
end
72+
7173
function Base.showerror(io::IO, err::BadInplaceException)
72-
println(io, "add! of $(err.t) did not return an updated accumuland.")
74+
println(io, "`add!!(accumuland, ithunk))` did not return an updated accumuland.")
75+
println(io, "ithunk = $(err.ithunk)")
7376
println(io, "accumuland = $(err.accumuland)")
7477
println(io, "returned_value = $(err.returned_value)")
7578

test/accumulation.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,7 @@
123123
ChainRulesCore.debug_mode() = false # disable it again
124124
end
125125
end
126+
126127
@testset "showerror BadInplaceException" begin
127128
BadInplaceException = ChainRulesCore.BadInplaceException
128129
ithunk = InplaceableThunk(@thunk(@assert false), x̄->nothing)

0 commit comments

Comments
 (0)