|
| 1 | +"""Benchmarks the effect of `diffrax.AbstractRungeKutta(scan_stages=...)`. |
| 2 | +
|
| 3 | +On my CPU-only machine: |
| 4 | +``` |
| 5 | +bash> python scan_stages_cnf.py --scan_stages=False --backsolve=False |
| 6 | +Compile+run time 79.18114789901301 |
| 7 | +Run time 0.16631506383419037 |
| 8 | +
|
| 9 | +bash> python scan_stages_cnf.py --scan_stages=False --backsolve=True |
| 10 | +Compile+run time 28.233896102989092 |
| 11 | +Run time 0.021237157052382827 |
| 12 | +
|
| 13 | +bash> python scan_stages_cnf.py --scan_stages=True --backsolve=False |
| 14 | +Compile+run time 37.9795492868870 |
| 15 | +Run time 0.16300765215419233 |
| 16 | +
|
| 17 | +bash> python scan_stages_cnf.py --scan_stages=True --backsolve=True |
| 18 | +Compile+run time 12.199542510090396 |
| 19 | +Run time 0.024600893026217818 |
| 20 | +``` |
| 21 | +
|
| 22 | +(Not forgetting that --backsolve=True produces only approximate gradients, so the fact |
| 23 | +that it obtains better compile time and run time doesn't mean it's always the best |
| 24 | +choice.) |
| 25 | +""" |
| 26 | + |
| 27 | +# This benchmark is adapted from |
| 28 | +# https://github.com/patrick-kidger/diffrax/issues/94#issuecomment-1140527134 |
| 29 | + |
| 30 | +import functools as ft |
| 31 | +import timeit |
| 32 | + |
| 33 | +import diffrax |
| 34 | +import equinox as eqx |
| 35 | +import fire |
| 36 | +import jax |
| 37 | +import jax.nn as jnn |
| 38 | +import jax.numpy as jnp |
| 39 | +import jax.random as jr |
| 40 | +import jax.scipy as jsp |
| 41 | + |
| 42 | + |
| 43 | +def vector_field_prob(t, input, model): |
| 44 | + y, _ = input |
| 45 | + f, vjp_fn = jax.vjp(model, y) |
| 46 | + (size,) = y.shape |
| 47 | + eye = jnp.eye(size) |
| 48 | + (dfdy,) = jax.vmap(vjp_fn)(eye) |
| 49 | + logp = jnp.trace(dfdy) |
| 50 | + return f, logp |
| 51 | + |
| 52 | + |
| 53 | +@eqx.filter_vmap(args=(None, 0, None, None)) |
| 54 | +def log_prob(model, y0, scan_stages, backsolve): |
| 55 | + term = diffrax.ODETerm(vector_field_prob) |
| 56 | + solver = diffrax.Dopri5(scan_stages=scan_stages) |
| 57 | + stepsize_controller = diffrax.PIDController(rtol=1.4e-8, atol=1.4e-8) |
| 58 | + if backsolve: |
| 59 | + adjoint = diffrax.BacksolveAdjoint() |
| 60 | + else: |
| 61 | + adjoint = diffrax.RecursiveCheckpointAdjoint() |
| 62 | + sol = diffrax.diffeqsolve( |
| 63 | + term, |
| 64 | + solver, |
| 65 | + t0=0.0, |
| 66 | + t1=0.5, |
| 67 | + dt0=0.05, |
| 68 | + y0=(y0, 0.0), |
| 69 | + args=model, |
| 70 | + stepsize_controller=stepsize_controller, |
| 71 | + adjoint=adjoint, |
| 72 | + ) |
| 73 | + (y1,), (log_prob,) = sol.ys |
| 74 | + return log_prob + jsp.stats.norm.logpdf(y1).sum(0) |
| 75 | + |
| 76 | + |
| 77 | +@eqx.filter_jit |
| 78 | +@eqx.filter_grad |
| 79 | +def solve(model, inputs, scan_stages, backsolve): |
| 80 | + return -log_prob(model, inputs, scan_stages, backsolve).mean() |
| 81 | + |
| 82 | + |
| 83 | +def main(scan_stages, backsolve): |
| 84 | + mkey, dkey = jr.split(jr.PRNGKey(0), 2) |
| 85 | + model = eqx.nn.MLP(2, 2, 10, 2, activation=jnn.gelu, key=mkey) |
| 86 | + x = jr.normal(dkey, (256, 2)) |
| 87 | + solve_ = ft.partial(solve, model, x, scan_stages, backsolve) |
| 88 | + print("Compile+run time", timeit.timeit(solve_, number=1)) |
| 89 | + print("Run time", timeit.timeit(solve_, number=1)) |
| 90 | + |
| 91 | + |
| 92 | +fire.Fire(main) |
0 commit comments