Skip to content

Commit 115997e

Browse files
Merge pull request #125 from patrick-kidger/solver-docs
Improved solver documentation
2 parents aa27945 + a3d827e commit 115997e

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

53 files changed

+1467
-536
lines changed

.github/workflows/release.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ jobs:
1414
with:
1515
python-version: "3.8"
1616
test-script: |
17-
python -m pip install pytest psutil jax jaxlib equinox scipy
17+
python -m pip install pytest psutil jax jaxlib equinox scipy optax
1818
cp -r ${{ github.workspace }}/test ./test
1919
pytest
2020
pypi-token: ${{ secrets.pypi_token }}

.github/workflows/run_tests.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ jobs:
2323
- name: Install dependencies
2424
run: |
2525
python -m pip install --upgrade pip
26-
python -m pip install pytest psutil wheel scipy numpy jaxlib
26+
python -m pip install pytest psutil wheel scipy numpy optax jaxlib
2727
2828
- name: Checks with pre-commit
2929
uses: pre-commit/action@v2.0.3

diffrax/__init__.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,16 @@
11
from .adjoint import (
22
AbstractAdjoint,
33
BacksolveAdjoint,
4+
ImplicitAdjoint,
45
NoAdjoint,
56
RecursiveCheckpointAdjoint,
67
)
78
from .brownian import AbstractBrownianPath, UnsafeBrownianPath, VirtualBrownianTree
9+
from .event import (
10+
AbstractDiscreteTerminatingEvent,
11+
DiscreteTerminatingEvent,
12+
SteadyStateEvent,
13+
)
814
from .global_interpolation import (
915
AbstractGlobalInterpolation,
1016
backward_hermite_coefficients,
@@ -31,7 +37,6 @@
3137
from .saveat import SaveAt
3238
from .solution import RESULTS, Solution
3339
from .solver import (
34-
AbstractAdaptiveSDESolver,
3540
AbstractAdaptiveSolver,
3641
AbstractDIRK,
3742
AbstractERK,
@@ -45,6 +50,7 @@
4550
AbstractWrappedSolver,
4651
Bosh3,
4752
ButcherTableau,
53+
CalculateJacobian,
4854
Dopri5,
4955
Dopri8,
5056
Euler,
@@ -81,4 +87,4 @@
8187
)
8288

8389

84-
__version__ = "0.1.2"
90+
__version__ = "0.2.0"

diffrax/adjoint.py

Lines changed: 71 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import jax.lax as lax
77
import jax.numpy as jnp
88

9-
from .misc import nondifferentiable_output, ω
9+
from .misc import implicit_jvp, nondifferentiable_output, ω
1010
from .saveat import SaveAt
1111
from .term import AbstractTerm, AdjointTerm
1212

@@ -22,6 +22,7 @@ def loop(
2222
terms,
2323
solver,
2424
stepsize_controller,
25+
discrete_terminating_event,
2526
saveat,
2627
t0,
2728
t1,
@@ -92,6 +93,73 @@ def loop(self, *, throw, **kwargs):
9293
return final_state, aux_stats
9394

9495

96+
def _vf(ys, residual, args__terms, closure):
97+
state_no_y, _ = residual
98+
t = state_no_y.tprev
99+
(y,) = ys # unpack length-1 dimension
100+
args, terms = args__terms
101+
_, _, solver, _, _ = closure
102+
return solver.func(terms, t, y, args)
103+
104+
105+
def _solve(args__terms, closure):
106+
args, terms = args__terms
107+
self, kwargs, solver, saveat, init_state = closure
108+
final_state, aux_stats = self._loop_fn(
109+
**kwargs,
110+
args=args,
111+
terms=terms,
112+
solver=solver,
113+
saveat=saveat,
114+
init_state=init_state,
115+
is_bounded=False,
116+
)
117+
# Note that we use .ys not .y here. The former is what is actually returned
118+
# by diffeqsolve, so it is the thing we want to attach the tangent to.
119+
return final_state.ys, (
120+
eqx.tree_at(lambda s: s.ys, final_state, None),
121+
aux_stats,
122+
)
123+
124+
125+
class ImplicitAdjoint(AbstractAdjoint):
126+
r"""Backpropagate via the [implicit function theorem](https://en.wikipedia.org/wiki/Implicit_function_theorem#Statement_of_the_theorem).
127+
128+
This is used when solving towards a steady state, typically using
129+
[`diffrax.SteadyStateEvent`][]. In this case, the output of the solver is $y(θ)$
130+
for which $f(t, y(θ), θ) = 0$. (Where $θ$ corresponds to all parameters found
131+
through `terms` and `args`, but not `y0`.) Then we can skip backpropagating through
132+
the solver and instead directly compute
133+
$\frac{\mathrm{d}y}{\mathrm{d}θ} = - (\frac{\mathrm{d}f}{\mathrm{d}y})^{-1}\frac{\mathrm{d}f}{\mathrm{d}θ}$
134+
via the implicit function theorem.
135+
""" # noqa: E501
136+
137+
def loop(self, *, args, terms, solver, saveat, throw, init_state, **kwargs):
138+
del throw
139+
140+
# `is` check because this may return a Tracer from SaveAt(ts=<array>)
141+
if eqx.tree_equal(saveat, SaveAt(t1=True)) is not True:
142+
raise ValueError(
143+
"Can only use `adjoint=ImplicitAdjoint()` with `SaveAt(t1=True)`."
144+
)
145+
146+
init_state = eqx.tree_at(
147+
lambda s: (s.y, s.solver_state, s.controller_state),
148+
init_state,
149+
replace_fn=lax.stop_gradient,
150+
)
151+
closure = (self, kwargs, solver, saveat, init_state)
152+
ys, residual = implicit_jvp(_solve, _vf, (args, terms), closure)
153+
154+
final_state_no_ys, aux_stats = residual
155+
return (
156+
eqx.tree_at(
157+
lambda s: s.ys, final_state_no_ys, ys, is_leaf=lambda x: x is None
158+
),
159+
aux_stats,
160+
)
161+
162+
95163
# Compute derivatives with respect to the first argument:
96164
# - y, corresponding to the initial state;
97165
# - args, corresponding to explicit parameters;
@@ -116,7 +184,6 @@ def _loop_backsolve_fwd(y__args__terms, **kwargs):
116184
return (final_state, aux_stats), (ts, ys)
117185

118186

119-
# TODO: implement this as a single diffeqsolve with events, once events are supported.
120187
def _loop_backsolve_bwd(
121188
residuals,
122189
grad_final_state__aux_stats,
@@ -125,6 +192,7 @@ def _loop_backsolve_bwd(
125192
self,
126193
solver,
127194
stepsize_controller,
195+
discrete_terminating_event,
128196
saveat,
129197
t0,
130198
t1,
@@ -162,6 +230,7 @@ def _loop_backsolve_bwd(
162230
adjoint=self,
163231
solver=solver,
164232
stepsize_controller=stepsize_controller,
233+
discrete_terminating_event=discrete_terminating_event,
165234
terms=adjoint_terms,
166235
dt0=None if dt0 is None else -dt0,
167236
max_steps=max_steps,

diffrax/event.py

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
import abc
2+
from typing import Callable, Optional
3+
4+
import equinox as eqx
5+
6+
from .custom_types import Bool, PyTree, Scalar
7+
from .misc import rms_norm
8+
from .step_size_controller import AbstractAdaptiveStepSizeController
9+
10+
11+
class AbstractDiscreteTerminatingEvent(eqx.Module):
12+
"""Evaluated at the end of each integration step. If true then the solve is stopped
13+
at that time.
14+
"""
15+
16+
@abc.abstractmethod
17+
def __call__(self, state, **kwargs):
18+
"""**Arguments:**
19+
20+
- `state`: a dataclass of the evolving state of the system, including in
21+
particular the solution `state.y` at time `state.tprev`.
22+
- `**kwargs`: the integration options held constant throughout the solve
23+
are passed as keyword arguments: `terms`, `solver`, `args`. etc.
24+
25+
**Returns**
26+
27+
A boolean. If true then the solve is terminated.
28+
"""
29+
30+
31+
class DiscreteTerminatingEvent(AbstractDiscreteTerminatingEvent):
32+
"""Terminates the solve if its condition is ever active."""
33+
34+
cond_fn: Callable[..., Bool]
35+
36+
def __call__(self, state, **kwargs):
37+
return self.cond_fn(state, **kwargs)
38+
39+
40+
DiscreteTerminatingEvent.__init__.__doc__ = """**Arguments:**
41+
42+
- `cond_fn`: A function `(state, **kwargs) -> bool` that is evaluated on every step of
43+
the differential equation solve. If it returns `True` then the solve is finished at
44+
that timestep. `state` is a dataclass of the evolving state of the system,
45+
including in particular the solution `state.y` at time `state.tprev`. Passed as
46+
keyword arguments are the `terms`, `solver`, `args` etc. that are constant
47+
throughout the solve.
48+
"""
49+
50+
51+
class SteadyStateEvent(AbstractDiscreteTerminatingEvent):
52+
"""Terminates the solve once it reaches a steady state."""
53+
54+
rtol: Optional[float] = None
55+
atol: Optional[float] = None
56+
norm: Callable[[PyTree], Scalar] = rms_norm
57+
58+
def __call__(self, state, *, terms, args, solver, stepsize_controller, **kwargs):
59+
del kwargs
60+
_error = False
61+
if self.rtol is None:
62+
if isinstance(stepsize_controller, AbstractAdaptiveStepSizeController):
63+
_rtol = stepsize_controller.rtol
64+
else:
65+
_error = True
66+
else:
67+
_rtol = self.rtol
68+
if self.atol is None:
69+
if isinstance(stepsize_controller, AbstractAdaptiveStepSizeController):
70+
_atol = stepsize_controller.atol
71+
else:
72+
_error = True
73+
else:
74+
_atol = self.atol
75+
if _error:
76+
raise ValueError(
77+
"The `rtol` and `atol` tolerances for `SteadyStateEvent` default "
78+
"to the `rtol` and `atol` used with an adaptive step size "
79+
"controller (such as `diffrax.PIDController`). Either use an "
80+
"adaptive step size controller, or specify these tolerances "
81+
"manually."
82+
)
83+
84+
# TODO: this makes an additional function evaluation that in practice has
85+
# probably already been made by the solver.
86+
vf = solver.func(terms, state.tprev, state.y, args)
87+
return self.norm(vf) < _atol + _rtol * self.norm(state.y)
88+
89+
90+
SteadyStateEvent.__init__.__doc__ = """**Arguments:**
91+
92+
- `rtol`: The relative tolerance for determining convergence. Defaults to the
93+
same `rtol` as passed to an adaptive step controller if one is used.
94+
- `atol`: The absolute tolerance for determining convergence. Defaults to the
95+
same `atol` as passed to an adaptive step controller if one is used.
96+
- `norm`: A function `PyTree -> Scalar`, which is called to determine whether
97+
the vector field is close to zero.
98+
"""

0 commit comments

Comments
 (0)