Skip to content

Commit 4028afc

Browse files
Make thunked unthunk before broadcasting (#97)
* Make thunked unthunk (not extern) before broadcasting * bump version * Fix spacing Co-Authored-By: Nick Robinson <npr251@gmail.com> Co-authored-by: Nick Robinson <npr251@gmail.com>
1 parent c782e40 commit 4028afc

File tree

3 files changed

+45
-9
lines changed

3 files changed

+45
-9
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "ChainRulesCore"
22
uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
3-
version = "0.5.3"
3+
version = "0.5.4"
44

55
[deps]
66
MuladdMacro = "46d2c3a1-f734-5fdb-9937-b9b9aeba4221"

src/differentials/thunks.jl

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,17 @@
11

22
abstract type AbstractThunk <: AbstractDifferential end
33

4-
Base.Broadcast.broadcastable(x::AbstractThunk) = broadcastable(extern(x))
4+
Base.Broadcast.broadcastable(x::AbstractThunk) = broadcastable(unthunk(x))
55

66
@inline function Base.iterate(x::AbstractThunk)
7-
externed = extern(x)
8-
element, state = iterate(externed)
9-
return element, (externed, state)
7+
val = unthunk(x)
8+
element, state = iterate(val)
9+
return element, (val, state)
1010
end
1111

12-
@inline function Base.iterate(::AbstractThunk, (externed, state))
13-
element, new_state = iterate(externed, state)
14-
return element, (externed, new_state)
12+
@inline function Base.iterate(::AbstractThunk, (val, state))
13+
element, new_state = iterate(val, state)
14+
return element, (val, new_state)
1515
end
1616

1717
#####
@@ -62,7 +62,7 @@ Do not use `@thunk` if this would be equal or more work than actually evaluating
6262
- The expression is merely wrapping something in a `struct`, such as `Adjoint(x)` or `Diagonal(x)`
6363
- The expression being itself a `thunk`
6464
- The expression being from another `rrule` or `frule` (it would be `@thunk`ed if required by the defining rule already)
65-
- There is only one derivative being returned, so from the fact that the user called `frule`/`rrule`
65+
- There is only one derivative being returned, so from the fact that the user called `frule`/`rrule`
6666
they clearly will want to use that one.
6767
"""
6868
struct Thunk{F} <: AbstractThunk

test/differentials/thunks.jl

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,4 +35,40 @@
3535
@test stackframe.file == Symbol(@__FILE__)
3636
end
3737
end
38+
39+
40+
@testset "Broadcast" begin
41+
@testset "Array" begin
42+
was_unthunked = 0
43+
array_thunk = @thunk begin
44+
was_unthunked += 1;
45+
[1.0, 2.0, 3.0]
46+
end
47+
48+
was_unthunked = 0
49+
@test array_thunk .+ fill(10, 3) .+ fill(10, 3) == [21.0, 22.0, 23.0]
50+
@test was_unthunked == 1
51+
52+
was_unthunked = 0
53+
@test array_thunk .+ 10.0 .+ 10.0 == [21.0, 22.0, 23.0]
54+
@test was_unthunked == 1
55+
56+
end
57+
58+
@testset "Scalar" begin
59+
was_unthunked=0
60+
scalar_thunk = @thunk begin
61+
was_unthunked += 1;
62+
sqrt(4.0)
63+
end
64+
65+
was_unthunked = 0
66+
@test scalar_thunk .+ fill(10, 3) .+ fill(10, 3) == [22.0, 22.0, 22.0]
67+
@test was_unthunked == 1
68+
69+
was_unthunked = 0
70+
@test scalar_thunk .+ 10.0 .+ 10.0 == 22.0
71+
@test was_unthunked == 1
72+
end
73+
end
3874
end

0 commit comments

Comments
 (0)