Skip to content

Commit bf01ddf

Browse files
authored
fix iteration of thunks (#370)
1 parent 7947bed commit bf01ddf

File tree

3 files changed

+14
-4
lines changed

3 files changed

+14
-4
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "ChainRulesCore"
22
uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
3-
version = "0.10.4"
3+
version = "0.10.5"
44

55
[deps]
66
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"

src/differentials/thunks.jl

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,11 @@ Base.Broadcast.broadcastable(x::AbstractThunk) = broadcastable(unthunk(x))
88
return element, (val, state)
99
end
1010

11-
@inline function Base.iterate(::AbstractThunk, (val, state))
12-
element, new_state = iterate(val, state)
13-
return element, (val, new_state)
11+
@inline function Base.iterate(::AbstractThunk, (underlying_object, state))
12+
next = iterate(underlying_object, state)
13+
next === nothing && return nothing
14+
element, new_state = next
15+
return element, (underlying_object, new_state)
1416
end
1517

1618
Base.:(==)(a::AbstractThunk, b::AbstractThunk) = unthunk(a) == unthunk(b)

test/differentials/thunks.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,14 @@
77
@test 3.2 == InplaceableThunk(@thunk(3.2), x -> x + 3.2)
88
end
99

10+
@testset "iterate" begin
11+
a = [1.0, 2.0, 3.0]
12+
t = @thunk(a)
13+
for (i, j) in zip(a, t)
14+
@test i == j
15+
end
16+
end
17+
1018
@testset "show" begin
1119
rep = repr(Thunk(rand))
1220
@test occursin(r"Thunk\(.*rand.*\)", rep)

0 commit comments

Comments
 (0)