Skip to content

Commit ea1bdc9

Browse files
Merge pull request #190 from patrick-kidger/internal4
Upgrade to `equinox.internal`
2 parents b847552 + ee39d82 commit ea1bdc9

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

45 files changed

+637
-933
lines changed

benchmarks/compile_times.py

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
import functools as ft
2+
import timeit
3+
4+
import diffrax as dfx
5+
import equinox as eqx
6+
import fire
7+
import jax
8+
import jax.numpy as jnp
9+
import jax.random as jr
10+
11+
12+
def _weight(in_, out, key):
13+
return [[w_ij for w_ij in w_i] for w_i in jr.normal(key, (out, in_))]
14+
15+
16+
class VectorField(eqx.Module):
17+
weights: list
18+
19+
def __init__(self, in_, out, width, depth, *, key):
20+
keys = jr.split(key, depth + 1)
21+
self.weights = [_weight(in_, width, keys[0])]
22+
for i in range(1, depth):
23+
self.weights.append(_weight(width, width, keys[i]))
24+
self.weights.append(_weight(width, out, keys[depth]))
25+
26+
def __call__(self, t, y, args):
27+
# Inefficient computation graph to make a toy example more expensive.
28+
y = [y_i for y_i in y]
29+
for w in self.weights:
30+
y = [sum(w_ij * y_j for w_ij, y_j in zip(w_i, y)) for w_i in w]
31+
return jnp.stack(y)
32+
33+
34+
def main(inline: bool, scan_stages: bool, grad: bool, adjoint: str):
35+
if adjoint == "direct":
36+
adjoint = dfx.DirectAdjoint()
37+
elif adjoint == "recursive":
38+
adjoint = dfx.RecursiveCheckpointAdjoint()
39+
elif adjoint == "backsolve":
40+
adjoint = dfx.BacksolveAdjoint()
41+
else:
42+
raise ValueError
43+
if grad:
44+
grad_decorator = jax.grad
45+
else:
46+
grad_decorator = lambda x: x
47+
48+
vf = VectorField(1, 1, 16, 2, key=jr.PRNGKey(0))
49+
if not inline:
50+
vf = eqx.internal.noinline(vf)
51+
term = dfx.ODETerm(vf)
52+
solver = dfx.Dopri8(scan_stages=scan_stages)
53+
stepsize_controller = dfx.PIDController(rtol=1e-3, atol=1e-6)
54+
t0 = 0
55+
t1 = 1
56+
dt0 = 0.01
57+
58+
@jax.jit
59+
@grad_decorator
60+
def solve(y0):
61+
sol = dfx.diffeqsolve(
62+
term,
63+
solver,
64+
t0,
65+
t1,
66+
dt0,
67+
y0,
68+
stepsize_controller=stepsize_controller,
69+
adjoint=adjoint,
70+
max_steps=16**2,
71+
)
72+
return jnp.sum(sol.ys)
73+
74+
solve_ = ft.partial(solve, jnp.array([1.0]))
75+
print("Compile+run time", timeit.timeit(solve_, number=1))
76+
print("Run time", timeit.timeit(solve_, number=1))
77+
78+
79+
if __name__ == "__main__":
80+
fire.Fire(main)

benchmarks/small_neural_ode.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import jax.nn as jnn
1010
import jax.numpy as jnp
1111
import jax.random as jrandom
12+
import numpy as np
1213
import torch
1314
import torchdiffeq
1415

@@ -173,10 +174,10 @@ def main(batch_size=64, t1=100, multiple=False, grad=False):
173174
with torch.no_grad():
174175
func_jax = neural_ode_diffrax.func.func
175176
func_torch = neural_ode_torch.func.func
176-
func_torch[0].weight.copy_(torch.tensor(func_jax.layers[0].weight.to_py()))
177-
func_torch[0].bias.copy_(torch.tensor(func_jax.layers[0].bias.to_py()))
178-
func_torch[2].weight.copy_(torch.tensor(func_jax.layers[1].weight.to_py()))
179-
func_torch[2].bias.copy_(torch.tensor(func_jax.layers[1].bias.to_py()))
177+
func_torch[0].weight.copy_(torch.tensor(np.asarray(func_jax.layers[0].weight)))
178+
func_torch[0].bias.copy_(torch.tensor(np.asarray(func_jax.layers[0].bias)))
179+
func_torch[2].weight.copy_(torch.tensor(np.asarray(func_jax.layers[1].weight)))
180+
func_torch[2].bias.copy_(torch.tensor(np.asarray(func_jax.layers[1].bias)))
180181

181182
y0_jax = jrandom.normal(jrandom.PRNGKey(1), (batch_size, 4))
182183
y0_torch = torch.tensor(y0_jax.to_py())

diffrax/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,4 +87,4 @@
8787
)
8888

8989

90-
__version__ = "0.2.1"
90+
__version__ = "0.2.2"

diffrax/adjoint.py

Lines changed: 108 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,67 @@
22
from typing import Any, Dict
33

44
import equinox as eqx
5+
import equinox.internal as eqxi
56
import jax.lax as lax
67
import jax.numpy as jnp
78
import jax.tree_util as jtu
9+
from equinox.internal import ω
810

9-
from .misc import implicit_jvp, nondifferentiable_output, ω
11+
from .misc import implicit_jvp
1012
from .saveat import SaveAt
1113
from .term import AbstractTerm, AdjointTerm
1214

1315

16+
def _is_none(x):
17+
return x is None
18+
19+
20+
def _no_transpose_final_state(final_state):
21+
y = eqxi.nondifferentiable_backward(final_state.y, name="y")
22+
tprev = eqxi.nondifferentiable_backward(final_state.tprev, name="tprev")
23+
tnext = eqxi.nondifferentiable_backward(final_state.tnext, name="tnext")
24+
solver_state = eqxi.nondifferentiable_backward(
25+
final_state.solver_state, name="solver_state"
26+
)
27+
controller_state = eqxi.nondifferentiable_backward(
28+
final_state.controller_state, name="controller_state"
29+
)
30+
ts = eqxi.nondifferentiable_backward(final_state.ts, name="ts")
31+
ys = final_state.ys
32+
dense_ts = eqxi.nondifferentiable_backward(final_state.dense_ts, name="dense_ts")
33+
dense_infos = eqxi.nondifferentiable_backward(
34+
final_state.dense_infos, name="dense_infos"
35+
)
36+
final_state = eqxi.nondifferentiable_backward(final_state) # no more specific name
37+
final_state = eqx.tree_at(
38+
lambda s: (
39+
s.y,
40+
s.tprev,
41+
s.tnext,
42+
s.solver_state,
43+
s.controller_state,
44+
s.ts,
45+
s.ys,
46+
s.dense_ts,
47+
s.dense_infos,
48+
),
49+
final_state,
50+
(
51+
y,
52+
tprev,
53+
tnext,
54+
solver_state,
55+
controller_state,
56+
ts,
57+
ys,
58+
dense_ts,
59+
dense_infos,
60+
),
61+
is_leaf=_is_none,
62+
)
63+
return final_state
64+
65+
1466
class AbstractAdjoint(eqx.Module):
1567
"""Abstract base class for all adjoint methods."""
1668

@@ -30,6 +82,8 @@ def loop(
3082
max_steps,
3183
throw,
3284
init_state,
85+
passed_solver_state,
86+
passed_controller_state,
3387
):
3488
"""Runs the main solve loop. Subclasses can override this to provide custom
3589
backpropagation behaviour; see for example the implementation of
@@ -69,27 +123,26 @@ class RecursiveCheckpointAdjoint(AbstractAdjoint):
69123
For most problems this is the preferred technique for backpropagating through a
70124
differential equation.
71125
72-
A binomial checkpointing scheme is used so that memory usage is low.
126+
In addition a binomial checkpointing scheme is used so that memory usage is low.
127+
(This checkpointing can increase compile time a bit, though.)
73128
"""
74129

75-
def loop(self, *, throw, **kwargs):
76-
del throw
130+
def loop(self, *, throw, passed_solver_state, passed_controller_state, **kwargs):
131+
del throw, passed_solver_state, passed_controller_state
77132
return self._loop_fn(**kwargs, is_bounded=True)
78133

79134

80135
class NoAdjoint(AbstractAdjoint):
81136
"""Disable backpropagation through [`diffrax.diffeqsolve`][].
82-
83137
Forward-mode autodifferentiation (`jax.jvp`) will continue to work as normal.
84-
85138
If you do not need to differentiate the results of [`diffrax.diffeqsolve`][] then
86139
this may sometimes improve the speed at which the differential equation is solved.
87140
"""
88141

89-
def loop(self, *, throw, **kwargs):
90-
del throw
142+
def loop(self, *, throw, passed_solver_state, passed_controller_state, **kwargs):
143+
del throw, passed_solver_state, passed_controller_state
91144
final_state, aux_stats = self._loop_fn(**kwargs, is_bounded=False)
92-
final_state = jtu.tree_map(nondifferentiable_output, final_state)
145+
final_state = eqxi.nondifferentiable_backward(final_state)
93146
return final_state, aux_stats
94147

95148

@@ -135,7 +188,19 @@ class ImplicitAdjoint(AbstractAdjoint):
135188
via the implicit function theorem.
136189
""" # noqa: E501
137190

138-
def loop(self, *, args, terms, solver, saveat, throw, init_state, **kwargs):
191+
def loop(
192+
self,
193+
*,
194+
args,
195+
terms,
196+
solver,
197+
saveat,
198+
throw,
199+
init_state,
200+
passed_solver_state,
201+
passed_controller_state,
202+
**kwargs,
203+
):
139204
del throw
140205

141206
# `is` check because this may return a Tracer from SaveAt(ts=<array>)
@@ -144,21 +209,30 @@ def loop(self, *, args, terms, solver, saveat, throw, init_state, **kwargs):
144209
"Can only use `adjoint=ImplicitAdjoint()` with `SaveAt(t1=True)`."
145210
)
146211

147-
init_state = eqx.tree_at(
148-
lambda s: (s.y, s.solver_state, s.controller_state),
149-
init_state,
150-
replace_fn=lax.stop_gradient,
151-
)
212+
if not passed_solver_state:
213+
init_state = eqx.tree_at(
214+
lambda s: s.solver_state,
215+
init_state,
216+
replace_fn=lax.stop_gradient,
217+
is_leaf=_is_none,
218+
)
219+
if not passed_controller_state:
220+
init_state = eqx.tree_at(
221+
lambda s: s.controller_state,
222+
init_state,
223+
replace_fn=lax.stop_gradient,
224+
is_leaf=_is_none,
225+
)
226+
152227
closure = (self, kwargs, solver, saveat, init_state)
153228
ys, residual = implicit_jvp(_solve, _vf, (args, terms), closure)
154229

155230
final_state_no_ys, aux_stats = residual
156-
return (
157-
eqx.tree_at(
158-
lambda s: s.ys, final_state_no_ys, ys, is_leaf=lambda x: x is None
159-
),
160-
aux_stats,
231+
final_state = eqx.tree_at(
232+
lambda s: s.ys, final_state_no_ys, ys, is_leaf=_is_none
161233
)
234+
final_state = _no_transpose_final_state(final_state)
235+
return final_state, aux_stats
162236

163237

164238
# Compute derivatives with respect to the first argument:
@@ -174,7 +248,7 @@ def _loop_backsolve(y__args__terms, *, self, throw, init_state, **kwargs):
174248
)
175249
del y
176250
return self._loop_fn(
177-
args=args, terms=terms, init_state=init_state, **kwargs, is_bounded=False
251+
args=args, terms=terms, init_state=init_state, is_bounded=False, **kwargs
178252
)
179253

180254

@@ -398,7 +472,18 @@ def __init__(self, **kwargs):
398472
)
399473
self.kwargs = kwargs
400474

401-
def loop(self, *, args, terms, saveat, init_state, **kwargs):
475+
def loop(
476+
self,
477+
*,
478+
args,
479+
terms,
480+
saveat,
481+
init_state,
482+
passed_solver_state,
483+
passed_controller_state,
484+
**kwargs,
485+
):
486+
del passed_solver_state, passed_controller_state
402487
if saveat.steps or saveat.dense:
403488
raise NotImplementedError(
404489
"Cannot use `adjoint=BacksolveAdjoint()` with "
@@ -414,13 +499,5 @@ def loop(self, *, args, terms, saveat, init_state, **kwargs):
414499
final_state, aux_stats = _loop_backsolve(
415500
(y, args, terms), self=self, saveat=saveat, init_state=init_state, **kwargs
416501
)
417-
418-
# We only allow backpropagation through `ys`; in particular not through
419-
# `solver_state` etc.
420-
ys = final_state.ys
421-
final_state = jtu.tree_map(nondifferentiable_output, final_state)
422-
final_state = eqx.tree_at(
423-
lambda s: jtu.tree_leaves(s.ys), final_state, jtu.tree_leaves(ys)
424-
)
425-
502+
final_state = _no_transpose_final_state(final_state)
426503
return final_state, aux_stats

diffrax/brownian/base.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,14 @@
11
import abc
22

3-
from ..custom_types import Array, Scalar
3+
from ..custom_types import Array, PyTree, Scalar
44
from ..path import AbstractPath
55

66

77
class AbstractBrownianPath(AbstractPath):
88
"Abstract base class for all Brownian paths."
99

1010
@abc.abstractmethod
11-
def evaluate(self, t0: Scalar, t1: Scalar, left: bool = True) -> Array:
11+
def evaluate(self, t0: Scalar, t1: Scalar, left: bool = True) -> PyTree[Array]:
1212
r"""Samples a Brownian increment $w(t_1) - w(t_0)$.
1313
1414
Each increment has distribution $\mathcal{N}(0, t_1 - t_0)$.
@@ -23,7 +23,7 @@ def evaluate(self, t0: Scalar, t1: Scalar, left: bool = True) -> Array:
2323
2424
**Returns:**
2525
26-
A JAX array corresponding to the increment $w(t_1) - w(t_0)$.
26+
A pytree of JAX arrays corresponding to the increment $w(t_1) - w(t_0)$.
2727
2828
Some subclasses may allow `t1=None`, in which case just the value $w(t_0)$ is
2929
returned.

0 commit comments

Comments
 (0)