Skip to content

Commit aa27945

Browse files
Merge pull request #127 from patrick-kidger/backsolve-docs
Updated BacksolveAdjoint docs
2 parents eefce55 + 67c3bd0 commit aa27945

File tree

3 files changed

+99
-23
lines changed

3 files changed

+99
-23
lines changed

diffrax/adjoint.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -276,10 +276,17 @@ class BacksolveAdjoint(AbstractAdjoint):
276276
"optimise-then-discretise", the "continuous adjoint method" or simply the "adjoint
277277
method".
278278
279-
This method implies very low memory usage, but is usually relatively slow, and the
279+
This method implies very low memory usage, but the
280280
computed gradients will only be approximate. As such other methods are generally
281281
preferred unless exceeding memory is a concern.
282282
283+
This will compute gradients with respect to the `terms`, `y0` and `args` arguments
284+
passed to [`diffrax.diffeqsolve`][]. If you attempt to compute gradients with
285+
respect to anything else (for example `t0`, or arguments passed via closure), then
286+
a `CustomVJPException` will be raised. See also
287+
[this FAQ](../../further_details/faq/#im-getting-a-customvjpexception)
288+
entry.
289+
283290
!!! note
284291
285292
This was popularised by [this paper](https://arxiv.org/abs/1806.07366). For
@@ -290,7 +297,7 @@ class BacksolveAdjoint(AbstractAdjoint):
290297
291298
Using this method prevents computing forward-mode autoderivatives of
292299
[`diffrax.diffeqsolve`][]. (That is to say, `jax.jvp` will not work.)
293-
"""
300+
""" # noqa: E501
294301

295302
kwargs: Dict[str, Any]
296303

docs/further_details/faq.md

Lines changed: 38 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
Try switching to 64-bit precision. (Instead of the 32-bit that is the default in JAX.) [See here](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#double-64bit-precision).
66

7-
### I'm getting zero gradient for one of my model parameters.
7+
### I'm getting a `CustomVJPException`.
88

99
This can happen if you use [`diffrax.BacksolveAdjoint`][] incorrectly.
1010

@@ -14,39 +14,56 @@ Gradients will be computed for:
1414
- Everything in the `y0` PyTree passed to `diffeqsolve(..., y0=y0)`.
1515
- Everything in the `terms` PyTree passed to `diffeqsolve(terms, ...)`.
1616

17+
Attempting to compute gradients with respect to anything else will result in this exception.
1718

1819
!!! example
1920

20-
Gradients through `args` and `y0` are self-explanatory. Meanwhile, a common example of computing gradients through `terms` is if using an [Equinox](https://github.com/patrick-kidger/equinox) module to represent a parameterised vector field. For example:
21+
Here is a minimal example of **wrong** code that will raise this exception.
2122

2223
```python
24+
from diffrax import BacksolveAdjoint, diffeqsolve, Euler, ODETerm
2325
import equinox as eqx
24-
import diffrax
26+
import jax.numpy as jnp
27+
import jax.random as jr
2528

26-
class Func(eqx.Module):
27-
mlp: eqx.nn.MLP
29+
mlp = eqx.nn.MLP(1, 1, 8, 2, key=jr.PRNGKey(0))
2830

29-
def __call__(self, t, y, args):
30-
return self.mlp(y)
31+
@eqx.filter_jit
32+
@eqx.filter_value_and_grad
33+
def run(model):
34+
def f(t, y, args): # `model` captured via closure; is not part of the `terms` PyTree.
35+
return model(y)
36+
sol = diffeqsolve(ODETerm(f), Euler(), 0, 1, 0.1, jnp.array([1.0]),
37+
adjoint=BacksolveAdjoint())
38+
return jnp.sum(sol.ys)
3139

32-
mlp = eqx.nn.MLP(...)
33-
func = Func(mlp)
34-
term = diffrax.ODETerm(func)
35-
diffrax.diffeqsolve(term, ..., adjoint=diffrax.BacksolveAdjoint())
40+
run(mlp)
3641
```
3742

38-
In this case `diffrax.ODETerm`, `Func` and `eqx.nn.MLP` are all PyTrees, so all of the parameters inside `mlp` are visible to `diffeqsolve` and it can compute gradients with respect to them.
43+
!!! example
44+
45+
The corrected version of the previous example is as follows. In this case, the model is properly part of the PyTree structure of `terms`.
46+
47+
```python
48+
from diffrax import BacksolveAdjoint, diffeqsolve, Euler, ODETerm
49+
import equinox as eqx
50+
import jax.numpy as jnp
51+
import jax.random as jr
3952

40-
However if you were to do:
53+
mlp = eqx.nn.MLP(1, 1, 8, 2, key=jr.PRNGKey(0))
4154

42-
```python
43-
model = ...
55+
class VectorField(eqx.Module):
56+
model: eqx.Module
4457

45-
def func(t, y, args):
46-
return model(y)
58+
def __call__(self, t, y, args):
59+
return self.model(y)
4760

48-
term = diffrax.ODETerm(func)
49-
diffrax.diffeqsolve(term, ..., adjoint=diffrax.BacksolveAdjoint())
50-
```
61+
@eqx.filter_jit
62+
@eqx.filter_value_and_grad
63+
def run(model):
64+
f = VectorField(model)
65+
sol = diffeqsolve(ODETerm(f), Euler(), 0, 1, 0.1, jnp.array([1.0]), adjoint=BacksolveAdjoint())
66+
return jnp.sum(sol.ys)
5167

52-
then the parameters of `model` are not visible to `diffeqsolve` so gradients will not be computed with respect to them.
68+
run(mlp)
69+
```

test/test_adjoint.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import equinox as eqx
55
import jax
66
import jax.numpy as jnp
7+
import jax.random as jrandom
78
import pytest
89

910
from .helpers import shaped_allclose
@@ -141,3 +142,54 @@ def solve(y0):
141142
return jnp.sum(sol.ys)
142143

143144
jax.grad(solve)(2.0)
145+
146+
147+
def test_closure_errors():
148+
mlp = eqx.nn.MLP(1, 1, 8, 2, key=jrandom.PRNGKey(0))
149+
150+
@eqx.filter_jit
151+
@eqx.filter_value_and_grad
152+
def run(model):
153+
def f(t, y, args):
154+
return model(y)
155+
156+
sol = diffrax.diffeqsolve(
157+
diffrax.ODETerm(f),
158+
diffrax.Euler(),
159+
0,
160+
1,
161+
0.1,
162+
jnp.array([1.0]),
163+
adjoint=diffrax.BacksolveAdjoint(),
164+
)
165+
return jnp.sum(sol.ys)
166+
167+
with pytest.raises(jax.interpreters.ad.CustomVJPException):
168+
run(mlp)
169+
170+
171+
def test_closure_fixed():
172+
mlp = eqx.nn.MLP(1, 1, 8, 2, key=jrandom.PRNGKey(0))
173+
174+
class VectorField(eqx.Module):
175+
model: eqx.Module
176+
177+
def __call__(self, t, y, args):
178+
return self.model(y)
179+
180+
@eqx.filter_jit
181+
@eqx.filter_value_and_grad
182+
def run(model):
183+
f = VectorField(model)
184+
sol = diffrax.diffeqsolve(
185+
diffrax.ODETerm(f),
186+
diffrax.Euler(),
187+
0,
188+
1,
189+
0.1,
190+
jnp.array([1.0]),
191+
adjoint=diffrax.BacksolveAdjoint(),
192+
)
193+
return jnp.sum(sol.ys)
194+
195+
run(mlp)

0 commit comments

Comments
 (0)