Skip to content

Commit f5e2e5a

Browse files
committed
Use maxthreadid() in TSVI
1 parent 1f1bb01 commit f5e2e5a

File tree

2 files changed

+2
-7
lines changed

2 files changed

+2
-7
lines changed

src/threadsafe.jl

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,7 @@ function ThreadSafeVarInfo(vi::AbstractVarInfo)
1313
# fields. This is not good practice --- see
1414
# https://github.com/TuringLang/DynamicPPL.jl/issues/924 for a full
1515
# explanation --- but it has worked okay so far.
16-
# The use of nthreads()*2 here ensures that threadid() doesn't exceed
17-
# the length of the logps array. Ideally, we would use maxthreadid(),
18-
# but Mooncake can't differentiate through that. Empirically, nthreads()*2
19-
# seems to provide an upper bound to maxthreadid(), so we use that here.
20-
# See https://github.com/TuringLang/DynamicPPL.jl/pull/936
21-
accs_by_thread = [map(split, getaccs(vi)) for _ in 1:(Threads.nthreads() * 2)]
16+
accs_by_thread = [map(split, getaccs(vi)) for _ in 1:Threads.maxthreadid()]
2217
return ThreadSafeVarInfo(vi, accs_by_thread)
2318
end
2419
ThreadSafeVarInfo(vi::ThreadSafeVarInfo) = vi

test/threadsafe.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
@test threadsafe_vi.varinfo === vi
77
@test threadsafe_vi.accs_by_thread isa Vector{<:DynamicPPL.AccumulatorTuple}
8-
@test length(threadsafe_vi.accs_by_thread) == Threads.nthreads() * 2
8+
@test length(threadsafe_vi.accs_by_thread) == Threads.maxthreadid()
99
expected_accs = DynamicPPL.AccumulatorTuple(
1010
(DynamicPPL.split(acc) for acc in DynamicPPL.getaccs(vi))...
1111
)

0 commit comments

Comments
 (0)