Skip to content

Commit a98546d

Browse files
Added AbstractRungeKutta(scan_stages=...) to improve compile times
This involved rewriting pretty much the entirety of the solver!
1 parent af7fe6d commit a98546d

File tree

9 files changed

+648
-240
lines changed

9 files changed

+648
-240
lines changed

benchmarks/scan_stages.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
"""Benchmarks the effect of `diffrax.AbstractRungeKutta(scan_stages=...)`.
2+
3+
On my CPU-only machine:
4+
```
5+
bash> python scan_stages.py False
6+
Compile+run time 24.38062646985054
7+
Run time 0.0018830380868166685
8+
9+
bash> python scan_stages.py True
10+
Compile+run time 11.418417416978627
11+
Run time 0.0014536201488226652
12+
```
13+
"""
14+
15+
import functools as ft
16+
import timeit
17+
18+
import diffrax as dfx
19+
import equinox as eqx
20+
import fire
21+
import jax.numpy as jnp
22+
import jax.random as jr
23+
24+
25+
def _weight(in_, out, key):
26+
return [[w_ij for w_ij in w_i] for w_i in jr.normal(key, (out, in_))]
27+
28+
29+
class VectorField(eqx.Module):
30+
weights: list
31+
32+
def __init__(self, in_, out, width, depth, *, key):
33+
keys = jr.split(key, depth + 1)
34+
self.weights = [_weight(in_, width, keys[0])]
35+
for i in range(1, depth):
36+
self.weights.append(_weight(width, width, keys[i]))
37+
self.weights.append(_weight(width, out, keys[depth]))
38+
39+
def __call__(self, t, y, args):
40+
# Inefficient computation graph to make a toy example more expensive.
41+
y = [y_i for y_i in y]
42+
for w in self.weights:
43+
y = [sum(w_ij * y_j for w_ij, y_j in zip(w_i, y)) for w_i in w]
44+
return jnp.stack(y)
45+
46+
47+
def main(scan_stages):
48+
vf = VectorField(1, 1, 16, 2, key=jr.PRNGKey(0))
49+
term = dfx.ODETerm(vf)
50+
solver = dfx.Dopri8(scan_stages=scan_stages)
51+
stepsize_controller = dfx.PIDController(rtol=1e-3, atol=1e-6)
52+
t0 = 0
53+
t1 = 1
54+
dt0 = None
55+
56+
@eqx.filter_jit
57+
def solve(y0):
58+
return dfx.diffeqsolve(
59+
term, solver, t0, t1, dt0, y0, stepsize_controller=stepsize_controller
60+
)
61+
62+
solve_ = ft.partial(solve, jnp.array([1.0]))
63+
print("Compile+run time", timeit.timeit(solve_, number=1))
64+
print("Run time", timeit.timeit(solve_, number=1))
65+
66+
67+
fire.Fire(main)

benchmarks/scan_stages_cnf.py

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
"""Benchmarks the effect of `diffrax.AbstractRungeKutta(scan_stages=...)`.
2+
3+
On my CPU-only machine:
4+
```
5+
bash> python scan_stages_cnf.py --scan_stages=False --backsolve=False
6+
Compile+run time 79.18114789901301
7+
Run time 0.16631506383419037
8+
9+
bash> python scan_stages_cnf.py --scan_stages=False --backsolve=True
10+
Compile+run time 28.233896102989092
11+
Run time 0.021237157052382827
12+
13+
bash> python scan_stages_cnf.py --scan_stages=True --backsolve=False
14+
Compile+run time 37.9795492868870
15+
Run time 0.16300765215419233
16+
17+
bash> python scan_stages_cnf.py --scan_stages=True --backsolve=True
18+
Compile+run time 12.199542510090396
19+
Run time 0.024600893026217818
20+
```
21+
22+
(Not forgetting that --backsolve=True produces only approximate gradients, so the fact
23+
that it obtains better compile time and run time doesn't mean it's always the best
24+
choice.)
25+
"""
26+
27+
# This benchmark is adapted from
28+
# https://github.com/patrick-kidger/diffrax/issues/94#issuecomment-1140527134
29+
30+
import functools as ft
31+
import timeit
32+
33+
import diffrax
34+
import equinox as eqx
35+
import fire
36+
import jax
37+
import jax.nn as jnn
38+
import jax.numpy as jnp
39+
import jax.random as jr
40+
import jax.scipy as jsp
41+
42+
43+
def vector_field_prob(t, input, model):
44+
y, _ = input
45+
f, vjp_fn = jax.vjp(model, y)
46+
(size,) = y.shape
47+
eye = jnp.eye(size)
48+
(dfdy,) = jax.vmap(vjp_fn)(eye)
49+
logp = jnp.trace(dfdy)
50+
return f, logp
51+
52+
53+
@eqx.filter_vmap(args=(None, 0, None, None))
54+
def log_prob(model, y0, scan_stages, backsolve):
55+
term = diffrax.ODETerm(vector_field_prob)
56+
solver = diffrax.Dopri5(scan_stages=scan_stages)
57+
stepsize_controller = diffrax.PIDController(rtol=1.4e-8, atol=1.4e-8)
58+
if backsolve:
59+
adjoint = diffrax.BacksolveAdjoint()
60+
else:
61+
adjoint = diffrax.RecursiveCheckpointAdjoint()
62+
sol = diffrax.diffeqsolve(
63+
term,
64+
solver,
65+
t0=0.0,
66+
t1=0.5,
67+
dt0=0.05,
68+
y0=(y0, 0.0),
69+
args=model,
70+
stepsize_controller=stepsize_controller,
71+
adjoint=adjoint,
72+
)
73+
(y1,), (log_prob,) = sol.ys
74+
return log_prob + jsp.stats.norm.logpdf(y1).sum(0)
75+
76+
77+
@eqx.filter_jit
78+
@eqx.filter_grad
79+
def solve(model, inputs, scan_stages, backsolve):
80+
return -log_prob(model, inputs, scan_stages, backsolve).mean()
81+
82+
83+
def main(scan_stages, backsolve):
84+
mkey, dkey = jr.split(jr.PRNGKey(0), 2)
85+
model = eqx.nn.MLP(2, 2, 10, 2, activation=jnn.gelu, key=mkey)
86+
x = jr.normal(dkey, (256, 2))
87+
solve_ = ft.partial(solve, model, x, scan_stages, backsolve)
88+
print("Compile+run time", timeit.timeit(solve_, number=1))
89+
print("Run time", timeit.timeit(solve_, number=1))
90+
91+
92+
fire.Fire(main)

0 commit comments

Comments
 (0)