Skip to content
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 34 additions & 1 deletion HISTORY.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,45 @@
This version provides a reimplementation of `LogDensityFunction` that provides performance improvements on the order of 2–10× for both model evaluation as well as automatic differentiation.
Exact speedups depend on the model size: larger models have less significant speedups because the bulk of the work is done in calls to `logpdf`.

For more information about how this is accomplished, please see https://github.com/TuringLang/DynamicPPL.jl/pull/1113 as well as the `src/fasteval.jl` file, which contains extensive comments.
For more information about how this is accomplished, please see https://github.com/TuringLang/DynamicPPL.jl/pull/1113 as well as the `src/logdensityfunction.jl` file, which contains extensive comments.

As a result of this change, `LogDensityFunction` no longer stores a VarInfo inside it.
In general, if `ldf` is a `LogDensityFunction`, it is now only valid to access `ldf.model` and `ldf.adtype`.
If you were previously relying on this behaviour, you will need to store a VarInfo separately.

#### Threadsafe evaluation

DynamicPPL models have traditionally supported running some probabilistic statements (e.g. tilde-statements, or `@addlogprob!`) in parallel.
Prior to DynamicPPL 0.39, thread safety for such models used to be enabled by default if Julia was launched with more than one thread.

In DynamicPPL 0.39, **thread-safe evaluation is now disabled by default**.
If you need it (see below for more discussion of when you _do_ need it), you **must** now manually mark it as so, using:

```julia
@model f() = ...
model = f()
model = setthreadsafe(model, true)
```

The problem with the previous on-by-default is that it can sacrifice a huge amount of performance when thread safety is not needed.
This is especially true when running Julia in a notebook, where multiple threads are often enabled by default.
Furthermore, it is not actually the correct approach: just because Julia has multiple threads does not mean that a particular model actually requires threadsafe evaluation.

**A model requires threadsafe evaluation if, and only if, the VarInfo object used inside the model is manipulated in parallel.**
This can occur if any of the following are inside `Threads.@threads` or other concurrency functions / macros:

- tilde-statements
- calls to `@addlogprob!`
- any direct manipulation of the special `__varinfo__` variable

If you have none of these inside threaded blocks, then you do not need to mark your model as threadsafe.
**Notably, the following do not require threadsafe evaluation:**

- Using threading for any computation that does not involve VarInfo. For example, you can calculate a log-probability in parallel, and then add it using `@addlogprob!` outside of the threaded block. This does not require threadsafe evaluation.
- Sampling with `AbstractMCMC.MCMCThreads()`.

For more information about threadsafe evaluation, please see [the Turing docs](https://turinglang.org/docs/usage/threadsafe-evaluation/).

#### Parent and leaf contexts

The `DynamicPPL.NodeTrait` function has been removed.
Expand Down
7 changes: 7 additions & 0 deletions docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,13 @@ The context of a model can be set using [`contextualize`](@ref):
contextualize
```

Some models require threadsafe evaluation (see [the Turing docs](https://turinglang.org/docs/usage/threadsafe-evaluation/) for more information on when this is necessary).
If this is the case, one must enable threadsafe evaluation for a model:

```@docs
setthreadsafe
```

## Evaluation

With [`rand`](@ref) one can draw samples from the prior distribution of a [`Model`](@ref).
Expand Down
1 change: 1 addition & 0 deletions src/DynamicPPL.jl
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ export AbstractVarInfo,
Model,
getmissings,
getargnames,
setthreadsafe,
extract_priors,
values_as_in_model,
# evaluation
Expand Down
53 changes: 39 additions & 14 deletions src/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,7 @@ function model(mod, linenumbernode, expr, warn)
modeldef = build_model_definition(expr)

# Generate main body
modeldef[:body] = generate_mainbody(mod, modeldef[:body], warn)
modeldef[:body] = generate_mainbody(mod, modeldef[:body], warn, false)

return build_output(modeldef, linenumbernode)
end
Expand Down Expand Up @@ -346,36 +346,58 @@ Generate the body of the main evaluation function from expression `expr` and arg
If `warn` is true, a warning is displayed if internal variables are used in the model
definition.
"""
generate_mainbody(mod, expr, warn) = generate_mainbody!(mod, Symbol[], expr, warn)
generate_mainbody(mod, expr, warn, warn_threads) =
generate_mainbody!(mod, Symbol[], expr, warn, warn_threads)

generate_mainbody!(mod, found, x, warn) = x
function generate_mainbody!(mod, found, sym::Symbol, warn)
generate_mainbody!(mod, found, x, warn, warn_threads) = x
function generate_mainbody!(mod, found, sym::Symbol, warn, warn_threads)
if warn && sym in INTERNALNAMES && sym ∉ found
@warn "you are using the internal variable `$sym`"
push!(found, sym)
end

return sym
end
function generate_mainbody!(mod, found, expr::Expr, warn)
function generate_mainbody!(mod, found, expr::Expr, warn, warn_threads)
# Do not touch interpolated expressions
expr.head === :$ && return expr.args[1]

# Flag to determine whether we've issued a warning for threadsafe macros Note that this
# detection is not fully correct. We can only detect the presence of a macro that has
# the symbol `Threads.@threads`, however, we can't detect if that *is actually*
# Threads.@threads from Base.Threads.

# Do we don't want escaped expressions because we unfortunately
# escape the entire body afterwards.
Meta.isexpr(expr, :escape) && return generate_mainbody(mod, found, expr.args[1], warn)
Meta.isexpr(expr, :escape) &&
return generate_mainbody(mod, found, expr.args[1], warn, warn_threads)

# If it's a macro, we expand it
if Meta.isexpr(expr, :macrocall)
return generate_mainbody!(mod, found, macroexpand(mod, expr; recursive=true), warn)
if (
expr.args[1] == Symbol("@threads") ||
expr.args[1] == Expr(:., :Threads, QuoteNode(Symbol("@threads"))) &&
!warn_threads
)
warn_threads = true
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Based on the name, I would have guessed this would work the other way around with true/false (I read "warn" as imperative).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, yes, let's flip it.

@warn (
"It looks like you are using `Threads.@threads` in your model definition." *
"\n\nNote that since version 0.39 of DynamicPPL, threadsafe evaluation of models is disabled by default." *
" If you need it, you will need to explicitly enable it by creating the model, and then running `model = setthreadsafe(model, true)`." *
"\n\nAvoiding threadsafe evaluation can often lead to significant performance improvements. Please see https://turinglang.org/docs/usage/threadsafe-evaluation/ for more details of when threadsafe evaluation is actually required."
)
end
return generate_mainbody!(
mod, found, macroexpand(mod, expr; recursive=true), warn, warn_threads
)
end

# Modify dotted tilde operators.
args_dottilde = getargs_dottilde(expr)
if args_dottilde !== nothing
L, R = args_dottilde
return generate_mainbody!(
mod, found, Base.remove_linenums!(generate_dot_tilde(L, R)), warn
mod, found, Base.remove_linenums!(generate_dot_tilde(L, R)), warn, warn_threads
)
end

Expand All @@ -385,8 +407,8 @@ function generate_mainbody!(mod, found, expr::Expr, warn)
L, R = args_tilde
return Base.remove_linenums!(
generate_tilde(
generate_mainbody!(mod, found, L, warn),
generate_mainbody!(mod, found, R, warn),
generate_mainbody!(mod, found, L, warn, warn_threads),
generate_mainbody!(mod, found, R, warn, warn_threads),
),
)
end
Expand All @@ -397,13 +419,16 @@ function generate_mainbody!(mod, found, expr::Expr, warn)
L, R = args_assign
return Base.remove_linenums!(
generate_assign(
generate_mainbody!(mod, found, L, warn),
generate_mainbody!(mod, found, R, warn),
generate_mainbody!(mod, found, L, warn, warn_threads),
generate_mainbody!(mod, found, R, warn, warn_threads),
),
)
end

return Expr(expr.head, map(x -> generate_mainbody!(mod, found, x, warn), expr.args)...)
return Expr(
expr.head,
map(x -> generate_mainbody!(mod, found, x, warn, warn_threads), expr.args)...,
)
end

function generate_assign(left, right)
Expand Down Expand Up @@ -699,7 +724,7 @@ function build_output(modeldef, linenumbernode)
# to the call site
modeldef[:body] = MacroTools.@q begin
$(linenumbernode)
return $(DynamicPPL.Model)($name, $args_nt; $(kwargs_inclusion...))
return $(DynamicPPL.Model){false}($name, $args_nt; $(kwargs_inclusion...))
end

return MacroTools.@q begin
Expand Down
7 changes: 5 additions & 2 deletions src/debug_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -424,8 +424,11 @@ function check_model_and_trace(
# Perform checks before evaluating the model.
issuccess = check_model_pre_evaluation(model)

# Force single-threaded execution.
_, varinfo = DynamicPPL.evaluate_threadunsafe!!(model, varinfo)
# TODO(penelopeysm): Implement merge, etc. for DebugAccumulator, and then perform a
# check on the merged accumulator, rather than checking it in the accumulate_assume
# calls. That way we can also support multi-threaded evaluation and use `evaluate!!`
# here instead of `_evaluate!!`.
_, varinfo = DynamicPPL._evaluate!!(model, varinfo)

# Perform checks after evaluating the model.
debug_acc = DynamicPPL.getacc(varinfo, Val(_DEBUG_ACC_NAME))
Expand Down
Loading
Loading