Skip to content

Commit 91c5c18

Browse files
committed
Finish the last section
1 parent 0b80363 commit 91c5c18

File tree

1 file changed

+52
-19
lines changed

1 file changed

+52
-19
lines changed

usage/threadsafe-evaluation/index.qmd

Lines changed: 52 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ This page specificaly discusses Turing's support for threadsafe model evaluation
1414

1515
:::{.callout-note}
1616
Please note that this is a rapidly-moving topic, and things may change in future releases of Turing.
17-
If you are ever unsure about what works and doesn't, please don't hesitate to ask on Slack or Discourse (links can be found at the footer of this site)!
17+
If you are ever unsure about what works and doesn't, please don't hesitate to ask on [Slack](https://julialang.slack.com/archives/CCYDC34A0) or [Discourse](https://discourse.julialang.org/c/domain/probprog/48)
1818
:::
1919

2020
## MCMC sampling
@@ -102,7 +102,7 @@ sample(model, NUTS(), 100; check_model=false, progress=false)
102102
::: {.callout-warning}
103103
## Upcoming changes
104104

105-
In the next release of Turing, if you use tilde-observations inside threaded blocks, you will have to declare this upfront using:
105+
Starting from DynamicPPL 0.39, if you use tilde-statements or `@addlogprob!` inside threaded blocks, you will have to declare this upfront using:
106106

107107
```julia
108108
model = threaded_obs() | (; y = randn(N))
@@ -136,7 +136,14 @@ model = threaded_assume_bad(100)
136136
model()
137137
```
138138

139-
**Note, in particular, that this means that you cannot use `predict` to sample new data in parallel.**
139+
**Note, in particular, that this means that you cannot currently use `predict` to sample new data in parallel.**
140+
141+
:::{.callout-note}
142+
## Threaded `predict`
143+
144+
Support for threaded `predict` will be added in DynamicPPL 0.39 (see [this pull request](https://github.com/TuringLang/DynamicPPL.jl/pull/1130)).
145+
:::
146+
140147
That is, even for `threaded_obs` where `y` was originally an observed term, you _cannot_ do:
141148

142149
```{julia}
@@ -148,13 +155,6 @@ pmodel = threaded_obs(N) # don't condition on data
148155
predict(pmodel, chn)
149156
```
150157

151-
152-
:::{.callout-note}
153-
## Threaded `predict`
154-
155-
Support for the above call to `predict` may land in the near future, with [this pull request](https://github.com/TuringLang/DynamicPPL.jl/pull/1130).
156-
:::
157-
158158
## Alternatives to threaded observation
159159

160160
An alternative to using threaded observations is to manually calculate the log-likelihood term (which can be parallelised using any of Julia's standard mechanisms), and then _outside_ of the threaded block, [add it to the model using `@addlogprob!`]({{< meta usage-modifying-logprob >}}).
@@ -187,7 +187,12 @@ See [this Discourse post](https://discourse.julialang.org/t/parallelism-within-t
187187

188188
We make no promises about the use of tilde-statements _with_ these packages (indeed it will most likely error), but as long as you use them to only parallelise regular Julia code (i.e., not tilde-statements), they will work as intended.
189189

190-
One benefit of rewriting the model this way is that sampling from this model with `MCMCThreads()` will always be reproducible.
190+
The main downside of this approach is:
191+
192+
1. You can't use conditioning syntax to provide data; it has to be passed as an argument or otherwise included inside the model.
193+
2. You can't use `predict` to sample new data.
194+
195+
On the other hand, one benefit of rewriting the model this way is that sampling from this model with `MCMCThreads()` will always be reproducible.
191196

192197
```{julia}
193198
using Random
@@ -196,17 +201,19 @@ y = randn(N)
196201
model = threaded_obs_addlogprob(N, y)
197202
nuts_kwargs = (check_model=false, progress=false, verbose=false)
198203
199-
chain1 = sample(Xoshiro(468), model, NUTS(), MCMCThreads(), 100, 4; nuts_kwargs...)
200-
chain2 = sample(Xoshiro(468), model, NUTS(), MCMCThreads(), 100, 4; nuts_kwargs...)
204+
chain1 = sample(Xoshiro(468), model, NUTS(), MCMCThreads(), 1000, 4; nuts_kwargs...)
205+
chain2 = sample(Xoshiro(468), model, NUTS(), MCMCThreads(), 1000, 4; nuts_kwargs...)
201206
mean(chain1[:x]), mean(chain2[:x]) # should be identical
202207
```
203208

204209
In contrast, the original `threaded_obs` (which used tilde inside `Threads.@threads`) is not reproducible when using `MCMCThreads()`.
210+
(In principle, we would like to fix this bug, but we haven't yet investigated where it stems from.)
205211

206212
```{julia}
207213
model = threaded_obs(N) | (; y = y)
208-
chain1 = sample(Xoshiro(468), model, NUTS(), MCMCThreads(), 100, 4; nuts_kwargs...)
209-
chain2 = sample(Xoshiro(468), model, NUTS(), MCMCThreads(), 100, 4; nuts_kwargs...)
214+
nuts_kwargs = (check_model=false, progress=false, verbose=false)
215+
chain1 = sample(Xoshiro(468), model, NUTS(), MCMCThreads(), 1000, 4; nuts_kwargs...)
216+
chain2 = sample(Xoshiro(468), model, NUTS(), MCMCThreads(), 1000, 4; nuts_kwargs...)
210217
mean(chain1[:x]), mean(chain2[:x]) # oops!
211218
```
212219

@@ -228,10 +235,36 @@ In particular:
228235
This part will likely only be of interest to DynamicPPL developers and the very curious user.
229236
:::
230237

231-
TODO: Something about metadata, accumulators, and TSVI.
238+
### Why is VarInfo not threadsafe?
239+
240+
As alluded to above, the issue with threaded tilde-statements stems from the fact that these tilde-statements modify the VarInfo object used for model evaluation, leading to potential data races.
241+
242+
Traditionally, VarInfo objects contain both *metadata* as well as *accumulators*.
243+
Metadata is where information about the random variables' values are stored.
244+
It is a Dict-like structure, and pushing to it from multiple threads is therefore not threadsafe (Julia's `Dict` has similar limitations).
245+
246+
On the other hand, accumulators are used to store outputs of the model, such as log-probabilities
247+
The way DynamicPPL's threadsafe evaluation works is to create one set of accumulators per thread, and then combine the results at the end of model evaluation.
248+
249+
In this way, any function call that _solely_ involving accumulators can be made threadsafe.
250+
For example, this is why observations are supported: there is no need to modify metadata, and only the log-likelihood accumulator needs to be updated.
251+
252+
However, `assume` tilde-statements always modify the metadata, and thus cannot currently be made threadsafe.
253+
254+
### OnlyAccsVarInfo
255+
256+
As it happens, much of what is needed in DynamicPPL can be constructed such that they *only* rely on accumulators.
257+
258+
For example, as long as there is no need to *sample* new values of random variables, it is actually fine to completely omit the metadata object.
259+
This is the case for `LogDensityFunction`: since values are provided as the input vector, there is no need to store it in metadata.
260+
We need only calculate the associated log-prior probability, which is stored in an accumulator.
261+
Thus, starting from DynamicPPL v0.39, `LogDensityFunction` itself is in fact completely threadsafe.
232262

233-
TODO: Say how OnlyAccsVarInfo and FastLDF changes this.
263+
Technically speaking, this is achieved using `OnlyAccsVarInfo`, which is a subtype of `VarInfo` that only contains accumulators, and no metadata at all.
264+
It implements enough of the `VarInfo` interface to be used in model evaluation, but will error if any functions attempt to modify or read its metadata.
234265

235-
Essentially, `predict(model, chn)` SHOULD work after #1130 because that uses OAVI, which doesn't have Metadata. It uses VAIMAcc to accumulate the values, but that is threadsafe as long as TSVI is used.
266+
There is currently an ongoing push to use `OnlyAccsVarInfo` in as many settings as we possibly can.
267+
For example, this is why `predict` will be threadsafe in DynamicPPL v0.39: instead of modifying metadata to store the predicted values, we store them inside a `ValuesAsInModelAccumulator` instead, and combine them at the end of evaluation.
236268

237-
FastLDF, _once constructed_, also works with threaded assume. The only problem is that to get the ranges and linked status it has to first generate a VarInfo, which cannot be done. But if there's a way to either manually provide the ranges OR use an accumulator instead to get the ranges/linked status, then it would straight up enable threaded assume with NUTS / any sampler that only uses FastLDF.
269+
However, propagating these changes up to Turing will require a substantial amount of additional work, since there are many places in Turing which currently rely on a full VarInfo (with metadata).
270+
See, e.g., [this PR](https://github.com/TuringLang/DynamicPPL.jl/pull/1154) for more information.

0 commit comments

Comments
 (0)