You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: docs/further_details/faq.md
+38-21Lines changed: 38 additions & 21 deletions
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -4,7 +4,7 @@
4
4
5
5
Try switching to 64-bit precision. (Instead of the 32-bit that is the default in JAX.) [See here](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#double-64bit-precision).
6
6
7
-
### I'm getting zero gradient for one of my model parameters.
7
+
### I'm getting a `CustomVJPException`.
8
8
9
9
This can happen if you use [`diffrax.BacksolveAdjoint`][] incorrectly.
10
10
@@ -14,39 +14,56 @@ Gradients will be computed for:
14
14
- Everything in the `y0` PyTree passed to `diffeqsolve(..., y0=y0)`.
15
15
- Everything in the `terms` PyTree passed to `diffeqsolve(terms, ...)`.
16
16
17
+
Attempting to compute gradients with respect to anything else will result in this exception.
17
18
18
19
!!! example
19
20
20
-
Gradients through `args` and `y0` are self-explanatory. Meanwhile, a common example of computing gradients through `terms` is if using an [Equinox](https://github.com/patrick-kidger/equinox) module to represent a parameterised vector field. For example:
21
+
Here is a minimal example of **wrong** code that will raise this exception.
21
22
22
23
```python
24
+
from diffrax import BacksolveAdjoint, diffeqsolve, Euler, ODETerm
23
25
import equinox as eqx
24
-
import diffrax
26
+
import jax.numpy as jnp
27
+
import jax.random as jr
25
28
26
-
class Func(eqx.Module):
27
-
mlp: eqx.nn.MLP
29
+
mlp = eqx.nn.MLP(1, 1, 8, 2, key=jr.PRNGKey(0))
28
30
29
-
def __call__(self, t, y, args):
30
-
return self.mlp(y)
31
+
@eqx.filter_jit
32
+
@eqx.filter_value_and_grad
33
+
def run(model):
34
+
def f(t, y, args): # `model` captured via closure; is not part of the `terms` PyTree.
35
+
return model(y)
36
+
sol = diffeqsolve(ODETerm(f), Euler(), 0, 1, 0.1, jnp.array([1.0]),
In this case `diffrax.ODETerm`, `Func` and `eqx.nn.MLP` are all PyTrees, so all of the parameters inside `mlp` are visible to `diffeqsolve` and it can compute gradients with respect to them.
43
+
!!! example
44
+
45
+
The corrected version of the previous example is as follows. In this case, the model is properly part of the PyTree structure of `terms`.
46
+
47
+
```python
48
+
from diffrax import BacksolveAdjoint, diffeqsolve, Euler, ODETerm
0 commit comments