Skip to content

Commit 353361e

Browse files
committed
Add debug mode for bad inplace
1 parent 3e7cf88 commit 353361e

File tree

3 files changed

+60
-1
lines changed

3 files changed

+60
-1
lines changed

docs/src/debug_mode.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,3 +11,8 @@ To enable, redefine the [`ChainRulesCore.debug_mode`](@ref) function to return `
1111
```julia
1212
ChainRulesCore.debug_mode() = true
1313
```
14+
15+
## Features of Debug Mode:
16+
17+
- If you add a `Composite` to a primal value, and it was unable to construct a new primal values, then a better error message will be displayed detailing what overloads need to be written to fix this.
18+
- during [`add!!`](@ref), if an `InplaceThunk` is used, and it runs the code that is supposed to run in place, but the return result is not the input (with updated values), then an error is thrown. Rather than silently using what ever values were returned.

src/accumulation.jl

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,11 @@ The specialization of `add!!` for [`InplaceableThunk`](@ref) promises to only ca
1414
"""
1515
function add!!(x, t::InplaceableThunk)
1616
return if is_inplaceable_destination(x)
17-
t.add!(x)
17+
if !debug_mode()
18+
t.add!(x)
19+
else
20+
return debug_add!(x, t)
21+
end
1822
else
1923
x + t
2024
end
@@ -51,3 +55,28 @@ function is_inplaceable_destination(x::AbstractArray)
5155
# processing and so are mutable if their `parent` is.
5256
return is_inplaceable_destination(p)
5357
end
58+
59+
function debug_add!(accumuland, t::InplaceableThunk)
60+
returned_value = t.add!(accumuland)
61+
if returned_value !== accumuland
62+
throw(BadInplaceException(t, accumuland, returned_value))
63+
end
64+
return returned_value
65+
end
66+
struct BadInplaceException <: Exception
67+
t::InplaceableThunk
68+
accumuland
69+
returned_value
70+
end
71+
function Base.showerror(io::IO, err::BadInplaceException)
72+
println(io, "add! of $(err.t) did not return an updated accumuland.")
73+
println(io, "accumuland = $(err.accumuland)")
74+
println(io, "returned_value = $(err.returned_value)")
75+
76+
if err.accumuland == err.returned_value
77+
println(
78+
io,
79+
"Which in this case happenned to be equal. But they are not the same object."
80+
)
81+
end
82+
end

test/accumulation.jl

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,5 +106,30 @@
106106
@test accumuland == [1.0 2.0; 3.0 4.0] # must not have mutated
107107
end
108108
end
109+
110+
@testset "not actually inplace but said it was" begin
111+
ithunk = InplaceableThunk(
112+
@thunk(@assert false), # this should never be used in this test
113+
x -> 77*ones(2, 2) # not actually inplace (also wrong)
114+
)
115+
accumuland = ones(2, 2)
116+
@assert ChainRulesCore.debug_mode() == false
117+
# without debug being enabled should return the result, not error
118+
@test 77*ones(2, 2) == add!!(accumuland, ithunk)
119+
120+
ChainRulesCore.debug_mode() = true # enable debug mode
121+
# with debug being enabled should error
122+
@test_throws ChainRulesCore.BadInplaceException add!!(accumuland, ithunk)
123+
ChainRulesCore.debug_mode() = false # disable it again
124+
end
125+
end
126+
@testset "showerror BadInplaceException" begin
127+
BadInplaceException = ChainRulesCore.BadInplaceException
128+
ithunk = InplaceableThunk(@thunk(@assert false), x̄->nothing)
129+
msg = sprint(showerror, BadInplaceException(ithunk, [22], [23]))
130+
@test occursin("22", msg)
131+
132+
msg_equal = sprint(showerror, BadInplaceException(ithunk, [22], [22]))
133+
@test occursin("equal", msg_equal)
109134
end
110135
end

0 commit comments

Comments
 (0)