From 33c3db118224326f99b147ec6c8aed3c22482af2 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Thu, 20 Nov 2025 14:27:38 +0000 Subject: [PATCH 01/16] Make threadsafe evaluation opt-in --- HISTORY.md | 31 ++++++++++++- docs/src/api.md | 7 +++ src/DynamicPPL.jl | 1 + src/compiler.jl | 60 ++++++++++++++++++------ src/debug_utils.jl | 7 ++- src/model.jl | 111 +++++++++++++++++++++++---------------------- test/compiler.jl | 42 ++++++++++++++--- test/threadsafe.jl | 61 +++++-------------------- 8 files changed, 195 insertions(+), 125 deletions(-) diff --git a/HISTORY.md b/HISTORY.md index ff28349d8..5d6b086fb 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -9,12 +9,41 @@ This version provides a reimplementation of `LogDensityFunction` that provides performance improvements on the order of 2–10× for both model evaluation as well as automatic differentiation. Exact speedups depend on the model size: larger models have less significant speedups because the bulk of the work is done in calls to `logpdf`. -For more information about how this is accomplished, please see https://github.com/TuringLang/DynamicPPL.jl/pull/1113 as well as the `src/fasteval.jl` file, which contains extensive comments. +For more information about how this is accomplished, please see https://github.com/TuringLang/DynamicPPL.jl/pull/1113 as well as the `src/logdensityfunction.jl` file, which contains extensive comments. As a result of this change, `LogDensityFunction` no longer stores a VarInfo inside it. In general, if `ldf` is a `LogDensityFunction`, it is now only valid to access `ldf.model` and `ldf.adtype`. If you were previously relying on this behaviour, you will need to store a VarInfo separately. +#### Threadsafe evaluation + +DynamicPPL models are by default no longer thread-safe. +If you have threading in a model, you **must** now manually mark it as so, using: + +```julia +@model f() = ... +model = f() +model = setthreadsafe(model, true) +``` + +It used to be that DynamicPPL would 'automatically' enable thread-safe evaluation if Julia was launched with more than one thread (i.e., by checking `Threads.nthreads() > 1`). + +The problem with this approach is that it sacrifices a huge amount of performance. +Furthermore, it is not actually the correct approach: just because Julia has multiple threads does not mean that a particular model actually requires threadsafe evaluation. + +**A model requires threadsafe evaluation if, and only if, the VarInfo object used inside the model is manipulated in parallel.** +This can occur if any of the following are inside `Threads.@threads` or other concurrency functions / macros: + + - tilde-statements + - calls to `@addlogprob!` + - any direct manipulation of the special `__varinfo__` variable + +If you have none of these inside threaded blocks, then you do not need to mark your model as threadsafe. +**Notably, the following do not require threadsafe evaluation:** + + - Using threading for anything that does not involve VarInfo. For example, you can calculate a log-probability in parallel, and then add it using `@addlogprob!` outside of the threaded block. This does not require threadsafe evaluation. + - Sampling with `AbstractMCMC.MCMCThreads()`. + #### Parent and leaf contexts The `DynamicPPL.NodeTrait` function has been removed. diff --git a/docs/src/api.md b/docs/src/api.md index adb476db5..7d7308f82 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -42,6 +42,13 @@ The context of a model can be set using [`contextualize`](@ref): contextualize ``` +Some models require threadsafe evaluation (see https://turinglang.org/docs/THIS_DOESNT_EXIST_YET for more information on when this is necessary). +If this is the case, one must enable threadsafe evaluation for a model: + +```@docs +setthreadsafe +``` + ## Evaluation With [`rand`](@ref) one can draw samples from the prior distribution of a [`Model`](@ref). diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index a885f6a96..97664f952 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -90,6 +90,7 @@ export AbstractVarInfo, Model, getmissings, getargnames, + setthreadsafe, extract_priors, values_as_in_model, # evaluation diff --git a/src/compiler.jl b/src/compiler.jl index 3324780ca..87a3ad811 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -301,7 +301,7 @@ function model(mod, linenumbernode, expr, warn) modeldef = build_model_definition(expr) # Generate main body - modeldef[:body] = generate_mainbody(mod, modeldef[:body], warn) + modeldef[:body] = generate_mainbody(mod, modeldef[:body], warn, false) return build_output(modeldef, linenumbernode) end @@ -346,10 +346,11 @@ Generate the body of the main evaluation function from expression `expr` and arg If `warn` is true, a warning is displayed if internal variables are used in the model definition. """ -generate_mainbody(mod, expr, warn) = generate_mainbody!(mod, Symbol[], expr, warn) +generate_mainbody(mod, expr, warn, warned_about_threads_threads) = + generate_mainbody!(mod, Symbol[], expr, warn, warned_about_threads_threads) -generate_mainbody!(mod, found, x, warn) = x -function generate_mainbody!(mod, found, sym::Symbol, warn) +generate_mainbody!(mod, found, x, warn, warned_about_threads_threads) = x +function generate_mainbody!(mod, found, sym::Symbol, warn, warned_about_threads_threads) if warn && sym in INTERNALNAMES && sym ∉ found @warn "you are using the internal variable `$sym`" push!(found, sym) @@ -357,17 +358,40 @@ function generate_mainbody!(mod, found, sym::Symbol, warn) return sym end -function generate_mainbody!(mod, found, expr::Expr, warn) +function generate_mainbody!(mod, found, expr::Expr, warn, warned_about_threads_threads) # Do not touch interpolated expressions expr.head === :$ && return expr.args[1] + # Flag to determine whether we've issued a warning for threadsafe macros Note that this + # detection is not fully correct. We can only detect the presence of a macro that has + # the symbol `Threads.@threads`, however, we can't detect if that *is actually* + # Threads.@threads from Base.Threads. + # Do we don't want escaped expressions because we unfortunately # escape the entire body afterwards. - Meta.isexpr(expr, :escape) && return generate_mainbody(mod, found, expr.args[1], warn) + Meta.isexpr(expr, :escape) && return generate_mainbody( + mod, found, expr.args[1], warn, warned_about_threads_threads + ) # If it's a macro, we expand it if Meta.isexpr(expr, :macrocall) - return generate_mainbody!(mod, found, macroexpand(mod, expr; recursive=true), warn) + if expr.args[1] == Expr(:., :Threads, QuoteNode(Symbol("@threads"))) && + !warned_about_threads_threads + warned_about_threads_threads = true + @warn ( + "It looks like you are using `Threads.@threads` in your model definition." * + "\n\nNote that since version 0.39 of DynamicPPL, threadsafe evaluation of models is disabled by default." * + " If you need it, you will need to explicitly enable it by creating the model, and then running `model = setthreadsafe(model, true)`." * + "\n\nAvoiding threadsafe evaluation can often lead to significant performance improvements. Please see https://turinglang.org/docs/THIS_PAGE_DOESNT_EXIST_YET for more details of when threadsafe evaluation is actually required." + ) + end + return generate_mainbody!( + mod, + found, + macroexpand(mod, expr; recursive=true), + warn, + warned_about_threads_threads, + ) end # Modify dotted tilde operators. @@ -375,7 +399,11 @@ function generate_mainbody!(mod, found, expr::Expr, warn) if args_dottilde !== nothing L, R = args_dottilde return generate_mainbody!( - mod, found, Base.remove_linenums!(generate_dot_tilde(L, R)), warn + mod, + found, + Base.remove_linenums!(generate_dot_tilde(L, R)), + warn, + warned_about_threads_threads, ) end @@ -385,8 +413,8 @@ function generate_mainbody!(mod, found, expr::Expr, warn) L, R = args_tilde return Base.remove_linenums!( generate_tilde( - generate_mainbody!(mod, found, L, warn), - generate_mainbody!(mod, found, R, warn), + generate_mainbody!(mod, found, L, warn, warned_about_threads_threads), + generate_mainbody!(mod, found, R, warn, warned_about_threads_threads), ), ) end @@ -397,13 +425,19 @@ function generate_mainbody!(mod, found, expr::Expr, warn) L, R = args_assign return Base.remove_linenums!( generate_assign( - generate_mainbody!(mod, found, L, warn), - generate_mainbody!(mod, found, R, warn), + generate_mainbody!(mod, found, L, warn, warned_about_threads_threads), + generate_mainbody!(mod, found, R, warn, warned_about_threads_threads), ), ) end - return Expr(expr.head, map(x -> generate_mainbody!(mod, found, x, warn), expr.args)...) + return Expr( + expr.head, + map( + x -> generate_mainbody!(mod, found, x, warn, warned_about_threads_threads), + expr.args, + )..., + ) end function generate_assign(left, right) diff --git a/src/debug_utils.jl b/src/debug_utils.jl index e8b50a0b7..40a3cb3c1 100644 --- a/src/debug_utils.jl +++ b/src/debug_utils.jl @@ -424,8 +424,11 @@ function check_model_and_trace( # Perform checks before evaluating the model. issuccess = check_model_pre_evaluation(model) - # Force single-threaded execution. - _, varinfo = DynamicPPL.evaluate_threadunsafe!!(model, varinfo) + # TODO(penelopeysm): Implement merge, etc. for DebugAccumulator, and then perform a + # check on the merged accumulator, rather than checking it in the accumulate_assume + # calls. That way we can also support multi-threaded evaluation and use `evaluate!!` + # here instead of `_evaluate!!`. + _, varinfo = DynamicPPL._evaluate!!(model, varinfo) # Perform checks after evaluating the model. debug_acc = DynamicPPL.getacc(varinfo, Val(_DEBUG_ACC_NAME)) diff --git a/src/model.jl b/src/model.jl index 7d5bbf2fb..8d9759f61 100644 --- a/src/model.jl +++ b/src/model.jl @@ -1,5 +1,5 @@ """ - struct Model{F,argnames,defaultnames,missings,Targs,Tdefaults,Ctx<:AbstractContext} + struct Model{F,argnames,defaultnames,missings,Targs,Tdefaults,Ctx<:AbstractContext,Threaded} f::F args::NamedTuple{argnames,Targs} defaults::NamedTuple{defaultnames,Tdefaults} @@ -17,6 +17,10 @@ An argument with a type of `Missing` will be in `missings` by default. However, non-traditional use-cases `missings` can be defined differently. All variables in `missings` are treated as random variables rather than observations. +The `Threaded` type parameter indicates whether the model requires threadsafe evaluation +(i.e., whether the model contains statements which modify the internal VarInfo that are +executed in parallel). By default, this is set to `false`. + The default arguments are used internally when constructing instances of the same model with different arguments. @@ -33,8 +37,9 @@ julia> Model{(:y,)}(f, (x = 1.0, y = 2.0), (x = 42,)) # with special definition Model{typeof(f),(:x, :y),(:x,),(:y,),Tuple{Float64,Float64},Tuple{Int64}}(f, (x = 1.0, y = 2.0), (x = 42,)) ``` """ -struct Model{F,argnames,defaultnames,missings,Targs,Tdefaults,Ctx<:AbstractContext} <: - AbstractProbabilisticProgram +struct Model{ + F,argnames,defaultnames,missings,Targs,Tdefaults,Ctx<:AbstractContext,Threaded +} <: AbstractProbabilisticProgram f::F args::NamedTuple{argnames,Targs} defaults::NamedTuple{defaultnames,Tdefaults} @@ -46,13 +51,13 @@ struct Model{F,argnames,defaultnames,missings,Targs,Tdefaults,Ctx<:AbstractConte Create a model with evaluation function `f` and missing arguments overwritten by `missings`. """ - function Model{missings}( + function Model{missings,Threaded}( f::F, args::NamedTuple{argnames,Targs}, defaults::NamedTuple{defaultnames,Tdefaults}, context::Ctx=DefaultContext(), - ) where {missings,F,argnames,Targs,defaultnames,Tdefaults,Ctx} - return new{F,argnames,defaultnames,missings,Targs,Tdefaults,Ctx}( + ) where {missings,F,argnames,Targs,defaultnames,Tdefaults,Ctx,Threaded} + return new{F,argnames,defaultnames,missings,Targs,Tdefaults,Ctx,Threaded}( f, args, defaults, context ) end @@ -71,6 +76,7 @@ model with different arguments. args::NamedTuple{argnames,Targs}, defaults::NamedTuple{kwargnames,Tkwargs}, context::AbstractContext=DefaultContext(), + threadsafe::Bool=false, ) where {F,argnames,Targs,kwargnames,Tkwargs} missing_args = Tuple( name for (name, typ) in zip(argnames, Targs.types) if typ <: Missing @@ -78,11 +84,19 @@ model with different arguments. missing_kwargs = Tuple( name for (name, typ) in zip(kwargnames, Tkwargs.types) if typ <: Missing ) - return :(Model{$(missing_args..., missing_kwargs...)}(f, args, defaults, context)) + return :(Model{$(missing_args..., missing_kwargs...),threadsafe}( + f, args, defaults, context + )) end -function Model(f, args::NamedTuple, context::AbstractContext=DefaultContext(); kwargs...) - return Model(f, args, NamedTuple(kwargs), context) +function Model( + f, + args::NamedTuple, + context::AbstractContext=DefaultContext(), + threadsafe=false; + kwargs..., +) + return Model(f, args, NamedTuple(kwargs), context, threadsafe) end """ @@ -91,8 +105,10 @@ end Return a new `Model` with the same evaluation function and other arguments, but with its underlying context set to `context`. """ -function contextualize(model::Model, context::AbstractContext) - return Model(model.f, model.args, model.defaults, context) +function contextualize( + model::Model{F,A,D,M,Ta,Td,Ctx,Threaded}, context::AbstractContext +) where {F,A,D,M,Ta,Td,Ctx,Threaded} + return Model(model.f, model.args, model.defaults, context, Threaded) end """ @@ -105,6 +121,31 @@ function setleafcontext(model::Model, context::AbstractContext) return contextualize(model, setleafcontext(model.context, context)) end +""" + setthreadsafe(model::Model, threadsafe::Bool) + +Returns a new `Model` with its threadsafe flag set to `threadsafe`. + +Threadsafe evaluation allows for parallel execution of model statements that mutate the +internal `VarInfo` object. For example, this is needed if tilde-statements are nested inside +`Threads.@threads` or similar constructs. + +It is not needed for generic multithreaded operations that don't involve VarInfo. For +example, calculating a log-likelihood term in parallel and then calling `@addlogprob!` +outside of the parallel region is safe without needing to set `threadsafe=true`. + +It is also not needed for multithreaded sampling with AbstractMCMC's `MCMCThreads()`. +""" +function setthreadsafe( + model::Model{F,A,D,M,Ta,Td,Ctx,Threaded}, threadsafe::Bool +) where {F,A,D,M,Ta,Td,Ctx,Threaded} + return if Threaded == threadsafe + model + else + Model{M,threadsafe}(model.f, model.args, model.defaults, model.context) + end +end + """ model | (x = 1.0, ...) @@ -863,16 +904,6 @@ function (model::Model)(rng::Random.AbstractRNG, varinfo::AbstractVarInfo=VarInf return first(init!!(rng, model, varinfo)) end -""" - use_threadsafe_eval(context::AbstractContext, varinfo::AbstractVarInfo) - -Return `true` if evaluation of a model using `context` and `varinfo` should -wrap `varinfo` in `ThreadSafeVarInfo`, i.e. threadsafe evaluation, and `false` otherwise. -""" -function use_threadsafe_eval(context::AbstractContext, varinfo::AbstractVarInfo) - return Threads.nthreads() > 1 -end - """ init!!( [rng::Random.AbstractRNG,] @@ -944,40 +975,14 @@ If multiple threads are available, the varinfo provided will be wrapped in a Returns a tuple of the model's return value, plus the updated `varinfo` (unwrapped if necessary). """ -function AbstractPPL.evaluate!!(model::Model, varinfo::AbstractVarInfo) - return if use_threadsafe_eval(model.context, varinfo) - evaluate_threadsafe!!(model, varinfo) - else - evaluate_threadunsafe!!(model, varinfo) - end -end - -""" - evaluate_threadunsafe!!(model, varinfo) - -Evaluate the `model` without wrapping `varinfo` inside a `ThreadSafeVarInfo`. - -If the `model` makes use of Julia's multithreading this will lead to undefined behaviour. -This method is not exposed and supposed to be used only internally in DynamicPPL. - -See also: [`evaluate_threadsafe!!`](@ref) -""" -function evaluate_threadunsafe!!(model, varinfo) +function AbstractPPL.evaluate!!( + model::Model{F,A,D,M,Ta,Td,Ctx,false}, varinfo::AbstractVarInfo +) where {F,A,D,M,Ta,Td,Ctx} return _evaluate!!(model, resetaccs!!(varinfo)) end - -""" - evaluate_threadsafe!!(model, varinfo, context) - -Evaluate the `model` with `varinfo` wrapped inside a `ThreadSafeVarInfo`. - -With the wrapper, Julia's multithreading can be used for observe statements in the `model` -but parallel sampling will lead to undefined behaviour. -This method is not exposed and supposed to be used only internally in DynamicPPL. - -See also: [`evaluate_threadunsafe!!`](@ref) -""" -function evaluate_threadsafe!!(model, varinfo) +function AbstractPPL.evaluate!!( + model::Model{F,A,D,M,Ta,Td,Ctx,true}, varinfo::AbstractVarInfo +) where {F,A,D,M,Ta,Td,Ctx} wrapper = ThreadSafeVarInfo(resetaccs!!(varinfo)) result, wrapper_new = _evaluate!!(model, wrapper) # TODO(penelopeysm): If seems that if you pass a TSVI to this method, it diff --git a/test/compiler.jl b/test/compiler.jl index b1309254e..9056f666a 100644 --- a/test/compiler.jl +++ b/test/compiler.jl @@ -606,12 +606,7 @@ module Issue537 end @model demo() = return __varinfo__ retval, svi = DynamicPPL.init!!(demo(), SimpleVarInfo()) @test svi == SimpleVarInfo() - if Threads.nthreads() > 1 - @test retval isa DynamicPPL.ThreadSafeVarInfo{<:SimpleVarInfo} - @test retval.varinfo == svi - else - @test retval == svi - end + @test retval == svi # We should not be altering return-values other than at top-level. @model function demo() @@ -793,4 +788,39 @@ module Issue537 end res = model() @test res == (a=1, b=1, c=2, d=2, t=DynamicPPL.TypeWrap{Int}()) end + + @testset "Threads.@threads detection" begin + # Check that the compiler detects when `Threads.@threads` is used inside a model + + e1 = quote + @model function f1() + Threads.@threads for i in 1:10 + x[i] ~ Normal() + end + end + end + @test_logs (:warn, r"threadsafe evaluation") eval(e1) + + e2 = quote + @model function f2() + for j in 1:10 + Threads.@threads for i in 1:10 + x[i] ~ Normal() + end + end + end + end + @test_logs (:warn, r"threadsafe evaluation") eval(e2) + + e3 = quote + @model function f3() + begin + Threads.@threads for i in 1:10 + x[i] ~ Normal() + end + end + end + end + @test_logs (:warn, r"threadsafe evaluation") eval(e3) + end end diff --git a/test/threadsafe.jl b/test/threadsafe.jl index 522730566..027e51422 100644 --- a/test/threadsafe.jl +++ b/test/threadsafe.jl @@ -52,63 +52,24 @@ x[i] ~ Normal(x[i - 1], 1) end end - model = wthreads(x) + model = setthreadsafe(wthreads(x), true) - vi = VarInfo() - model(vi) - lp_w_threads = getlogjoint(vi) - if Threads.nthreads() == 1 - @test vi_ isa VarInfo - else - @test vi_ isa DynamicPPL.ThreadSafeVarInfo - end - - println("With `@threads`:") - println(" default:") - @time model(vi) - - # Ensure that we use `ThreadSafeVarInfo` to handle multithreaded observe statements. - DynamicPPL.evaluate_threadsafe!!(model, vi) - @test getlogjoint(vi) ≈ lp_w_threads - # check that it's wrapped during the model evaluation - @test vi_ isa DynamicPPL.ThreadSafeVarInfo - # ensure that it's unwrapped after evaluation finishes - @test vi isa VarInfo - - println(" evaluate_threadsafe!!:") - @time DynamicPPL.evaluate_threadsafe!!(model, vi) - - @model function wothreads(x) - global vi_ = __varinfo__ - x[1] ~ Normal(0, 1) + function correct_lp(x) + lp = logpdf(Normal(0, 1), x[1]) for i in 2:length(x) - x[i] ~ Normal(x[i - 1], 1) + lp += logpdf(Normal(x[i - 1], 1), x[i]) end + return lp end - model = wothreads(x) vi = VarInfo() - model(vi) - lp_wo_threads = getlogjoint(vi) - if Threads.nthreads() == 1 - @test vi_ isa VarInfo - else - @test vi_ isa DynamicPPL.ThreadSafeVarInfo - end - - println("Without `@threads`:") - println(" default:") - @time model(vi) + _, vi = DynamicPPL.evaluate!!(model, vi) - @test lp_w_threads ≈ lp_wo_threads - - # Ensure that we use `VarInfo`. - DynamicPPL.evaluate_threadunsafe!!(model, vi) - @test getlogjoint(vi) ≈ lp_w_threads - @test vi_ isa VarInfo + # check that logp is correct + @test getlogjoint(vi) ≈ correct_lp(x) + # check that varinfo was wrapped during the model evaluation + @test vi_ isa DynamicPPL.ThreadSafeVarInfo + # ensure that it's unwrapped after evaluation finishes @test vi isa VarInfo - - println(" evaluate_threadunsafe!!:") - @time DynamicPPL.evaluate_threadunsafe!!(model, vi) end end From cba5d29a478d469a4a42ccf67c14aec64bb9a1ad Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Fri, 21 Nov 2025 11:22:04 +0000 Subject: [PATCH 02/16] Reduce number of type parameters in methods --- src/model.jl | 49 ++++++++++++++++++++++++------------------------- 1 file changed, 24 insertions(+), 25 deletions(-) diff --git a/src/model.jl b/src/model.jl index 8d9759f61..1d4128526 100644 --- a/src/model.jl +++ b/src/model.jl @@ -99,16 +99,20 @@ function Model( return Model(f, args, NamedTuple(kwargs), context, threadsafe) end +function _requires_threadsafe( + ::Model{F,A,D,M,Ta,Td,Ctx,Threaded} +) where {F,A,D,M,Ta,Td,Ctx,Threaded} + return Threaded +end + """ contextualize(model::Model, context::AbstractContext) Return a new `Model` with the same evaluation function and other arguments, but with its underlying context set to `context`. """ -function contextualize( - model::Model{F,A,D,M,Ta,Td,Ctx,Threaded}, context::AbstractContext -) where {F,A,D,M,Ta,Td,Ctx,Threaded} - return Model(model.f, model.args, model.defaults, context, Threaded) +function contextualize(model::Model, context::AbstractContext) + return Model(model.f, model.args, model.defaults, context, _requires_threadsafe(model)) end """ @@ -136,10 +140,8 @@ outside of the parallel region is safe without needing to set `threadsafe=true`. It is also not needed for multithreaded sampling with AbstractMCMC's `MCMCThreads()`. """ -function setthreadsafe( - model::Model{F,A,D,M,Ta,Td,Ctx,Threaded}, threadsafe::Bool -) where {F,A,D,M,Ta,Td,Ctx,Threaded} - return if Threaded == threadsafe +function setthreadsafe(model::Model{F,A,D,M}, threadsafe::Bool) where {F,A,D,M} + return if _requires_threadsafe(model) == threadsafe model else Model{M,threadsafe}(model.f, model.args, model.defaults, model.context) @@ -969,27 +971,24 @@ end Evaluate the `model` with the given `varinfo`. -If multiple threads are available, the varinfo provided will be wrapped in a -`ThreadSafeVarInfo` before evaluation. +If the model has been marked as requiring threadsafe evaluation, are available, the varinfo +provided will be wrapped in a `ThreadSafeVarInfo` before evaluation. Returns a tuple of the model's return value, plus the updated `varinfo` (unwrapped if necessary). """ -function AbstractPPL.evaluate!!( - model::Model{F,A,D,M,Ta,Td,Ctx,false}, varinfo::AbstractVarInfo -) where {F,A,D,M,Ta,Td,Ctx} - return _evaluate!!(model, resetaccs!!(varinfo)) -end -function AbstractPPL.evaluate!!( - model::Model{F,A,D,M,Ta,Td,Ctx,true}, varinfo::AbstractVarInfo -) where {F,A,D,M,Ta,Td,Ctx} - wrapper = ThreadSafeVarInfo(resetaccs!!(varinfo)) - result, wrapper_new = _evaluate!!(model, wrapper) - # TODO(penelopeysm): If seems that if you pass a TSVI to this method, it - # will return the underlying VI, which is a bit counterintuitive (because - # calling TSVI(::TSVI) returns the original TSVI, instead of wrapping it - # again). - return result, setaccs!!(wrapper_new.varinfo, getaccs(wrapper_new)) +function AbstractPPL.evaluate!!(model::Model, varinfo::AbstractVarInfo) + return if _requires_threadsafe(model) + wrapper = ThreadSafeVarInfo(resetaccs!!(varinfo)) + result, wrapper_new = _evaluate!!(model, wrapper) + # TODO(penelopeysm): If seems that if you pass a TSVI to this method, it + # will return the underlying VI, which is a bit counterintuitive (because + # calling TSVI(::TSVI) returns the original TSVI, instead of wrapping it + # again). + return result, setaccs!!(wrapper_new.varinfo, getaccs(wrapper_new)) + else + _evaluate!!(model, resetaccs!!(varinfo)) + end end """ From 8d1fcaa9c3e21e3723d124d7462f0b2166b64ee8 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Fri, 21 Nov 2025 11:23:17 +0000 Subject: [PATCH 03/16] Make `warned_warn_about_threads_threads_threads_threads` shorter --- src/compiler.jl | 44 ++++++++++++++++---------------------------- 1 file changed, 16 insertions(+), 28 deletions(-) diff --git a/src/compiler.jl b/src/compiler.jl index 87a3ad811..3e635de7a 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -346,11 +346,11 @@ Generate the body of the main evaluation function from expression `expr` and arg If `warn` is true, a warning is displayed if internal variables are used in the model definition. """ -generate_mainbody(mod, expr, warn, warned_about_threads_threads) = - generate_mainbody!(mod, Symbol[], expr, warn, warned_about_threads_threads) +generate_mainbody(mod, expr, warn, warn_threads) = + generate_mainbody!(mod, Symbol[], expr, warn, warn_threads) -generate_mainbody!(mod, found, x, warn, warned_about_threads_threads) = x -function generate_mainbody!(mod, found, sym::Symbol, warn, warned_about_threads_threads) +generate_mainbody!(mod, found, x, warn, warn_threads) = x +function generate_mainbody!(mod, found, sym::Symbol, warn, warn_threads) if warn && sym in INTERNALNAMES && sym ∉ found @warn "you are using the internal variable `$sym`" push!(found, sym) @@ -358,7 +358,7 @@ function generate_mainbody!(mod, found, sym::Symbol, warn, warned_about_threads_ return sym end -function generate_mainbody!(mod, found, expr::Expr, warn, warned_about_threads_threads) +function generate_mainbody!(mod, found, expr::Expr, warn, warn_threads) # Do not touch interpolated expressions expr.head === :$ && return expr.args[1] @@ -369,15 +369,14 @@ function generate_mainbody!(mod, found, expr::Expr, warn, warned_about_threads_t # Do we don't want escaped expressions because we unfortunately # escape the entire body afterwards. - Meta.isexpr(expr, :escape) && return generate_mainbody( - mod, found, expr.args[1], warn, warned_about_threads_threads - ) + Meta.isexpr(expr, :escape) && + return generate_mainbody(mod, found, expr.args[1], warn, warn_threads) # If it's a macro, we expand it if Meta.isexpr(expr, :macrocall) if expr.args[1] == Expr(:., :Threads, QuoteNode(Symbol("@threads"))) && - !warned_about_threads_threads - warned_about_threads_threads = true + !warn_threads + warn_threads = true @warn ( "It looks like you are using `Threads.@threads` in your model definition." * "\n\nNote that since version 0.39 of DynamicPPL, threadsafe evaluation of models is disabled by default." * @@ -386,11 +385,7 @@ function generate_mainbody!(mod, found, expr::Expr, warn, warned_about_threads_t ) end return generate_mainbody!( - mod, - found, - macroexpand(mod, expr; recursive=true), - warn, - warned_about_threads_threads, + mod, found, macroexpand(mod, expr; recursive=true), warn, warn_threads ) end @@ -399,11 +394,7 @@ function generate_mainbody!(mod, found, expr::Expr, warn, warned_about_threads_t if args_dottilde !== nothing L, R = args_dottilde return generate_mainbody!( - mod, - found, - Base.remove_linenums!(generate_dot_tilde(L, R)), - warn, - warned_about_threads_threads, + mod, found, Base.remove_linenums!(generate_dot_tilde(L, R)), warn, warn_threads ) end @@ -413,8 +404,8 @@ function generate_mainbody!(mod, found, expr::Expr, warn, warned_about_threads_t L, R = args_tilde return Base.remove_linenums!( generate_tilde( - generate_mainbody!(mod, found, L, warn, warned_about_threads_threads), - generate_mainbody!(mod, found, R, warn, warned_about_threads_threads), + generate_mainbody!(mod, found, L, warn, warn_threads), + generate_mainbody!(mod, found, R, warn, warn_threads), ), ) end @@ -425,18 +416,15 @@ function generate_mainbody!(mod, found, expr::Expr, warn, warned_about_threads_t L, R = args_assign return Base.remove_linenums!( generate_assign( - generate_mainbody!(mod, found, L, warn, warned_about_threads_threads), - generate_mainbody!(mod, found, R, warn, warned_about_threads_threads), + generate_mainbody!(mod, found, L, warn, warn_threads), + generate_mainbody!(mod, found, R, warn, warn_threads), ), ) end return Expr( expr.head, - map( - x -> generate_mainbody!(mod, found, x, warn, warned_about_threads_threads), - expr.args, - )..., + map(x -> generate_mainbody!(mod, found, x, warn, warn_threads), expr.args)..., ) end From cb562809f7ed86b13bce89de61ccf0ada1a9ca5b Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Fri, 21 Nov 2025 11:25:26 +0000 Subject: [PATCH 04/16] Improve `setthreadsafe` docstring --- src/model.jl | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/src/model.jl b/src/model.jl index 1d4128526..94ec9adcb 100644 --- a/src/model.jl +++ b/src/model.jl @@ -130,15 +130,19 @@ end Returns a new `Model` with its threadsafe flag set to `threadsafe`. -Threadsafe evaluation allows for parallel execution of model statements that mutate the -internal `VarInfo` object. For example, this is needed if tilde-statements are nested inside -`Threads.@threads` or similar constructs. +Threadsafe evaluation ensures correctness when executing model statements that mutate the +internal `VarInfo` object in parallel. For example, this is needed if tilde-statements are +nested inside `Threads.@threads` or similar constructs. It is not needed for generic multithreaded operations that don't involve VarInfo. For example, calculating a log-likelihood term in parallel and then calling `@addlogprob!` outside of the parallel region is safe without needing to set `threadsafe=true`. It is also not needed for multithreaded sampling with AbstractMCMC's `MCMCThreads()`. + +Setting `threadsafe` to `true` increases the overhead in evaluating the model. See +(https://turinglang.org/docs/THIS_DOESNT_EXIST_YET)[https://turinglang.org/docs/THIS_DOESNT_EXIST_YET] +for more details. """ function setthreadsafe(model::Model{F,A,D,M}, threadsafe::Bool) where {F,A,D,M} return if _requires_threadsafe(model) == threadsafe From eeeb12f7c506dae316566c2dddf86e92d3ac214c Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Fri, 21 Nov 2025 19:03:08 +0000 Subject: [PATCH 05/16] warn on bare `@threads` as well --- src/compiler.jl | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/compiler.jl b/src/compiler.jl index 3e635de7a..67f6b7937 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -374,8 +374,11 @@ function generate_mainbody!(mod, found, expr::Expr, warn, warn_threads) # If it's a macro, we expand it if Meta.isexpr(expr, :macrocall) - if expr.args[1] == Expr(:., :Threads, QuoteNode(Symbol("@threads"))) && + if ( + expr.args[1] == Symbol("@threads") || + expr.args[1] == Expr(:., :Threads, QuoteNode(Symbol("@threads"))) && !warn_threads + ) warn_threads = true @warn ( "It looks like you are using `Threads.@threads` in your model definition." * From 0688f11c4800042b3ec94a51573c1bdca1cf79a4 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Wed, 26 Nov 2025 11:26:43 +0000 Subject: [PATCH 06/16] fix merge --- src/model.jl | 5 +-- test/logdensityfunction.jl | 78 ++++++++++++++++++-------------------- 2 files changed, 38 insertions(+), 45 deletions(-) diff --git a/src/model.jl b/src/model.jl index 94ec9adcb..364a156a2 100644 --- a/src/model.jl +++ b/src/model.jl @@ -937,10 +937,7 @@ Returns a tuple of the model's return value, plus the updated `varinfo` object. ) ctx = InitContext(rng, strategy) model = DynamicPPL.setleafcontext(model, ctx) - # TODO(penelopeysm): This should _not_ check Threads.nthreads(). I still don't know what - # it _should_ do, but this is wrong regardless. - # https://github.com/TuringLang/DynamicPPL.jl/issues/1086 - return if Threads.nthreads() > 1 + return if _requires_threadsafe(model) # TODO(penelopeysm): The logic for setting eltype of accs is very similar to that # used in `unflatten`. The reason why we need it here is because the VarInfo `vi` # won't have been filled with parameters prior to `init!!` being called. diff --git a/test/logdensityfunction.jl b/test/logdensityfunction.jl index f43ed45a4..b75895015 100644 --- a/test/logdensityfunction.jl +++ b/test/logdensityfunction.jl @@ -51,21 +51,19 @@ using Mooncake: Mooncake end @testset "Threaded observe" begin - if Threads.nthreads() > 1 - @model function threaded(y) - x ~ Normal() - Threads.@threads for i in eachindex(y) - y[i] ~ Normal(x) - end + @model function threaded(y) + x ~ Normal() + Threads.@threads for i in eachindex(y) + y[i] ~ Normal(x) end - N = 100 - model = threaded(zeros(N)) - ldf = DynamicPPL.LogDensityFunction(model) - - xs = [1.0] - @test LogDensityProblems.logdensity(ldf, xs) ≈ - logpdf(Normal(), xs[1]) + N * logpdf(Normal(xs[1]), 0.0) end + N = 100 + model = setthreadsafe(threaded(zeros(N)), true) + ldf = DynamicPPL.LogDensityFunction(model) + + xs = [1.0] + @test LogDensityProblems.logdensity(ldf, xs) ≈ + logpdf(Normal(), xs[1]) + N * logpdf(Normal(xs[1]), 0.0) end end @@ -125,34 +123,32 @@ end end @testset "LogDensityFunction: performance" begin - if Threads.nthreads() == 1 - # Evaluating these three models should not lead to any allocations (but only when - # not using TSVI). - @model function f() - x ~ Normal() - return 1.0 ~ Normal(x) - end - @model function submodel_inner() - m ~ Normal(0, 1) - s ~ Exponential() - return (m=m, s=s) - end - # Note that for the allocation tests to work on this one, `inner` has - # to be passed as an argument to `submodel_outer`, instead of just - # being called inside the model function itself - @model function submodel_outer(inner) - params ~ to_submodel(inner) - y ~ Normal(params.m, params.s) - return 1.0 ~ Normal(y) - end - @testset for model in - (f(), submodel_inner() | (; s=0.0), submodel_outer(submodel_inner())) - vi = VarInfo(model) - ldf = DynamicPPL.LogDensityFunction(model, DynamicPPL.getlogjoint_internal, vi) - x = vi[:] - bench = median(@be LogDensityProblems.logdensity(ldf, x)) - @test iszero(bench.allocs) - end + # Evaluating these three models should not lead to any allocations (but only when + # not using TSVI). + @model function f() + x ~ Normal() + return 1.0 ~ Normal(x) + end + @model function submodel_inner() + m ~ Normal(0, 1) + s ~ Exponential() + return (m=m, s=s) + end + # Note that for the allocation tests to work on this one, `inner` has + # to be passed as an argument to `submodel_outer`, instead of just + # being called inside the model function itself + @model function submodel_outer(inner) + params ~ to_submodel(inner) + y ~ Normal(params.m, params.s) + return 1.0 ~ Normal(y) + end + @testset for model in + (f(), submodel_inner() | (; s=0.0), submodel_outer(submodel_inner())) + vi = VarInfo(model) + ldf = DynamicPPL.LogDensityFunction(model, DynamicPPL.getlogjoint_internal, vi) + x = vi[:] + bench = median(@be LogDensityProblems.logdensity(ldf, x)) + @test iszero(bench.allocs) end end From 1f1bb0151c993b8b9a8695f6a7486c94bc05627b Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Wed, 26 Nov 2025 11:45:29 +0000 Subject: [PATCH 07/16] Fix performance issues --- src/compiler.jl | 2 +- src/model.jl | 31 +++++++++++++------------------ test/logdensityfunction.jl | 2 +- 3 files changed, 15 insertions(+), 20 deletions(-) diff --git a/src/compiler.jl b/src/compiler.jl index 67f6b7937..60f19d164 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -724,7 +724,7 @@ function build_output(modeldef, linenumbernode) # to the call site modeldef[:body] = MacroTools.@q begin $(linenumbernode) - return $(DynamicPPL.Model)($name, $args_nt; $(kwargs_inclusion...)) + return $(DynamicPPL.Model){false}($name, $args_nt; $(kwargs_inclusion...)) end return MacroTools.@q begin diff --git a/src/model.jl b/src/model.jl index 364a156a2..90c8749ca 100644 --- a/src/model.jl +++ b/src/model.jl @@ -46,12 +46,12 @@ struct Model{ context::Ctx @doc """ - Model{missings}(f, args::NamedTuple, defaults::NamedTuple) + Model{Threaded,missings}(f, args::NamedTuple, defaults::NamedTuple) Create a model with evaluation function `f` and missing arguments overwritten by `missings`. """ - function Model{missings,Threaded}( + function Model{Threaded,missings}( f::F, args::NamedTuple{argnames,Targs}, defaults::NamedTuple{defaultnames,Tdefaults}, @@ -71,32 +71,27 @@ Create a model with evaluation function `f` and missing arguments deduced from ` Default arguments `defaults` are used internally when constructing instances of the same model with different arguments. """ -@generated function Model( +@generated function Model{Threaded}( f::F, args::NamedTuple{argnames,Targs}, defaults::NamedTuple{kwargnames,Tkwargs}, context::AbstractContext=DefaultContext(), - threadsafe::Bool=false, -) where {F,argnames,Targs,kwargnames,Tkwargs} +) where {Threaded,F,argnames,Targs,kwargnames,Tkwargs} missing_args = Tuple( name for (name, typ) in zip(argnames, Targs.types) if typ <: Missing ) missing_kwargs = Tuple( name for (name, typ) in zip(kwargnames, Tkwargs.types) if typ <: Missing ) - return :(Model{$(missing_args..., missing_kwargs...),threadsafe}( + return :(Model{Threaded,$(missing_args..., missing_kwargs...)}( f, args, defaults, context )) end -function Model( - f, - args::NamedTuple, - context::AbstractContext=DefaultContext(), - threadsafe=false; - kwargs..., -) - return Model(f, args, NamedTuple(kwargs), context, threadsafe) +function Model{Threaded}( + f, args::NamedTuple, context::AbstractContext=DefaultContext(); kwargs... +) where {Threaded} + return Model{Threaded}(f, args, NamedTuple(kwargs), context) end function _requires_threadsafe( @@ -112,7 +107,7 @@ Return a new `Model` with the same evaluation function and other arguments, but with its underlying context set to `context`. """ function contextualize(model::Model, context::AbstractContext) - return Model(model.f, model.args, model.defaults, context, _requires_threadsafe(model)) + return Model{_requires_threadsafe(model)}(model.f, model.args, model.defaults, context) end """ @@ -148,7 +143,7 @@ function setthreadsafe(model::Model{F,A,D,M}, threadsafe::Bool) where {F,A,D,M} return if _requires_threadsafe(model) == threadsafe model else - Model{M,threadsafe}(model.f, model.args, model.defaults, model.context) + Model{threadsafe,M}(model.f, model.args, model.defaults, model.context) end end @@ -955,9 +950,9 @@ Returns a tuple of the model's return value, plus the updated `varinfo` object. vi = DynamicPPL.setaccs!!(vi, accs) tsvi = ThreadSafeVarInfo(resetaccs!!(vi)) retval, tsvi_new = DynamicPPL._evaluate!!(model, tsvi) - return retval, setaccs!!(tsvi_new.varinfo, DynamicPPL.getaccs(tsvi_new)) + retval, setaccs!!(tsvi_new.varinfo, DynamicPPL.getaccs(tsvi_new)) else - return DynamicPPL._evaluate!!(model, resetaccs!!(vi)) + DynamicPPL._evaluate!!(model, resetaccs!!(vi)) end end @inline function init!!( diff --git a/test/logdensityfunction.jl b/test/logdensityfunction.jl index b75895015..1d609a013 100644 --- a/test/logdensityfunction.jl +++ b/test/logdensityfunction.jl @@ -147,7 +147,7 @@ end vi = VarInfo(model) ldf = DynamicPPL.LogDensityFunction(model, DynamicPPL.getlogjoint_internal, vi) x = vi[:] - bench = median(@be LogDensityProblems.logdensity(ldf, x)) + bench = median(@be LogDensityProblems.logdensity($ldf, $x)) @test iszero(bench.allocs) end end From f5e2e5a0edb437b7e0a609d4c7b5323bef599b6d Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Wed, 26 Nov 2025 12:19:48 +0000 Subject: [PATCH 08/16] Use maxthreadid() in TSVI --- src/threadsafe.jl | 7 +------ test/threadsafe.jl | 2 +- 2 files changed, 2 insertions(+), 7 deletions(-) diff --git a/src/threadsafe.jl b/src/threadsafe.jl index 89877f385..0e906b6ca 100644 --- a/src/threadsafe.jl +++ b/src/threadsafe.jl @@ -13,12 +13,7 @@ function ThreadSafeVarInfo(vi::AbstractVarInfo) # fields. This is not good practice --- see # https://github.com/TuringLang/DynamicPPL.jl/issues/924 for a full # explanation --- but it has worked okay so far. - # The use of nthreads()*2 here ensures that threadid() doesn't exceed - # the length of the logps array. Ideally, we would use maxthreadid(), - # but Mooncake can't differentiate through that. Empirically, nthreads()*2 - # seems to provide an upper bound to maxthreadid(), so we use that here. - # See https://github.com/TuringLang/DynamicPPL.jl/pull/936 - accs_by_thread = [map(split, getaccs(vi)) for _ in 1:(Threads.nthreads() * 2)] + accs_by_thread = [map(split, getaccs(vi)) for _ in 1:Threads.maxthreadid()] return ThreadSafeVarInfo(vi, accs_by_thread) end ThreadSafeVarInfo(vi::ThreadSafeVarInfo) = vi diff --git a/test/threadsafe.jl b/test/threadsafe.jl index 027e51422..8bab2ba3e 100644 --- a/test/threadsafe.jl +++ b/test/threadsafe.jl @@ -5,7 +5,7 @@ @test threadsafe_vi.varinfo === vi @test threadsafe_vi.accs_by_thread isa Vector{<:DynamicPPL.AccumulatorTuple} - @test length(threadsafe_vi.accs_by_thread) == Threads.nthreads() * 2 + @test length(threadsafe_vi.accs_by_thread) == Threads.maxthreadid() expected_accs = DynamicPPL.AccumulatorTuple( (DynamicPPL.split(acc) for acc in DynamicPPL.getaccs(vi))... ) From 284dc8885ed84ebb6585051615692be474a427b9 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Wed, 26 Nov 2025 12:19:54 +0000 Subject: [PATCH 09/16] Move convert_eltype code to threadsafe eval function --- src/model.jl | 43 +++++++++++++++++++------------------------ src/simple_varinfo.jl | 10 +--------- src/varinfo.jl | 16 +--------------- 3 files changed, 21 insertions(+), 48 deletions(-) diff --git a/src/model.jl b/src/model.jl index 90c8749ca..9709f6e74 100644 --- a/src/model.jl +++ b/src/model.jl @@ -921,7 +921,7 @@ If `init_strategy` is not provided, defaults to `InitFromPrior()`. Returns a tuple of the model's return value, plus the updated `varinfo` object. """ -@inline function init!!( +function init!!( # Note that this `@inline` is mandatory for performance, especially for # LogDensityFunction. If it's not inlined, it leads to extra allocations (even for # trivial models) and much slower runtime. @@ -932,30 +932,9 @@ Returns a tuple of the model's return value, plus the updated `varinfo` object. ) ctx = InitContext(rng, strategy) model = DynamicPPL.setleafcontext(model, ctx) - return if _requires_threadsafe(model) - # TODO(penelopeysm): The logic for setting eltype of accs is very similar to that - # used in `unflatten`. The reason why we need it here is because the VarInfo `vi` - # won't have been filled with parameters prior to `init!!` being called. - # - # Note that this eltype promotion is only needed for threadsafe evaluation. In an - # ideal world, this code should be handled inside `evaluate_threadsafe!!` or a - # similar method. In other words, it should not be here, and it should not be inside - # `unflatten` either. The problem is performance. Shifting this code around can have - # massive, inexplicable, impacts on performance. This should be investigated - # properly. - param_eltype = DynamicPPL.get_param_eltype(strategy) - accs = map(vi.accs) do acc - DynamicPPL.convert_eltype(float_type_with_fallback(param_eltype), acc) - end - vi = DynamicPPL.setaccs!!(vi, accs) - tsvi = ThreadSafeVarInfo(resetaccs!!(vi)) - retval, tsvi_new = DynamicPPL._evaluate!!(model, tsvi) - retval, setaccs!!(tsvi_new.varinfo, DynamicPPL.getaccs(tsvi_new)) - else - DynamicPPL._evaluate!!(model, resetaccs!!(vi)) - end + return DynamicPPL.evaluate!!(model, vi) end -@inline function init!!( +function init!!( model::Model, vi::AbstractVarInfo, strategy::AbstractInitStrategy=InitFromPrior() ) # This `@inline` is also mandatory for performance @@ -975,6 +954,22 @@ Returns a tuple of the model's return value, plus the updated `varinfo` """ function AbstractPPL.evaluate!!(model::Model, varinfo::AbstractVarInfo) return if _requires_threadsafe(model) + # Use of float_type_with_fallback(eltype(x)) is necessary to deal with cases where x is + # a gradient type of some AD backend. + # TODO(mhauru) How could we do this more cleanly? The problem case is map_accumulator!! + # for ThreadSafeVarInfo. In that one, if the map produces e.g a ForwardDiff.Dual, but + # the accumulators in the VarInfo are plain floats, we error since we can't change the + # element type of ThreadSafeVarInfo.accs_by_thread. However, doing this conversion here + # messes with cases like using Float32 of logprobs and Float64 for x. Also, this is just + # plain ugly and hacky. + # The below line is finicky for type stability. For instance, assigning the eltype to + # convert to into an intermediate variable makes this unstable (constant propagation) + # fails. Take care when editing. + param_eltype = DynamicPPL.get_param_eltype(varinfo, model.context) + accs = map(DynamicPPL.getaccs(varinfo)) do acc + DynamicPPL.convert_eltype(float_type_with_fallback(param_eltype), acc) + end + varinfo = DynamicPPL.setaccs!!(varinfo, accs) wrapper = ThreadSafeVarInfo(resetaccs!!(varinfo)) result, wrapper_new = _evaluate!!(model, wrapper) # TODO(penelopeysm): If seems that if you pass a TSVI to this method, it diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index 434480be6..9d3fb1925 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -278,15 +278,7 @@ end function unflatten(svi::SimpleVarInfo, x::AbstractVector) vals = unflatten(svi.values, x) - # TODO(mhauru) See comment in unflatten in src/varinfo.jl for why this conversion is - # required but undesireable. - # The below line is finicky for type stability. For instance, assigning the eltype to - # convert to into an intermediate variable makes this unstable (constant propagation) - # fails. Take care when editing. - accs = map( - acc -> convert_eltype(float_type_with_fallback(eltype(x)), acc), getaccs(svi) - ) - return SimpleVarInfo(vals, accs, svi.transformation) + return SimpleVarInfo(vals, svi.accs, svi.transformation) end function BangBang.empty!!(vi::SimpleVarInfo) diff --git a/src/varinfo.jl b/src/varinfo.jl index 486d24191..14e08515c 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -367,21 +367,7 @@ vector_length(md::Metadata) = sum(length, md.ranges) function unflatten(vi::VarInfo, x::AbstractVector) md = unflatten_metadata(vi.metadata, x) - # Use of float_type_with_fallback(eltype(x)) is necessary to deal with cases where x is - # a gradient type of some AD backend. - # TODO(mhauru) How could we do this more cleanly? The problem case is map_accumulator!! - # for ThreadSafeVarInfo. In that one, if the map produces e.g a ForwardDiff.Dual, but - # the accumulators in the VarInfo are plain floats, we error since we can't change the - # element type of ThreadSafeVarInfo.accs_by_thread. However, doing this conversion here - # messes with cases like using Float32 of logprobs and Float64 for x. Also, this is just - # plain ugly and hacky. - # The below line is finicky for type stability. For instance, assigning the eltype to - # convert to into an intermediate variable makes this unstable (constant propagation) - # fails. Take care when editing. - accs = map( - acc -> convert_eltype(float_type_with_fallback(eltype(x)), acc), copy(getaccs(vi)) - ) - return VarInfo(md, accs) + return VarInfo(md, vi.accs) end # We would call this `unflatten` if not for `unflatten` having a method for NamedTuples in From 74ea4ac815a2df9d51019b63a0e6a69fd3ccb5e8 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Thu, 27 Nov 2025 12:17:43 +0000 Subject: [PATCH 10/16] Point to new Turing docs page --- HISTORY.md | 16 ++++++++++------ docs/src/api.md | 2 +- src/compiler.jl | 2 +- src/model.jl | 6 +++--- 4 files changed, 15 insertions(+), 11 deletions(-) diff --git a/HISTORY.md b/HISTORY.md index 5d6b086fb..f70aab244 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -17,8 +17,11 @@ If you were previously relying on this behaviour, you will need to store a VarIn #### Threadsafe evaluation -DynamicPPL models are by default no longer thread-safe. -If you have threading in a model, you **must** now manually mark it as so, using: +DynamicPPL models have traditionally supported running some probabilistic statements (e.g. tilde-statements, or `@addlogprob!`) in parallel. +Prior to DynamicPPL 0.39, thread safety for such models used to be enabled by default if Julia was launched with more than one thread. + +In DynamicPPL 0.39, **thread-safe evaluation is now disabled by default**. +If you need it (see below for more discussion of when you _do_ need it), you **must** now manually mark it as so, using: ```julia @model f() = ... @@ -26,9 +29,8 @@ model = f() model = setthreadsafe(model, true) ``` -It used to be that DynamicPPL would 'automatically' enable thread-safe evaluation if Julia was launched with more than one thread (i.e., by checking `Threads.nthreads() > 1`). - -The problem with this approach is that it sacrifices a huge amount of performance. +The problem with the previous on-by-default is that it can sacrifice a huge amount of performance when thread safety is not needed. +This is especially true when running Julia in a notebook, where multiple threads are often enabled by default. Furthermore, it is not actually the correct approach: just because Julia has multiple threads does not mean that a particular model actually requires threadsafe evaluation. **A model requires threadsafe evaluation if, and only if, the VarInfo object used inside the model is manipulated in parallel.** @@ -41,9 +43,11 @@ This can occur if any of the following are inside `Threads.@threads` or other co If you have none of these inside threaded blocks, then you do not need to mark your model as threadsafe. **Notably, the following do not require threadsafe evaluation:** - - Using threading for anything that does not involve VarInfo. For example, you can calculate a log-probability in parallel, and then add it using `@addlogprob!` outside of the threaded block. This does not require threadsafe evaluation. + - Using threading for any computation that does not involve VarInfo. For example, you can calculate a log-probability in parallel, and then add it using `@addlogprob!` outside of the threaded block. This does not require threadsafe evaluation. - Sampling with `AbstractMCMC.MCMCThreads()`. +For more information about threadsafe evaluation, please see [the Turing docs](https://turinglang.org/docs/usage/threadsafe-evaluation/). + #### Parent and leaf contexts The `DynamicPPL.NodeTrait` function has been removed. diff --git a/docs/src/api.md b/docs/src/api.md index 7d7308f82..eaba4de82 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -42,7 +42,7 @@ The context of a model can be set using [`contextualize`](@ref): contextualize ``` -Some models require threadsafe evaluation (see https://turinglang.org/docs/THIS_DOESNT_EXIST_YET for more information on when this is necessary). +Some models require threadsafe evaluation (see [the Turing docs](https://turinglang.org/docs/usage/threadsafe-evaluation/) for more information on when this is necessary). If this is the case, one must enable threadsafe evaluation for a model: ```@docs diff --git a/src/compiler.jl b/src/compiler.jl index 60f19d164..d540cec12 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -384,7 +384,7 @@ function generate_mainbody!(mod, found, expr::Expr, warn, warn_threads) "It looks like you are using `Threads.@threads` in your model definition." * "\n\nNote that since version 0.39 of DynamicPPL, threadsafe evaluation of models is disabled by default." * " If you need it, you will need to explicitly enable it by creating the model, and then running `model = setthreadsafe(model, true)`." * - "\n\nAvoiding threadsafe evaluation can often lead to significant performance improvements. Please see https://turinglang.org/docs/THIS_PAGE_DOESNT_EXIST_YET for more details of when threadsafe evaluation is actually required." + "\n\nAvoiding threadsafe evaluation can often lead to significant performance improvements. Please see https://turinglang.org/docs/usage/threadsafe-evaluation/ for more details of when threadsafe evaluation is actually required." ) end return generate_mainbody!( diff --git a/src/model.jl b/src/model.jl index 9709f6e74..04c759652 100644 --- a/src/model.jl +++ b/src/model.jl @@ -135,9 +135,9 @@ outside of the parallel region is safe without needing to set `threadsafe=true`. It is also not needed for multithreaded sampling with AbstractMCMC's `MCMCThreads()`. -Setting `threadsafe` to `true` increases the overhead in evaluating the model. See -(https://turinglang.org/docs/THIS_DOESNT_EXIST_YET)[https://turinglang.org/docs/THIS_DOESNT_EXIST_YET] -for more details. +Setting `threadsafe` to `true` increases the overhead in evaluating the model. Please see +[the Turing.jl docs](https://turinglang.org/docs/usage/threadsafe-evaluation/) for more +details. """ function setthreadsafe(model::Model{F,A,D,M}, threadsafe::Bool) where {F,A,D,M} return if _requires_threadsafe(model) == threadsafe From f8c45c27e00b6db5f39d1faa062a138c4aaf847d Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Thu, 27 Nov 2025 13:45:48 +0000 Subject: [PATCH 11/16] Add a test for setthreadsafe --- test/threadsafe.jl | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/test/threadsafe.jl b/test/threadsafe.jl index 8bab2ba3e..eebaec7ac 100644 --- a/test/threadsafe.jl +++ b/test/threadsafe.jl @@ -12,6 +12,16 @@ @test all(accs == expected_accs for accs in threadsafe_vi.accs_by_thread) end + @testset "setthreadsafe" begin + @model f() = x ~ Normal() + model = f() + @test !DynamicPPL._requires_threadsafe(model) + model = setthreadsafe(model, true) + @test DynamicPPL._requires_threadsafe(model) + model = setthreadsafe(model, false) + @test !DynamicPPL._requires_threadsafe(model) + end + # TODO: Add more tests of the public API @testset "API" begin vi = VarInfo(gdemo_default) @@ -41,8 +51,6 @@ end @testset "model" begin - println("Peforming threading tests with $(Threads.nthreads()) threads") - x = rand(10_000) @model function wthreads(x) From 9d0edc433f5ae35e5df819b98a70afe629065d47 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Thu, 27 Nov 2025 14:03:31 +0000 Subject: [PATCH 12/16] Tidy up check_model --- src/debug_utils.jl | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/debug_utils.jl b/src/debug_utils.jl index 40a3cb3c1..8810b9819 100644 --- a/src/debug_utils.jl +++ b/src/debug_utils.jl @@ -426,9 +426,8 @@ function check_model_and_trace( # TODO(penelopeysm): Implement merge, etc. for DebugAccumulator, and then perform a # check on the merged accumulator, rather than checking it in the accumulate_assume - # calls. That way we can also support multi-threaded evaluation and use `evaluate!!` - # here instead of `_evaluate!!`. - _, varinfo = DynamicPPL._evaluate!!(model, varinfo) + # calls. That way we can also correctly support multi-threaded evaluation. + _, varinfo = DynamicPPL.evaluate!!(model, varinfo) # Perform checks after evaluating the model. debug_acc = DynamicPPL.getacc(varinfo, Val(_DEBUG_ACC_NAME)) From 8ef33556d5a2c05096fd99912d70abb5cee12dea Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Mon, 1 Dec 2025 10:59:16 +0000 Subject: [PATCH 13/16] Apply suggestions from code review Fix outdated docstrings Co-authored-by: Markus Hauru --- src/model.jl | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/src/model.jl b/src/model.jl index 04c759652..7bf06d39d 100644 --- a/src/model.jl +++ b/src/model.jl @@ -922,9 +922,6 @@ If `init_strategy` is not provided, defaults to `InitFromPrior()`. Returns a tuple of the model's return value, plus the updated `varinfo` object. """ function init!!( - # Note that this `@inline` is mandatory for performance, especially for - # LogDensityFunction. If it's not inlined, it leads to extra allocations (even for - # trivial models) and much slower runtime. rng::Random.AbstractRNG, model::Model, vi::AbstractVarInfo, @@ -937,7 +934,6 @@ end function init!!( model::Model, vi::AbstractVarInfo, strategy::AbstractInitStrategy=InitFromPrior() ) - # This `@inline` is also mandatory for performance return init!!(Random.default_rng(), model, vi, strategy) end @@ -963,8 +959,8 @@ function AbstractPPL.evaluate!!(model::Model, varinfo::AbstractVarInfo) # messes with cases like using Float32 of logprobs and Float64 for x. Also, this is just # plain ugly and hacky. # The below line is finicky for type stability. For instance, assigning the eltype to - # convert to into an intermediate variable makes this unstable (constant propagation) - # fails. Take care when editing. + # convert to into an intermediate variable makes this unstable (constant propagation + # fails). Take care when editing. param_eltype = DynamicPPL.get_param_eltype(varinfo, model.context) accs = map(DynamicPPL.getaccs(varinfo)) do acc DynamicPPL.convert_eltype(float_type_with_fallback(param_eltype), acc) From 180ced8875441f65c6d1bdddb913751602969acf Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Mon, 1 Dec 2025 11:01:30 +0000 Subject: [PATCH 14/16] Improve warning message --- src/compiler.jl | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/src/compiler.jl b/src/compiler.jl index d540cec12..1b4260121 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -301,7 +301,7 @@ function model(mod, linenumbernode, expr, warn) modeldef = build_model_definition(expr) # Generate main body - modeldef[:body] = generate_mainbody(mod, modeldef[:body], warn, false) + modeldef[:body] = generate_mainbody(mod, modeldef[:body], warn, true) return build_output(modeldef, linenumbernode) end @@ -377,14 +377,15 @@ function generate_mainbody!(mod, found, expr::Expr, warn, warn_threads) if ( expr.args[1] == Symbol("@threads") || expr.args[1] == Expr(:., :Threads, QuoteNode(Symbol("@threads"))) && - !warn_threads + warn_threads ) - warn_threads = true + warn_threads = false @warn ( "It looks like you are using `Threads.@threads` in your model definition." * "\n\nNote that since version 0.39 of DynamicPPL, threadsafe evaluation of models is disabled by default." * " If you need it, you will need to explicitly enable it by creating the model, and then running `model = setthreadsafe(model, true)`." * - "\n\nAvoiding threadsafe evaluation can often lead to significant performance improvements. Please see https://turinglang.org/docs/usage/threadsafe-evaluation/ for more details of when threadsafe evaluation is actually required." + "\n\nThreadsafe model evaluation is only needed when parallelising tilde-statements (not arbitrary Julia code), and avoiding it can often lead to significant performance improvements." * + "\n\nPlease see https://turinglang.org/docs/usage/threadsafe-evaluation/ for more details of when threadsafe evaluation is actually required." ) end return generate_mainbody!( From 79b5ca096255a388da8ca7aa24d86698b5aa1f05 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Mon, 1 Dec 2025 11:03:50 +0000 Subject: [PATCH 15/16] Export `requires_threadsafe` --- HISTORY.md | 4 ++++ docs/src/api.md | 1 + src/DynamicPPL.jl | 1 + src/model.jl | 8 ++++---- test/threadsafe.jl | 6 +++--- 5 files changed, 13 insertions(+), 7 deletions(-) diff --git a/HISTORY.md b/HISTORY.md index f70aab244..5dcb008d1 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -48,6 +48,10 @@ If you have none of these inside threaded blocks, then you do not need to mark y For more information about threadsafe evaluation, please see [the Turing docs](https://turinglang.org/docs/usage/threadsafe-evaluation/). +When threadsafe evaluation is enabled for a model, an internal flag is set on the model. +The value of this flag can be queried using `DynamicPPL.requires_threadsafe(model)`, which returns a boolean. +This function is newly exported in this version of DynamicPPL. + #### Parent and leaf contexts The `DynamicPPL.NodeTrait` function has been removed. diff --git a/docs/src/api.md b/docs/src/api.md index eaba4de82..193a6ce4c 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -47,6 +47,7 @@ If this is the case, one must enable threadsafe evaluation for a model: ```@docs setthreadsafe +requires_threadsafe ``` ## Evaluation diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index 97664f952..fda428eaa 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -91,6 +91,7 @@ export AbstractVarInfo, getmissings, getargnames, setthreadsafe, + requires_threadsafe, extract_priors, values_as_in_model, # evaluation diff --git a/src/model.jl b/src/model.jl index 7bf06d39d..860274b04 100644 --- a/src/model.jl +++ b/src/model.jl @@ -94,7 +94,7 @@ function Model{Threaded}( return Model{Threaded}(f, args, NamedTuple(kwargs), context) end -function _requires_threadsafe( +function requires_threadsafe( ::Model{F,A,D,M,Ta,Td,Ctx,Threaded} ) where {F,A,D,M,Ta,Td,Ctx,Threaded} return Threaded @@ -107,7 +107,7 @@ Return a new `Model` with the same evaluation function and other arguments, but with its underlying context set to `context`. """ function contextualize(model::Model, context::AbstractContext) - return Model{_requires_threadsafe(model)}(model.f, model.args, model.defaults, context) + return Model{requires_threadsafe(model)}(model.f, model.args, model.defaults, context) end """ @@ -140,7 +140,7 @@ Setting `threadsafe` to `true` increases the overhead in evaluating the model. P details. """ function setthreadsafe(model::Model{F,A,D,M}, threadsafe::Bool) where {F,A,D,M} - return if _requires_threadsafe(model) == threadsafe + return if requires_threadsafe(model) == threadsafe model else Model{threadsafe,M}(model.f, model.args, model.defaults, model.context) @@ -949,7 +949,7 @@ Returns a tuple of the model's return value, plus the updated `varinfo` (unwrapped if necessary). """ function AbstractPPL.evaluate!!(model::Model, varinfo::AbstractVarInfo) - return if _requires_threadsafe(model) + return if requires_threadsafe(model) # Use of float_type_with_fallback(eltype(x)) is necessary to deal with cases where x is # a gradient type of some AD backend. # TODO(mhauru) How could we do this more cleanly? The problem case is map_accumulator!! diff --git a/test/threadsafe.jl b/test/threadsafe.jl index eebaec7ac..879e936d6 100644 --- a/test/threadsafe.jl +++ b/test/threadsafe.jl @@ -15,11 +15,11 @@ @testset "setthreadsafe" begin @model f() = x ~ Normal() model = f() - @test !DynamicPPL._requires_threadsafe(model) + @test !DynamicPPL.requires_threadsafe(model) model = setthreadsafe(model, true) - @test DynamicPPL._requires_threadsafe(model) + @test DynamicPPL.requires_threadsafe(model) model = setthreadsafe(model, false) - @test !DynamicPPL._requires_threadsafe(model) + @test !DynamicPPL.requires_threadsafe(model) end # TODO: Add more tests of the public API From e839de88ab838d3db0e595793729a4abb1f97bde Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Mon, 1 Dec 2025 11:32:38 +0000 Subject: [PATCH 16/16] Add an actual docstring for `requires_threadsafe` --- src/model.jl | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/model.jl b/src/model.jl index 860274b04..e82fdc60c 100644 --- a/src/model.jl +++ b/src/model.jl @@ -94,6 +94,12 @@ function Model{Threaded}( return Model{Threaded}(f, args, NamedTuple(kwargs), context) end +""" + requires_threadsafe(model::Model) + +Return whether `model` has been marked as needing threadsafe evaluation (using +`setthreadsafe`). +""" function requires_threadsafe( ::Model{F,A,D,M,Ta,Td,Ctx,Threaded} ) where {F,A,D,M,Ta,Td,Ctx,Threaded}