You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: usage/threadsafe-evaluation/index.qmd
+52-19Lines changed: 52 additions & 19 deletions
Original file line number
Diff line number
Diff line change
@@ -14,7 +14,7 @@ This page specificaly discusses Turing's support for threadsafe model evaluation
14
14
15
15
:::{.callout-note}
16
16
Please note that this is a rapidly-moving topic, and things may change in future releases of Turing.
17
-
If you are ever unsure about what works and doesn't, please don't hesitate to ask on Slack or Discourse (links can be found at the footer of this site)!
17
+
If you are ever unsure about what works and doesn't, please don't hesitate to ask on [Slack](https://julialang.slack.com/archives/CCYDC34A0) or [Discourse](https://discourse.julialang.org/c/domain/probprog/48)
In the next release of Turing, if you use tilde-observations inside threaded blocks, you will have to declare this upfront using:
105
+
Starting from DynamicPPL 0.39, if you use tilde-statements or `@addlogprob!` inside threaded blocks, you will have to declare this upfront using:
106
106
107
107
```julia
108
108
model =threaded_obs() | (; y =randn(N))
@@ -136,7 +136,14 @@ model = threaded_assume_bad(100)
136
136
model()
137
137
```
138
138
139
-
**Note, in particular, that this means that you cannot use `predict` to sample new data in parallel.**
139
+
**Note, in particular, that this means that you cannot currently use `predict` to sample new data in parallel.**
140
+
141
+
:::{.callout-note}
142
+
## Threaded `predict`
143
+
144
+
Support for threaded `predict` will be added in DynamicPPL 0.39 (see [this pull request](https://github.com/TuringLang/DynamicPPL.jl/pull/1130)).
145
+
:::
146
+
140
147
That is, even for `threaded_obs` where `y` was originally an observed term, you _cannot_ do:
141
148
142
149
```{julia}
@@ -148,13 +155,6 @@ pmodel = threaded_obs(N) # don't condition on data
148
155
predict(pmodel, chn)
149
156
```
150
157
151
-
152
-
:::{.callout-note}
153
-
## Threaded `predict`
154
-
155
-
Support for the above call to `predict` may land in the near future, with [this pull request](https://github.com/TuringLang/DynamicPPL.jl/pull/1130).
156
-
:::
157
-
158
158
## Alternatives to threaded observation
159
159
160
160
An alternative to using threaded observations is to manually calculate the log-likelihood term (which can be parallelised using any of Julia's standard mechanisms), and then _outside_ of the threaded block, [add it to the model using `@addlogprob!`]({{< meta usage-modifying-logprob >}}).
@@ -187,7 +187,12 @@ See [this Discourse post](https://discourse.julialang.org/t/parallelism-within-t
187
187
188
188
We make no promises about the use of tilde-statements _with_ these packages (indeed it will most likely error), but as long as you use them to only parallelise regular Julia code (i.e., not tilde-statements), they will work as intended.
189
189
190
-
One benefit of rewriting the model this way is that sampling from this model with `MCMCThreads()` will always be reproducible.
190
+
The main downside of this approach is:
191
+
192
+
1. You can't use conditioning syntax to provide data; it has to be passed as an argument or otherwise included inside the model.
193
+
2. You can't use `predict` to sample new data.
194
+
195
+
On the other hand, one benefit of rewriting the model this way is that sampling from this model with `MCMCThreads()` will always be reproducible.
This part will likely only be of interest to DynamicPPL developers and the very curious user.
229
236
:::
230
237
231
-
TODO: Something about metadata, accumulators, and TSVI.
238
+
### Why is VarInfo not threadsafe?
239
+
240
+
As alluded to above, the issue with threaded tilde-statements stems from the fact that these tilde-statements modify the VarInfo object used for model evaluation, leading to potential data races.
241
+
242
+
Traditionally, VarInfo objects contain both *metadata* as well as *accumulators*.
243
+
Metadata is where information about the random variables' values are stored.
244
+
It is a Dict-like structure, and pushing to it from multiple threads is therefore not threadsafe (Julia's `Dict` has similar limitations).
245
+
246
+
On the other hand, accumulators are used to store outputs of the model, such as log-probabilities
247
+
The way DynamicPPL's threadsafe evaluation works is to create one set of accumulators per thread, and then combine the results at the end of model evaluation.
248
+
249
+
In this way, any function call that _solely_ involving accumulators can be made threadsafe.
250
+
For example, this is why observations are supported: there is no need to modify metadata, and only the log-likelihood accumulator needs to be updated.
251
+
252
+
However, `assume` tilde-statements always modify the metadata, and thus cannot currently be made threadsafe.
253
+
254
+
### OnlyAccsVarInfo
255
+
256
+
As it happens, much of what is needed in DynamicPPL can be constructed such that they *only* rely on accumulators.
257
+
258
+
For example, as long as there is no need to *sample* new values of random variables, it is actually fine to completely omit the metadata object.
259
+
This is the case for `LogDensityFunction`: since values are provided as the input vector, there is no need to store it in metadata.
260
+
We need only calculate the associated log-prior probability, which is stored in an accumulator.
261
+
Thus, starting from DynamicPPL v0.39, `LogDensityFunction` itself is in fact completely threadsafe.
232
262
233
-
TODO: Say how OnlyAccsVarInfo and FastLDF changes this.
263
+
Technically speaking, this is achieved using `OnlyAccsVarInfo`, which is a subtype of `VarInfo` that only contains accumulators, and no metadata at all.
264
+
It implements enough of the `VarInfo` interface to be used in model evaluation, but will error if any functions attempt to modify or read its metadata.
234
265
235
-
Essentially, `predict(model, chn)` SHOULD work after #1130 because that uses OAVI, which doesn't have Metadata. It uses VAIMAcc to accumulate the values, but that is threadsafe as long as TSVI is used.
266
+
There is currently an ongoing push to use `OnlyAccsVarInfo` in as many settings as we possibly can.
267
+
For example, this is why `predict` will be threadsafe in DynamicPPL v0.39: instead of modifying metadata to store the predicted values, we store them inside a `ValuesAsInModelAccumulator` instead, and combine them at the end of evaluation.
236
268
237
-
FastLDF, _once constructed_, also works with threaded assume. The only problem is that to get the ranges and linked status it has to first generate a VarInfo, which cannot be done. But if there's a way to either manually provide the ranges OR use an accumulator instead to get the ranges/linked status, then it would straight up enable threaded assume with NUTS / any sampler that only uses FastLDF.
269
+
However, propagating these changes up to Turing will require a substantial amount of additional work, since there are many places in Turing which currently rely on a full VarInfo (with metadata).
270
+
See, e.g., [this PR](https://github.com/TuringLang/DynamicPPL.jl/pull/1154) for more information.
0 commit comments