Skip to content

Commit a1003ef

Browse files
2 parents b66190d + 2e701f8 commit a1003ef

File tree

3 files changed

+95
-27
lines changed

3 files changed

+95
-27
lines changed

README.md

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -59,12 +59,16 @@ If you found this library useful in academic research, please cite: [(arXiv link
5959

6060
(Also consider starring the project on GitHub.)
6161

62-
## See also
62+
## See also: other libraries in the JAX ecosystem
6363

64-
Neural networks: [Equinox](https://github.com/patrick-kidger/equinox).
64+
[Equinox](https://github.com/patrick-kidger/equinox): neural networks.
6565

66-
Type annotations and runtime checking for PyTrees and shape/dtype of JAX arrays: [jaxtyping](https://github.com/google/jaxtyping).
66+
[Optax](https://github.com/deepmind/optax): first-order gradient (SGD, Adam, ...) optimisers.
6767

68-
Computer vision models: [Eqxvision](https://github.com/paganpasta/eqxvision).
68+
[Lineax](https://github.com/google/lineax): linear solvers and linear least squares.
6969

70-
SymPy<->JAX conversion; train symbolic expressions via gradient descent: [sympy2jax](https://github.com/google/sympy2jax).
70+
[jaxtyping](https://github.com/google/jaxtyping): type annotations for shape/dtype of arrays.
71+
72+
[Eqxvision](https://github.com/paganpasta/eqxvision): computer vision models.
73+
74+
[sympy2jax](https://github.com/google/sympy2jax): SymPy<->JAX conversion; train symbolic expressions via gradient descent.

diffrax/solver/runge_kutta.py

Lines changed: 72 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
import equinox as eqx
66
import equinox.internal as eqxi
77
import jax
8-
import jax.flatten_util as jfu
98
import jax.lax as lax
109
import jax.numpy as jnp
1110
import jax.tree_util as jtu
@@ -218,7 +217,19 @@ class CalculateJacobian(metaclass=eqxi.ContainerMeta):
218217
# numerical behaviour since the iteration is close to 0. (Although we have
219218
# multiplied by the increment of the control, i.e. dt, which is small...)
220219
def _implicit_relation_f(fi, nonlinear_solve_args):
221-
diagonal, vf, prod, ti, yi_partial, args, control = nonlinear_solve_args
220+
# We pass stage_index, even without using it, so that custom nonlinear solvers
221+
# can special-case on the stage if they want to.
222+
(
223+
stage_index,
224+
diagonal,
225+
vf,
226+
prod,
227+
ti,
228+
yi_partial,
229+
args,
230+
control,
231+
) = nonlinear_solve_args
232+
del stage_index
222233
diff = (
223234
fi**ω
224235
- vf(ti, (yi_partial**ω + diagonal * prod(fi, control) ** ω).ω, args) ** ω
@@ -231,7 +242,11 @@ def _implicit_relation_k(ki, nonlinear_solve_args):
231242
# c.f:
232243
# https://github.com/SciML/DiffEqDevMaterials/blob/master/newton/output/main.pdf
233244
# (Bearing in mind that our ki is dt times smaller than theirs.)
234-
diagonal, vf_prod, ti, yi_partial, args, control = nonlinear_solve_args
245+
#
246+
# We pass stage_index, even without using it, so that custom nonlinear solvers
247+
# can special-case on the stage if they want to.
248+
stage_index, diagonal, vf_prod, ti, yi_partial, args, control = nonlinear_solve_args
249+
del stage_index
235250
diff = (
236251
ki**ω
237252
- vf_prod(ti, (yi_partial**ω + diagonal * ki**ω).ω, args, control) ** ω
@@ -732,6 +747,25 @@ def embed_c(tab):
732747
implicit_predictor = jnp.asarray(implicit_predictor)
733748
implicit_c = get_implicit(tableaus_c)
734749

750+
if implicit_term is None:
751+
implicit_vf = _unused
752+
implicit_prod = _unused
753+
implicit_vf_prod = _unused
754+
else:
755+
if eval_fs:
756+
assert f0 is not _unused
757+
implicit_vf = eqx.filter_closure_convert(implicit_term.vf, t0, y0, args)
758+
implicit_prod = eqx.filter_closure_convert(
759+
implicit_term.prod, get_implicit(f0), implicit_control
760+
)
761+
implicit_vf_prod = _unused
762+
else:
763+
implicit_vf = _unused
764+
implicit_prod = _unused
765+
implicit_vf_prod = eqx.filter_closure_convert(
766+
implicit_term.vf_prod, t0, y0, args, implicit_control
767+
)
768+
735769
#
736770
# Run the loop over stages. (This is what you signed up for, and it's taken us
737771
# several hundred lines of code just to get this far!)
@@ -791,11 +825,10 @@ def rk_stage(val):
791825
f_pred = jtu.tree_map(if_first_stage, f0_for_jac, f_pred)
792826
assert f0 is not _unused
793827
f_implicit_args = (
828+
stage_index,
794829
implicit_diagonal_i,
795-
eqx.filter_closure_convert(implicit_term.vf, t0, y0, args),
796-
eqx.filter_closure_convert(
797-
implicit_term.prod, f_pred, implicit_control
798-
),
830+
implicit_vf,
831+
implicit_prod,
799832
implicit_ti,
800833
yi_partial,
801834
args,
@@ -814,10 +847,9 @@ def rk_stage(val):
814847
# doesn't matter.
815848
k_pred = jtu.tree_map(if_first_stage, k0_for_jac, k_pred)
816849
k_implicit_args = (
850+
stage_index,
817851
implicit_diagonal_i,
818-
eqx.filter_closure_convert(
819-
implicit_term.vf_prod, t0, y0, args, implicit_control
820-
),
852+
implicit_vf_prod,
821853
implicit_ti,
822854
yi_partial,
823855
args,
@@ -954,23 +986,41 @@ def buffers(val):
954986
# For DIRK and SDIRK methods then the choice here doesn't matter; we compute
955987
# the Jacobian straight away.
956988
# For ESDIRK methods, this is the Jacobian of an explicit step.
957-
#
958-
# TODO: fix once we have more advanced nonlinear solvers.
959-
# Mildly hacky hardcoding for now.
960989
if eval_fs:
961990
assert f0 is not _unused
962-
struct = jax.eval_shape(lambda: jfu.ravel_pytree(get_implicit(f0))[0])
963-
jac_f = (
964-
jnp.eye(struct.size, dtype=struct.dtype),
965-
jnp.arange(struct.size, dtype=jnp.int32),
991+
f_implicit_args = (
992+
jnp.array(0),
993+
# zero diagonal == identity matrix as the Jacobian
994+
jnp.array(0.0, dtype=implicit_diagonal.dtype),
995+
implicit_vf,
996+
implicit_prod,
997+
t0,
998+
y0,
999+
args,
1000+
implicit_control,
1001+
)
1002+
jac_f = self.nonlinear_solver.jac(
1003+
_implicit_relation_f,
1004+
jtu.tree_map(jnp.zeros_like, get_implicit(f0)),
1005+
_filter_stop_gradient(f_implicit_args),
9661006
)
9671007
jac_k = _unused
9681008
else:
969-
struct = jax.eval_shape(lambda: jfu.ravel_pytree(y0)[0])
1009+
k_implicit_args = (
1010+
jnp.array(0),
1011+
# zero diagonal == identity matrix as the Jacobian
1012+
jnp.array(0.0, dtype=implicit_diagonal.dtype),
1013+
implicit_vf_prod,
1014+
t0,
1015+
y0,
1016+
args,
1017+
implicit_control,
1018+
)
9701019
jac_f = _unused
971-
jac_k = (
972-
jnp.eye(struct.size, dtype=struct.dtype),
973-
jnp.arange(struct.size, dtype=jnp.int32),
1020+
jac_k = self.nonlinear_solver.jac(
1021+
_implicit_relation_k,
1022+
jtu.tree_map(jnp.zeros_like, y0),
1023+
_filter_stop_gradient(k_implicit_args),
9741024
)
9751025
init_val = (
9761026
init_stage_index,
@@ -1065,7 +1115,7 @@ def __init_subclass__(cls, **kwargs):
10651115
diagonal = cls.tableau.a_diagonal[0]
10661116
assert (cls.tableau.a_diagonal == diagonal).all()
10671117

1068-
calculate_jacobian = CalculateJacobian.second_stage
1118+
calculate_jacobian = CalculateJacobian.first_stage
10691119

10701120

10711121
class AbstractESDIRK(AbstractDIRK):

docs/index.md

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,3 +46,17 @@ Here, `Dopri5` refers to the Dormand--Prince 5(4) numerical differential equatio
4646
## Next steps
4747

4848
Have a look at the [Getting Started](./usage/getting-started.md) page.
49+
50+
## See also: other libraries in the JAX ecosystem
51+
52+
[Equinox](https://github.com/patrick-kidger/equinox): neural networks.
53+
54+
[Optax](https://github.com/deepmind/optax): first-order gradient (SGD, Adam, ...) optimisers.
55+
56+
[Lineax](https://github.com/google/lineax): linear solvers and linear least squares.
57+
58+
[jaxtyping](https://github.com/google/jaxtyping): type annotations for shape/dtype of arrays.
59+
60+
[Eqxvision](https://github.com/paganpasta/eqxvision): computer vision models.
61+
62+
[sympy2jax](https://github.com/google/sympy2jax): SymPy<->JAX conversion; train symbolic expressions via gradient descent.

0 commit comments

Comments
 (0)