You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
A common technique to speed up Julia code is to use multiple threads to run computations in parallel.
10
17
The Julia manual [has a section on multithreading](https://docs.julialang.org/en/v1/manual/multi-threading), which is a good introduction to the topic.
11
18
@@ -17,26 +24,21 @@ Please note that this is a rapidly-moving topic, and things may change in future
17
24
If you are ever unsure about what works and doesn't, please don't hesitate to ask on [Slack](https://julialang.slack.com/archives/CCYDC34A0) or [Discourse](https://discourse.julialang.org/c/domain/probprog/48)
18
25
:::
19
26
20
-
## MCMC sampling
21
-
22
-
For complete clarity, this page has nothing to do with parallel sampling of MCMC chains using
23
-
24
-
```julia
25
-
sample(model, sampler, MCMCThreads(), N, nchains)
27
+
```{julia}
28
+
println("This notebook is being run with $(Threads.nthreads()) threads.")
26
29
```
27
30
28
-
That parallelisation exists outside of the model evaluation, and thus is independent of the model contents.
29
-
This page only discusses threading _inside_ Turing models.
30
-
31
31
## Threading in Turing models
32
32
33
33
Given that Turing models mostly contain 'plain' Julia code, one might expect that all threading constructs such as `Threads.@threads` or `Threads.@spawn` can be used inside Turing models.
34
34
35
35
This is, to some extent, true: for example, you can use threading constructs to speed up deterministic computations.
36
36
For example, here we use parallelism to speed up a transformation of `x`:
37
37
38
-
```julia
39
-
@modelfunctionf(y)
38
+
```{julia}
39
+
using Turing
40
+
41
+
@model function parallel(y)
40
42
x ~ dist
41
43
x_transformed = similar(x)
42
44
Threads.@threads for i in eachindex(x)
@@ -48,8 +50,11 @@ end
48
50
49
51
In general, for code that does not involve tilde-statements (`x ~ dist`), threading works exactly as it does in regular Julia code.
50
52
51
-
**However, extra care must be taken when using tilde-statements (`x ~ dist`) inside threaded blocks.**
52
-
The reason for this is because tilde-statements modify the internal VarInfo object used for model evaluation.
53
+
**However, extra care must be taken when using tilde-statements (`x ~ dist`), or `@addlogprob!`, inside threaded blocks.**
54
+
55
+
::: {.callout-note}
56
+
## Why are tilde-statements special?
57
+
Tilde-statements are expanded by the `@model` macro into something that modifies the internal VarInfo object used for model evaluation.
and writing into `__varinfo__` is, _in general_, not threadsafe.
60
65
Thus, parallelising tilde-statements can lead to data races [as described in the Julia manual](https://docs.julialang.org/en/v1/manual/multi-threading/#Using-@threads-without-data-races).
66
+
:::
67
+
68
+
## Threaded observations
61
69
62
-
## Threaded tilde-statements
70
+
**As of version 0.42, Turing only supports the use of tilde-statements inside threaded blocks when these are observations (i.e., likelihood terms).**
63
71
64
-
**As of version 0.41, Turing only supports the use of tilde-statements inside threaded blocks when these are observations (i.e., likelihood terms).**
72
+
However, such models **must** be marked by the user as requiring threadsafe evaluation, using `setthreadsafe`.
65
73
66
74
This means that the following code is safe to use:
pmodel = setthreadsafe(threaded_obs(N), true) # don't condition on data
118
+
predict(pmodel, chn)
110
119
```
111
120
112
-
Then you can sample from `threadsafe_model` as before.
121
+
::: {.callout-warning}
122
+
## Previous versions
123
+
124
+
Up until Turing v0.41, you did not need to use `setthreadsafe` to enable threadsafe evaluation, and it was automatically enabled whenever Julia was launched with more than one thread.
125
+
126
+
There were several reasons for changing this: one major one is because threadsafe evaluation comes with a performance cost, which can sometimes be substantial (see below).
113
127
114
-
The reason for this change is because threadsafe evaluation comes with a performance cost, which can sometimes be substantial.
115
-
In the past, threadsafe evaluation was always enabled, i.e., this cost was *always* incurred whenever Julia was launched with more than one thread.
116
-
However, this is not an appropriate way to determine whether threadsafe evaluation is needed!
128
+
Furthermore, the number of threads is not an appropriate way to determine whether threadsafe evaluation is needed!
117
129
:::
118
130
131
+
## Threaded assumptions / sampling latent values
132
+
119
133
**On the other hand, parallelising the sampling of latent values is not supported.**
120
134
Attempting to do this will either error or give wrong results.
121
135
@@ -136,25 +150,72 @@ model = threaded_assume_bad(100)
136
150
model()
137
151
```
138
152
139
-
**Note, in particular, that this means that you cannot currently use `predict` to sample new data in parallel.**
153
+
## When is threadsafe evaluation really needed?
140
154
141
-
:::{.callout-note}
142
-
## Threaded `predict`
155
+
You only need to enable threadsafe evaluation if you are using tilde-statements or `@addlogprob!` inside threaded blocks.
143
156
144
-
Support for threaded `predict` will be added in DynamicPPL 0.39 (see [this pull request](https://github.com/TuringLang/DynamicPPL.jl/pull/1130)).
145
-
:::
157
+
Specifically, you do *not* need to enable threadsafe evaluation if:
158
+
159
+
- You have parallelism inside the model, but it does not involve tilde-statements or `@addlogprob!`.
160
+
161
+
```julia
162
+
@modelfunctionparallel_no_tilde(y)
163
+
x ~Normal()
164
+
fy =similar(y)
165
+
Threads.@threadsfor i ineachindex(y)
166
+
fy[i] =some_expensive_function(x, y[i])
167
+
end
168
+
end
169
+
# This does not need setthreadsafe
170
+
model =parallel_no_tilde(y)
171
+
```
172
+
173
+
- You are sampling from a model using `MCMCThreads()`, but the model itself does not contain any parallel tilde-statements or `@addlogprob!`.
174
+
175
+
```julia
176
+
@modelfunctionno_parallel(y)
177
+
x ~Normal()
178
+
y ~Normal(x)
179
+
end
180
+
181
+
# This does not need setthreadsafe
182
+
model =no_parallel(1.0)
183
+
chn =sample(model, NUTS(), MCMCThreads(), 100)
184
+
```
185
+
186
+
## Performance considerations
187
+
188
+
As described above, one of the major considerations behind the introduction of `setthreadsafe` is that threadsafe evaluation comes with a performance cost.
146
189
147
-
That is, even for `threaded_obs` where `y` was originally an observed term, you _cannot_ do:
190
+
Consider a simple model that does not use threading:
In previous versions of Turing, this cost would **always** be incurred whenever Julia was launched with multiple threads, even if the model did not use any threading at all!
218
+
158
219
## Alternatives to threaded observation
159
220
160
221
An alternative to using threaded observations is to manually calculate the log-likelihood term (which can be parallelised using any of Julia's standard mechanisms), and then _outside_ of the threaded block, [add it to the model using `@addlogprob!`]({{< meta usage-modifying-logprob >}}).
@@ -198,8 +259,10 @@ On the other hand, one benefit of rewriting the model this way is that sampling
198
259
using Random
199
260
N = 100
200
261
y = randn(N)
262
+
# Note that since `@addlogprob!` is outside of the threaded block, we don't
@@ -258,13 +321,13 @@ As it happens, much of what is needed in DynamicPPL can be constructed such that
258
321
For example, as long as there is no need to *sample* new values of random variables, it is actually fine to completely omit the metadata object.
259
322
This is the case for `LogDensityFunction`: since values are provided as the input vector, there is no need to store it in metadata.
260
323
We need only calculate the associated log-prior probability, which is stored in an accumulator.
261
-
Thus, starting from DynamicPPL v0.39, `LogDensityFunction` itself will in fact be completely threadsafe.
324
+
Thus, since DynamicPPL v0.39, `LogDensityFunction` itself is completely threadsafe.
262
325
263
326
Technically speaking, this is achieved using `OnlyAccsVarInfo`, which is a subtype of `VarInfo` that only contains accumulators, and no metadata at all.
264
327
It implements enough of the `VarInfo` interface to be used in model evaluation, but will error if any functions attempt to modify or read its metadata.
265
328
266
329
There is currently an ongoing push to use `OnlyAccsVarInfo` in as many settings as we possibly can.
267
-
For example, this is why `predict`will be threadsafe in DynamicPPL v0.39: instead of modifying metadata to store the predicted values, we store them inside a `ValuesAsInModelAccumulator` instead, and combine them at the end of evaluation.
330
+
For example, this is why `predict`is threadsafe in DynamicPPL v0.39: instead of modifying metadata to store the predicted values, we store them inside a `ValuesAsInModelAccumulator` instead, and combine them at the end of evaluation.
268
331
269
332
However, propagating these changes up to Turing will require a substantial amount of additional work, since there are many places in Turing which currently rely on a full VarInfo (with metadata).
270
333
See, e.g., [this PR](https://github.com/TuringLang/DynamicPPL.jl/pull/1154) for more information.
0 commit comments