File tree Expand file tree Collapse file tree 2 files changed +2
-7
lines changed
Expand file tree Collapse file tree 2 files changed +2
-7
lines changed Original file line number Diff line number Diff line change @@ -13,12 +13,7 @@ function ThreadSafeVarInfo(vi::AbstractVarInfo)
1313 # fields. This is not good practice --- see
1414 # https://github.com/TuringLang/DynamicPPL.jl/issues/924 for a full
1515 # explanation --- but it has worked okay so far.
16- # The use of nthreads()*2 here ensures that threadid() doesn't exceed
17- # the length of the logps array. Ideally, we would use maxthreadid(),
18- # but Mooncake can't differentiate through that. Empirically, nthreads()*2
19- # seems to provide an upper bound to maxthreadid(), so we use that here.
20- # See https://github.com/TuringLang/DynamicPPL.jl/pull/936
21- accs_by_thread = [map (split, getaccs (vi)) for _ in 1 : (Threads. nthreads () * 2 )]
16+ accs_by_thread = [map (split, getaccs (vi)) for _ in 1 : Threads. maxthreadid ()]
2217 return ThreadSafeVarInfo (vi, accs_by_thread)
2318end
2419ThreadSafeVarInfo (vi:: ThreadSafeVarInfo ) = vi
Original file line number Diff line number Diff line change 55
66 @test threadsafe_vi. varinfo === vi
77 @test threadsafe_vi. accs_by_thread isa Vector{<: DynamicPPL.AccumulatorTuple }
8- @test length (threadsafe_vi. accs_by_thread) == Threads. nthreads () * 2
8+ @test length (threadsafe_vi. accs_by_thread) == Threads. maxthreadid ()
99 expected_accs = DynamicPPL. AccumulatorTuple (
1010 (DynamicPPL. split (acc) for acc in DynamicPPL. getaccs (vi)). ..
1111 )
You can’t perform that action at this time.
0 commit comments