Skip to content

Commit 6af45a5

Browse files
Updated to latest Equinox (v0.10.8)
1 parent aadbeae commit 6af45a5

File tree

7 files changed

+50
-17
lines changed

7 files changed

+50
-17
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ _From a technical point of view, the internal structure of the library is pretty
2121
pip install diffrax
2222
```
2323

24-
Requires Python 3.9+, JAX 0.4.4+, and [Equinox](https://github.com/patrick-kidger/equinox) 0.10.4+.
24+
Requires Python 3.9+, JAX 0.4.13+, and [Equinox](https://github.com/patrick-kidger/equinox) 0.10.8+.
2525

2626
## Documentation
2727

diffrax/adjoint.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -493,7 +493,9 @@ def _loop_backsolve(y__args__terms, *, self, throw, init_state, **kwargs):
493493
)
494494

495495

496-
def _loop_backsolve_fwd(y__args__terms, **kwargs):
496+
@_loop_backsolve.def_fwd
497+
def _loop_backsolve_fwd(perturbed, y__args__terms, **kwargs):
498+
del perturbed
497499
final_state, aux_stats = _loop_backsolve(y__args__terms, **kwargs)
498500
# Note that `final_state.save_state` has type `PyTree[SaveState]`; here we are
499501
# relying on the guard in `BacksolveAdjoint` that it have trivial structure.
@@ -502,9 +504,18 @@ def _loop_backsolve_fwd(y__args__terms, **kwargs):
502504
return (final_state, aux_stats), (ts, ys)
503505

504506

507+
def _materialise_none(y, grad_y):
508+
if grad_y is None and eqx.is_inexact_array(y):
509+
return jnp.zeros_like(y)
510+
else:
511+
return grad_y
512+
513+
514+
@_loop_backsolve.def_bwd
505515
def _loop_backsolve_bwd(
506516
residuals,
507517
grad_final_state__aux_stats,
518+
perturbed,
508519
y__args__terms,
509520
*,
510521
self,
@@ -525,13 +536,15 @@ def _loop_backsolve_bwd(
525536
# using them later.
526537
#
527538

528-
del init_state, t1
539+
del perturbed, init_state, t1
529540
ts, ys = residuals
530541
del residuals
531542
grad_final_state, _ = grad_final_state__aux_stats
532543
# Note that `grad_final_state.save_state` has type `PyTree[SaveState]`; here we are
533544
# relying on the guard in `BacksolveAdjoint` that it have trivial structure.
534545
grad_ys = grad_final_state.save_state.ys
546+
# We take the simple way out and don't try to handle symbolic zeros.
547+
grad_ys = jtu.tree_map(_materialise_none, ys, grad_ys)
535548
del grad_final_state, grad_final_state__aux_stats
536549
y, args, terms = y__args__terms
537550
del y__args__terms
@@ -662,9 +675,6 @@ def __get(__aug):
662675
return a_y1, a_diff_args1, a_diff_terms1
663676

664677

665-
_loop_backsolve.defvjp(_loop_backsolve_fwd, _loop_backsolve_bwd)
666-
667-
668678
class BacksolveAdjoint(AbstractAdjoint):
669679
"""Backpropagate through [`diffrax.diffeqsolve`][] by solving the continuous
670680
adjoint equations backwards-in-time. This is also sometimes known as

diffrax/term.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,13 @@ class ODETerm(AbstractTerm):
170170
vector_field: Callable[[Scalar, PyTree, PyTree], PyTree]
171171

172172
def vf(self, t: Scalar, y: PyTree, args: PyTree) -> PyTree:
173-
return self.vector_field(t, y, args)
173+
out = self.vector_field(t, y, args)
174+
if jtu.tree_structure(out) != jtu.tree_structure(y):
175+
raise ValueError(
176+
"The vector field inside `ODETerm` must return a pytree with the "
177+
"same structure as `y0`."
178+
)
179+
return jtu.tree_map(lambda o, yi: jnp.broadcast_to(o, jnp.shape(yi)), out, y)
174180

175181
@staticmethod
176182
def contr(t0: Scalar, t1: Scalar) -> Scalar:

docs/index.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ _From a technical point of view, the internal structure of the library is pretty
2020
pip install diffrax
2121
```
2222

23-
Requires Python 3.9+, JAX 0.4.4+, and [Equinox](https://github.com/patrick-kidger/equinox) 0.10.4+.
23+
Requires Python 3.9+, JAX 0.4.13+, and [Equinox](https://github.com/patrick-kidger/equinox) 0.10.8+.
2424

2525
## Quick example
2626

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@
4646

4747
python_requires = "~=3.9"
4848

49-
install_requires = ["jax>=0.4.3", "equinox>=0.10.4"]
49+
install_requires = ["jax>=0.4.13", "equinox>=0.10.8"]
5050

5151
setuptools.setup(
5252
name=name,

test/test_adjoint.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,17 @@ def _run(y0__args__term, saveat, adjoint):
5151
).ys
5252
)
5353

54+
def _run_finite_diff(y0__args__term, saveat, adjoint):
55+
y0, args, term = y0__args__term
56+
y0_a = y0 + jnp.array([1e-5, 0])
57+
y0_b = y0 + jnp.array([0, 1e-5])
58+
val = _run((y0, args, term), saveat, adjoint)
59+
val_a = _run((y0_a, args, term), saveat, adjoint)
60+
val_b = _run((y0_b, args, term), saveat, adjoint)
61+
out_a = (val_a - val) / 1e-5
62+
out_b = (val_b - val) / 1e-5
63+
return jnp.stack([out_a, out_b])
64+
5465
diff, nondiff = eqx.partition(y0__args__term, eqx.is_inexact_array)
5566
_run_grad = eqx.filter_jit(
5667
jax.grad(
@@ -85,11 +96,15 @@ def _convert_float0(x):
8596
continue
8697
saveat = diffrax.SaveAt(t0=t0, t1=t1, ts=ts)
8798

99+
fd_grads = _run_finite_diff(
100+
y0__args__term, saveat, diffrax.RecursiveCheckpointAdjoint()
101+
)
88102
direct_grads = _run_grad(diff, saveat, diffrax.DirectAdjoint())
89103
recursive_grads = _run_grad(
90104
diff, saveat, diffrax.RecursiveCheckpointAdjoint()
91105
)
92106
backsolve_grads = _run_grad(diff, saveat, diffrax.BacksolveAdjoint())
107+
assert shaped_allclose(fd_grads, direct_grads[0])
93108
assert shaped_allclose(direct_grads, recursive_grads, atol=1e-5)
94109
assert shaped_allclose(direct_grads, backsolve_grads, atol=1e-5)
95110

test/test_term.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -36,17 +36,18 @@ def derivative(self, t):
3636
term = diffrax.ControlTerm(vector_field, control)
3737
args = getkey()
3838
dx = term.contr(0, 1)
39-
vf = term.vf(0, None, args)
40-
vf_prod = term.vf_prod(0, None, args, dx)
39+
y = jnp.array([1.0, 2.0, 3.0])
40+
vf = term.vf(0, y, args)
41+
vf_prod = term.vf_prod(0, y, args, dx)
4142
assert dx.shape == (2,)
4243
assert vf.shape == (3, 2)
4344
assert vf_prod.shape == (3,)
4445
assert shaped_allclose(vf_prod, term.prod(vf, dx))
4546

4647
term = term.to_ode()
4748
dt = term.contr(0, 1)
48-
vf = term.vf(0, None, args)
49-
vf_prod = term.vf_prod(0, None, args, dt)
49+
vf = term.vf(0, y, args)
50+
vf_prod = term.vf_prod(0, y, args, dt)
5051
assert vf.shape == (3,)
5152
assert vf_prod.shape == (3,)
5253
assert shaped_allclose(vf_prod, term.prod(vf, dt))
@@ -70,17 +71,18 @@ def derivative(self, t):
7071
term = diffrax.WeaklyDiagonalControlTerm(vector_field, control)
7172
args = getkey()
7273
dx = term.contr(0, 1)
73-
vf = term.vf(0, None, args)
74-
vf_prod = term.vf_prod(0, None, args, dx)
74+
y = jnp.array([1.0, 2.0, 3.0])
75+
vf = term.vf(0, y, args)
76+
vf_prod = term.vf_prod(0, y, args, dx)
7577
assert dx.shape == (3,)
7678
assert vf.shape == (3,)
7779
assert vf_prod.shape == (3,)
7880
assert shaped_allclose(vf_prod, term.prod(vf, dx))
7981

8082
term = term.to_ode()
8183
dt = term.contr(0, 1)
82-
vf = term.vf(0, None, args)
83-
vf_prod = term.vf_prod(0, None, args, dt)
84+
vf = term.vf(0, y, args)
85+
vf_prod = term.vf_prod(0, y, args, dt)
8486
assert vf.shape == (3,)
8587
assert vf_prod.shape == (3,)
8688
assert shaped_allclose(vf_prod, term.prod(vf, dt))

0 commit comments

Comments
 (0)