Skip to content

Commit f433446

Browse files
committed
Fix VI interface
1 parent e4d56e8 commit f433446

File tree

1 file changed

+45
-12
lines changed

1 file changed

+45
-12
lines changed

tutorials/variational-inference/index.qmd

Lines changed: 45 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,7 @@ m = linear_regression(train, train_label, n_obs, n_vars);
148148
To run VI, we must first set a *variational family*.
149149
For instance, the most commonly used family is the mean-field Gaussian family.
150150
For this, Turing provides functions that automatically construct the initialisation corresponding to the model `m`:
151+
151152
```{julia}
152153
q_init = q_meanfield_gaussian(m);
153154
```
@@ -161,10 +162,12 @@ Here is a detailed documentation for the constructor:
161162
As we can see, the precise initialisation can be customized through the keyword arguments.
162163

163164
Let's run VI with the default setting:
165+
164166
```{julia}
165167
n_iters = 1000
166-
q_avg, q_last, info, state = vi(m, q_init, n_iters; show_progress=false);
168+
q_avg, info, state = vi(m, q_init, n_iters; show_progress=false);
167169
```
170+
168171
The default setting uses the `AdvancedVI.RepGradELBO` objective, which corresponds to a variant of what is known as *automatic differentiation VI*[^KTRGB2017] or *stochastic gradient VI*[^TL2014] or *black-box VI*[^RGB2014] with the reparameterization gradient[^KW2014][^RMW2014][^TL2014].
169172
The default optimiser we use is `AdvancedVI.DoWG`[^KMJ2023] combined with a proximal operator.
170173
(The use of proximal operators with VI on a location-scale family is discussed in detail by J. Domke[^D2020][^DGG2023] and others[^KOWMG2023].)
@@ -178,8 +181,24 @@ First, here is the full documentation for `vi`:
178181
## Values Returned by `vi`
179182
The main output of the algorithm is `q_avg`, the average of the parameters generated by the optimisation algorithm.
180183
For computing `q_avg`, the default setting uses what is known as polynomial averaging[^SZ2013].
181-
Usually, `q_avg` will perform better than the last-iterate `q_last`.
184+
Usually, `q_avg` will perform better than the last-iterate `q_last`, which cana be obtained by disabling averaging:
185+
186+
```{julia}
187+
q_last, _, _ = vi(
188+
m,
189+
q_init,
190+
n_iters;
191+
show_progress=false,
192+
algorithm=KLMinRepGradDescent(
193+
AutoForwardDiff();
194+
operator=AdvancedVI.ClipScale(),
195+
averager=AdvancedVI.NoAveraging()
196+
),
197+
);
198+
```
199+
182200
For instance, we can compare the ELBO of the two:
201+
183202
```{julia}
184203
@info("Objective of q_avg and q_last",
185204
ELBO_q_avg = estimate_objective(AdvancedVI.RepGradELBO(32), q_avg, LogDensityFunction(m)),
@@ -194,6 +213,7 @@ For the default setting, which is `RepGradELBO`, it contains the ELBO estimated
194213
```{julia}
195214
Plots.plot([i.elbo for i in info], xlabel="Iterations", ylabel="ELBO", label="info")
196215
```
216+
197217
Since the ELBO is estimated by a small number of samples, it appears noisy.
198218
Furthermore, at each step, the ELBO is evaluated on `q_last`, not `q_avg`, which is the actual output that we care about.
199219
To obtain more accurate ELBO estimates evaluated on `q_avg`, we have to define a custom callback function.
@@ -203,30 +223,38 @@ To inspect the progress of optimisation in more detail, one can define a custom
203223
For example, the following callback function estimates the ELBO on `q_avg` every 10 steps with a larger number of samples:
204224

205225
```{julia}
206-
function callback(; stat, averaged_params, restructure, kwargs...)
207-
if mod(stat.iteration, 10) == 1
226+
using DynamicPPL: DynamicPPL
227+
228+
function callback(; iteration, averaged_params, restructure, kwargs...)
229+
if mod(iteration, 10) == 1
208230
q_avg = restructure(averaged_params)
209-
obj = AdvancedVI.RepGradELBO(128)
210-
elbo_avg = estimate_objective(obj, q_avg, LogDensityFunction(m))
231+
obj = AdvancedVI.RepGradELBO(128) # 128 samples for ELBO estimation
232+
vi = DynamicPPL.link!!(DynamicPPL.VarInfo(m), m);
233+
elbo_avg = -estimate_objective(obj, q_avg, LogDensityFunction(m, DynamicPPL.getlogjoint_internal, vi))
211234
(elbo_avg = elbo_avg,)
212235
else
213236
nothing
214237
end
215238
end;
216239
```
240+
217241
The `NamedTuple` returned by `callback` will be appended to the corresponding entry of `info`, and it will also be displayed on the progress meter if `show_progress` is set as `true`.
218242

219243
The custom callback can be supplied to `vi` as a keyword argument:
244+
220245
```{julia}
221-
q_mf, _, info_mf, _ = vi(m, q_init, n_iters; show_progress=false, callback=callback);
246+
q_mf, info_mf, _ = vi(m, q_init, n_iters; show_progress=false, callback=callback);
222247
```
223248

224249
Let's plot the result:
250+
225251
```{julia}
226252
iters = 1:10:length(info_mf)
227253
elbo_mf = [i.elbo_avg for i in info_mf[iters]]
228-
Plots.plot!(iters, elbo_mf, xlabel="Iterations", ylabel="ELBO", label="callback", ylims=(-200,Inf))
254+
Plots.plot([i.elbo for i in info], xlabel="Iterations", ylabel="ELBO", label="info", linewidth=0.4)
255+
Plots.plot!(iters, elbo_mf, xlabel="Iterations", ylabel="ELBO", label="callback", ylims=(-200,Inf), linewidth=2)
229256
```
257+
230258
We can see that the ELBO values are less noisy and progress more smoothly due to averaging.
231259

232260
## Using Different Optimisers
@@ -244,7 +272,7 @@ Since `AdvancedVI` does not implement a proximal operator for `Optimisers.Adam`,
244272
```{julia}
245273
using Optimisers
246274
247-
_, _, info_adam, _ = vi(
275+
_, info_adam, _ = vi(
248276
m, q_init, n_iters;
249277
show_progress=false,
250278
callback=callback,
@@ -265,6 +293,7 @@ That is, most common optimisers require some degree of tuning to perform better
265293
Due to this fact, they are referred to as parameter-free optimizers.
266294

267295
## Using Full-Rank Variational Families
296+
268297
So far, we have only used the mean-field Gaussian family.
269298
This, however, approximates the posterior covariance with a diagonal matrix.
270299
To model the full covariance matrix, we can use the *full-rank* Gaussian family[^TL2014][^KTRGB2017]:
@@ -283,7 +312,7 @@ This term, however, traditionally comes from the fact that full-rank families us
283312
In contrast to the mean-field family, the full-rank family will often result in more computation per optimisation step and slower convergence, especially in high dimensions:
284313

285314
```{julia}
286-
q_fr, _, info_fr, _ = vi(m, q_init_fr, n_iters; show_progress=false, callback)
315+
q_fr, info_fr, _ = vi(m, q_init_fr, n_iters; show_progress=false, callback)
287316
288317
Plots.plot(elbo_mf, xlabel="Iterations", ylabel="ELBO", label="Mean-Field", ylims=(-200, Inf))
289318
@@ -292,7 +321,7 @@ Plots.plot!(elbo_fr, xlabel="Iterations", ylabel="ELBO", label="Full-Rank", ylim
292321
```
293322

294323
However, we can see that the full-rank families achieve a higher ELBO in the end.
295-
Due to the relationship between the ELBO and the Kullback-Leibler divergence, this indicates that the full-rank covariance is much more accurate.
324+
Due to the relationship between the ELBO and the KullbackLeibler divergence, this indicates that the full-rank covariance is much more accurate.
296325
This trade-off between statistical accuracy and optimisation speed is often referred to as the *statistical-computational trade-off*.
297326
The fact that we can control this trade-off through the choice of variational family is a strength, rather than a limitation, of variational inference.
298327

@@ -342,26 +371,29 @@ avg[union(sym2range[:coefficients]...)]
342371
```
343372

344373
For further convenience, we can wrap the samples into a `Chains` object to summarise the results.
374+
345375
```{julia}
346376
varinf = Turing.DynamicPPL.VarInfo(m)
347377
vns_and_values = Turing.DynamicPPL.varname_and_value_leaves(Turing.DynamicPPL.values_as(varinf, OrderedDict))
348378
varnames = map(first, vns_and_values)
349379
vi_chain = Chains(reshape(z', (size(z,2), size(z,1), 1)), varnames)
350380
```
381+
351382
(Since we're drawing independent samples, we can simply ignore the ESS and Rhat metrics.)
352383
Unfortunately, extracting `varnames` is a bit verbose at the moment, but hopefully will become simpler in the near future.
353384

354385
Let's compare this against samples from `NUTS`:
355386

356387
```{julia}
357-
mcmc_chain = sample(m, NUTS(), 10_000, drop_warmup=true, progress=false);
388+
mcmc_chain = sample(m, NUTS(), 10_000; progress=false);
358389
359390
vi_mean = mean(vi_chain)[:, 2]
360391
mcmc_mean = mean(mcmc_chain, names(mcmc_chain, :parameters))[:, 2]
361392
362393
plot(mcmc_mean; xticks=1:1:length(mcmc_mean), label="mean of NUTS")
363394
plot!(vi_mean; label="mean of VI")
364395
```
396+
365397
That looks pretty good! But let's see how the predictive distributions looks for the two.
366398

367399
## Making Predictions
@@ -516,6 +548,7 @@ title!("MCMC (NUTS)")
516548
517549
plot(p1, p2, p3; layout=(1, 3), size=(900, 250), label="")
518550
```
551+
519552
We can see that the full-rank VI approximation is very close to the predictions from MCMC samples.
520553
Also, the coverage of full-rank VI and MCMC is much better the crude mean-field approximation.
521554

0 commit comments

Comments
 (0)