Skip to content

Commit d1c8e79

Browse files
Merge pull request #167 from patrick-kidger/jax-experimental-ode-faq
Add comparison to `jax.experimental.ode`
2 parents 6d8a6ac + 430e31c commit d1c8e79

File tree

1 file changed

+15
-0
lines changed

1 file changed

+15
-0
lines changed

docs/further_details/faq.md

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,21 @@ If you're using a Runge--Kutta method like [`diffrax.Dopri5`][] etc., then try s
88

99
Try switching to 64-bit precision. (Instead of the 32-bit that is the default in JAX.) [See here](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#double-64bit-precision).
1010

11+
### How does this compare to `jax.experimental.ode.odeint`?
12+
13+
The equivalent solver in Diffrax is:
14+
```python
15+
diffeqsolve(
16+
...,
17+
solver=Dopri5(scan_stages=True),
18+
stepsize_controller=PIDController(rtol=1.4e-8, atol=1.4e-8),
19+
adjoint=BacksolveAdjoint(),
20+
max_steps=None,
21+
)
22+
```
23+
24+
In practice, `TSit5` is usually a better solver than `Dopri5`. And the default adjoint method (`RecursiveCheckpointAdjoint`) is usually a better choice than `BacksolveAdjoint`.
25+
1126
### I'm getting a `CustomVJPException`.
1227

1328
This can happen if you use [`diffrax.BacksolveAdjoint`][] incorrectly.

0 commit comments

Comments
 (0)