Skip to content

Commit 0848518

Browse files
committed
Update threadsafe docs for v0.42
1 parent 9ecf0c3 commit 0848518

File tree

1 file changed

+110
-47
lines changed

1 file changed

+110
-47
lines changed

usage/threadsafe-evaluation/index.qmd

Lines changed: 110 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,13 @@ julia:
66
- "--threads=4"
77
---
88

9+
```{julia}
10+
#| echo: false
11+
#| output: false
12+
using Pkg;
13+
Pkg.instantiate();
14+
```
15+
916
A common technique to speed up Julia code is to use multiple threads to run computations in parallel.
1017
The Julia manual [has a section on multithreading](https://docs.julialang.org/en/v1/manual/multi-threading), which is a good introduction to the topic.
1118

@@ -17,26 +24,21 @@ Please note that this is a rapidly-moving topic, and things may change in future
1724
If you are ever unsure about what works and doesn't, please don't hesitate to ask on [Slack](https://julialang.slack.com/archives/CCYDC34A0) or [Discourse](https://discourse.julialang.org/c/domain/probprog/48)
1825
:::
1926

20-
## MCMC sampling
21-
22-
For complete clarity, this page has nothing to do with parallel sampling of MCMC chains using
23-
24-
```julia
25-
sample(model, sampler, MCMCThreads(), N, nchains)
27+
```{julia}
28+
println("This notebook is being run with $(Threads.nthreads()) threads.")
2629
```
2730

28-
That parallelisation exists outside of the model evaluation, and thus is independent of the model contents.
29-
This page only discusses threading _inside_ Turing models.
30-
3131
## Threading in Turing models
3232

3333
Given that Turing models mostly contain 'plain' Julia code, one might expect that all threading constructs such as `Threads.@threads` or `Threads.@spawn` can be used inside Turing models.
3434

3535
This is, to some extent, true: for example, you can use threading constructs to speed up deterministic computations.
3636
For example, here we use parallelism to speed up a transformation of `x`:
3737

38-
```julia
39-
@model function f(y)
38+
```{julia}
39+
using Turing
40+
41+
@model function parallel(y)
4042
x ~ dist
4143
x_transformed = similar(x)
4244
Threads.@threads for i in eachindex(x)
@@ -48,8 +50,11 @@ end
4850

4951
In general, for code that does not involve tilde-statements (`x ~ dist`), threading works exactly as it does in regular Julia code.
5052

51-
**However, extra care must be taken when using tilde-statements (`x ~ dist`) inside threaded blocks.**
52-
The reason for this is because tilde-statements modify the internal VarInfo object used for model evaluation.
53+
**However, extra care must be taken when using tilde-statements (`x ~ dist`), or `@addlogprob!`, inside threaded blocks.**
54+
55+
::: {.callout-note}
56+
## Why are tilde-statements special?
57+
Tilde-statements are expanded by the `@model` macro into something that modifies the internal VarInfo object used for model evaluation.
5358
Essentially, `x ~ dist` expands to something like
5459

5560
```julia
@@ -58,16 +63,17 @@ x, __varinfo__ = DynamicPPL.tilde_assume!!(..., __varinfo__)
5863

5964
and writing into `__varinfo__` is, _in general_, not threadsafe.
6065
Thus, parallelising tilde-statements can lead to data races [as described in the Julia manual](https://docs.julialang.org/en/v1/manual/multi-threading/#Using-@threads-without-data-races).
66+
:::
67+
68+
## Threaded observations
6169

62-
## Threaded tilde-statements
70+
**As of version 0.42, Turing only supports the use of tilde-statements inside threaded blocks when these are observations (i.e., likelihood terms).**
6371

64-
**As of version 0.41, Turing only supports the use of tilde-statements inside threaded blocks when these are observations (i.e., likelihood terms).**
72+
However, such models **must** be marked by the user as requiring threadsafe evaluation, using `setthreadsafe`.
6573

6674
This means that the following code is safe to use:
6775

6876
```{julia}
69-
using Turing
70-
7177
@model function threaded_obs(N)
7278
x ~ Normal()
7379
y = Vector{Float64}(undef, N)
@@ -78,13 +84,14 @@ end
7884
7985
N = 100
8086
y = randn(N)
81-
model = threaded_obs(N) | (; y = y)
87+
threadunsafe_model = threaded_obs(N) | (; y = y)
88+
threadsafe_model = setthreadsafe(threadunsafe_model, true)
8289
```
8390

8491
Evaluating this model is threadsafe, in that Turing guarantees to provide the correct result in functions such as:
8592

8693
```{julia}
87-
logjoint(model, (; x = 0.0))
94+
logjoint(threadsafe_model, (; x = 0.0))
8895
```
8996

9097
(we can compare with the true value)
@@ -93,29 +100,36 @@ logjoint(model, (; x = 0.0))
93100
logpdf(Normal(), 0.0) + sum(logpdf.(Normal(0.0), y))
94101
```
95102

96-
When sampling, you must disable model checking, but otherwise results will be correct:
103+
Note that if you do not use `setthreadsafe`, the above code may give wrong results, or even error:
97104

98105
```{julia}
99-
sample(model, NUTS(), 100; check_model=false, progress=false)
106+
logjoint(threadunsafe_model, (; x = 0.0))
100107
```
101108

102-
::: {.callout-warning}
103-
## Upcoming changes
109+
You can sample from this model and safely use functions such as `predict` or `returned`, as long as the model is always marked as threadsafe:
104110

105-
Starting from DynamicPPL 0.39, if you use tilde-statements or `@addlogprob!` inside threaded blocks, you will have to declare this upfront using:
111+
```{julia}
112+
model = setthreadsafe(threaded_obs(N) | (; y = y), true)
113+
chn = sample(model, NUTS(), 100; check_model=false, progress=false)
114+
```
106115

107-
```julia
108-
model = threaded_obs() | (; y = randn(N))
109-
threadsafe_model = setthreadsafe(model, true)
116+
```{julia}
117+
pmodel = setthreadsafe(threaded_obs(N), true) # don't condition on data
118+
predict(pmodel, chn)
110119
```
111120

112-
Then you can sample from `threadsafe_model` as before.
121+
::: {.callout-warning}
122+
## Previous versions
123+
124+
Up until Turing v0.41, you did not need to use `setthreadsafe` to enable threadsafe evaluation, and it was automatically enabled whenever Julia was launched with more than one thread.
125+
126+
There were several reasons for changing this: one major one is because threadsafe evaluation comes with a performance cost, which can sometimes be substantial (see below).
113127

114-
The reason for this change is because threadsafe evaluation comes with a performance cost, which can sometimes be substantial.
115-
In the past, threadsafe evaluation was always enabled, i.e., this cost was *always* incurred whenever Julia was launched with more than one thread.
116-
However, this is not an appropriate way to determine whether threadsafe evaluation is needed!
128+
Furthermore, the number of threads is not an appropriate way to determine whether threadsafe evaluation is needed!
117129
:::
118130

131+
## Threaded assumptions / sampling latent values
132+
119133
**On the other hand, parallelising the sampling of latent values is not supported.**
120134
Attempting to do this will either error or give wrong results.
121135

@@ -136,25 +150,72 @@ model = threaded_assume_bad(100)
136150
model()
137151
```
138152

139-
**Note, in particular, that this means that you cannot currently use `predict` to sample new data in parallel.**
153+
## When is threadsafe evaluation really needed?
140154

141-
:::{.callout-note}
142-
## Threaded `predict`
155+
You only need to enable threadsafe evaluation if you are using tilde-statements or `@addlogprob!` inside threaded blocks.
143156

144-
Support for threaded `predict` will be added in DynamicPPL 0.39 (see [this pull request](https://github.com/TuringLang/DynamicPPL.jl/pull/1130)).
145-
:::
157+
Specifically, you do *not* need to enable threadsafe evaluation if:
158+
159+
- You have parallelism inside the model, but it does not involve tilde-statements or `@addlogprob!`.
160+
161+
```julia
162+
@model function parallel_no_tilde(y)
163+
x ~ Normal()
164+
fy = similar(y)
165+
Threads.@threads for i in eachindex(y)
166+
fy[i] = some_expensive_function(x, y[i])
167+
end
168+
end
169+
# This does not need setthreadsafe
170+
model = parallel_no_tilde(y)
171+
```
172+
173+
- You are sampling from a model using `MCMCThreads()`, but the model itself does not contain any parallel tilde-statements or `@addlogprob!`.
174+
175+
```julia
176+
@model function no_parallel(y)
177+
x ~ Normal()
178+
y ~ Normal(x)
179+
end
180+
181+
# This does not need setthreadsafe
182+
model = no_parallel(1.0)
183+
chn = sample(model, NUTS(), MCMCThreads(), 100)
184+
```
185+
186+
## Performance considerations
187+
188+
As described above, one of the major considerations behind the introduction of `setthreadsafe` is that threadsafe evaluation comes with a performance cost.
146189

147-
That is, even for `threaded_obs` where `y` was originally an observed term, you _cannot_ do:
190+
Consider a simple model that does not use threading:
148191

149192
```{julia}
150-
#| error: true
151-
model = threaded_obs(N) | (; y = y)
152-
chn = sample(model, NUTS(), 100; check_model=false, progress=false)
193+
@model function gdemo()
194+
s ~ InverseGamma(2, 3)
195+
m ~ Normal(0, sqrt(s))
196+
1.5 ~ Normal(m, sqrt(s))
197+
2.0 ~ Normal(m, sqrt(s))
198+
end
199+
model_no_threadsafe = gdemo()
200+
model_threadsafe = setthreadsafe(gdemo(), true)
201+
```
153202

154-
pmodel = threaded_obs(N) # don't condition on data
155-
predict(pmodel, chn)
203+
One can see that evaluation of the threadsafe model is substantially slower:
204+
205+
```{julia}
206+
using Chairmarks, DynamicPPL
207+
208+
function benchmark_eval(m)
209+
vi = VarInfo(m)
210+
display(median(@be DynamicPPL.evaluate!!($m, $vi)))
211+
end
212+
213+
benchmark_eval(model_no_threadsafe)
214+
benchmark_eval(model_threadsafe)
156215
```
157216

217+
In previous versions of Turing, this cost would **always** be incurred whenever Julia was launched with multiple threads, even if the model did not use any threading at all!
218+
158219
## Alternatives to threaded observation
159220

160221
An alternative to using threaded observations is to manually calculate the log-likelihood term (which can be parallelised using any of Julia's standard mechanisms), and then _outside_ of the threaded block, [add it to the model using `@addlogprob!`]({{< meta usage-modifying-logprob >}}).
@@ -198,8 +259,10 @@ On the other hand, one benefit of rewriting the model this way is that sampling
198259
using Random
199260
N = 100
200261
y = randn(N)
262+
# Note that since `@addlogprob!` is outside of the threaded block, we don't
263+
# need to use `setthreadsafe`.
201264
model = threaded_obs_addlogprob(N, y)
202-
nuts_kwargs = (check_model=false, progress=false, verbose=false)
265+
nuts_kwargs = (progress=false, verbose=false)
203266
204267
chain1 = sample(Xoshiro(468), model, NUTS(), MCMCThreads(), 1000, 4; nuts_kwargs...)
205268
chain2 = sample(Xoshiro(468), model, NUTS(), MCMCThreads(), 1000, 4; nuts_kwargs...)
@@ -210,8 +273,8 @@ In contrast, the original `threaded_obs` (which used tilde inside `Threads.@thre
210273
(In principle, we would like to fix this bug, but we haven't yet investigated where it stems from.)
211274

212275
```{julia}
213-
model = threaded_obs(N) | (; y = y)
214-
nuts_kwargs = (check_model=false, progress=false, verbose=false)
276+
model = setthreadsafe(threaded_obs(N) | (; y = y), true)
277+
nuts_kwargs = (progress=false, verbose=false)
215278
chain1 = sample(Xoshiro(468), model, NUTS(), MCMCThreads(), 1000, 4; nuts_kwargs...)
216279
chain2 = sample(Xoshiro(468), model, NUTS(), MCMCThreads(), 1000, 4; nuts_kwargs...)
217280
mean(chain1[:x]), mean(chain2[:x]) # oops!
@@ -258,13 +321,13 @@ As it happens, much of what is needed in DynamicPPL can be constructed such that
258321
For example, as long as there is no need to *sample* new values of random variables, it is actually fine to completely omit the metadata object.
259322
This is the case for `LogDensityFunction`: since values are provided as the input vector, there is no need to store it in metadata.
260323
We need only calculate the associated log-prior probability, which is stored in an accumulator.
261-
Thus, starting from DynamicPPL v0.39, `LogDensityFunction` itself will in fact be completely threadsafe.
324+
Thus, since DynamicPPL v0.39, `LogDensityFunction` itself is completely threadsafe.
262325

263326
Technically speaking, this is achieved using `OnlyAccsVarInfo`, which is a subtype of `VarInfo` that only contains accumulators, and no metadata at all.
264327
It implements enough of the `VarInfo` interface to be used in model evaluation, but will error if any functions attempt to modify or read its metadata.
265328

266329
There is currently an ongoing push to use `OnlyAccsVarInfo` in as many settings as we possibly can.
267-
For example, this is why `predict` will be threadsafe in DynamicPPL v0.39: instead of modifying metadata to store the predicted values, we store them inside a `ValuesAsInModelAccumulator` instead, and combine them at the end of evaluation.
330+
For example, this is why `predict` is threadsafe in DynamicPPL v0.39: instead of modifying metadata to store the predicted values, we store them inside a `ValuesAsInModelAccumulator` instead, and combine them at the end of evaluation.
268331

269332
However, propagating these changes up to Turing will require a substantial amount of additional work, since there are many places in Turing which currently rely on a full VarInfo (with metadata).
270333
See, e.g., [this PR](https://github.com/TuringLang/DynamicPPL.jl/pull/1154) for more information.

0 commit comments

Comments
 (0)