Skip to content

Commit 0f4e130

Browse files
Improved term docs.
1 parent c27f6c7 commit 0f4e130

File tree

3 files changed

+198
-51
lines changed

3 files changed

+198
-51
lines changed

.pre-commit-config.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,11 @@ repos:
22
- repo: https://github.com/astral-sh/ruff-pre-commit
33
rev: v0.2.2
44
hooks:
5+
- id: ruff-format # formatter
6+
types_or: [ python, pyi, jupyter ]
57
- id: ruff # linter
68
types_or: [ python, pyi, jupyter ]
79
args: [ --fix ]
8-
- id: ruff-format # formatter
9-
types_or: [ python, pyi, jupyter ]
1010
- repo: https://github.com/RobertCraigie/pyright-python
1111
rev: v1.1.350
1212
hooks:

diffrax/_term.py

Lines changed: 158 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,10 @@ def prod(self, vf: _VF, control: _Control) -> Y:
8484
control $x$, this computes $f(t, y(t), args) \Delta x(t)$ given
8585
$f(t, y(t), args)$ and $\Delta x(t)$.
8686
87+
!!! note
88+
89+
This function must be bilinear.
90+
8791
**Arguments:**
8892
8993
- `vf`: The vector field evaluation; a PyTree of structure $S$.
@@ -93,10 +97,6 @@ def prod(self, vf: _VF, control: _Control) -> Y:
9397
9498
The interaction between the vector field and control; a PyTree of structure
9599
$T$.
96-
97-
!!! note
98-
99-
This function must be bilinear.
100100
"""
101101
pass
102102

@@ -260,6 +260,8 @@ def _prod(vf, control):
260260
return jnp.tensordot(vf, control, axes=jnp.ndim(control))
261261

262262

263+
# This class exists for backward compatibility with `WeaklyDiagonalControlTerm`. If we
264+
# were writing things again today it would be folded into just `ControlTerm`.
263265
class _AbstractControlTerm(AbstractTerm[_VF, _Control]):
264266
vector_field: Callable[[RealScalarLike, Y, Args], _VF]
265267
control: Union[
@@ -290,61 +292,143 @@ def to_ode(self) -> ODETerm:
290292
- `vector_field`: A callable representing the vector field. This callable takes three
291293
arguments `(t, y, args)`. `t` is a scalar representing the integration time. `y` is
292294
the evolving state of the system. `args` are any static arguments as passed to
293-
[`diffrax.diffeqsolve`][]. This `vector_field` can be a function that returns a
294-
JAX array, or returns any [lineax `AbstractLinearOperator`](https://docs.kidger.site/lineax/api/operators/#lineax.AbstractLinearOperator).
295-
- `control`: The control. Should either be (A) a [`diffrax.AbstractPath`][], in which
296-
case its `evaluate(t0, t1)` method will be used to give the increment of the control
297-
over a time interval `[t0, t1]`, or (B) a callable `(t0, t1) -> increment`, which
298-
returns the increment directly.
295+
[`diffrax.diffeqsolve`][]. This `vector_field` can either be
296+
297+
1. a function that returns a PyTree of JAX arrays, or
298+
2. it can return a
299+
[Lineax linear operator](https://docs.kidger.site/lineax/api/operators),
300+
as described above.
301+
302+
- `control`: The control. Should either be
303+
304+
1. a [`diffrax.AbstractPath`][], in which case its `.evaluate(t0, t1)` method
305+
will be used to give the increment of the control over a time interval
306+
`[t0, t1]`, or
307+
2. a callable `(t0, t1) -> increment`, which returns the increment directly.
299308
"""
300309

301310

302311
class ControlTerm(_AbstractControlTerm[_VF, _Control]):
303312
r"""A term representing the general case of $f(t, y(t), args) \mathrm{d}x(t)$, in
304-
which the vector field - control interaction is a matrix-vector product.
313+
which the vector field ($f$) - control ($\mathrm{d}x$) interaction is a
314+
matrix-vector product.
305315
306-
`vector_field` and `control` should both return PyTrees, both with the same
307-
structure as the initial state `y0`. Every dimension of `control` is then
308-
contracted against the last dimensions of `vector_field`; that is to say if each
309-
leaf of `y0` has shape `(y1, ..., yN)`, and the corresponding leaf of `control`
310-
has shape `(c1, ..., cM)`, then the corresponding leaf of `vector_field` should
311-
have shape `(y1, ..., yN, c1, ..., cM)`.
316+
This is typically used for either stochastic differential equations or for
317+
controlled differential equations.
312318
313-
A common special case is when `y0` and `control` are vector-valued, and
314-
`vector_field` is matrix-valued.
319+
`ControlTerm` can be used in two different ways.
315320
316-
To make a weakly diagonal control term, simply use your vector field
317-
callable return a `lx.DiagonalLinearOperator`.
321+
1. Simple way: directly return JAX arrays.
318322
319-
!!! info
323+
`vector_field` and `control` should both return PyTrees, both with the same
324+
structure as the initial state `y0`. All leaves should be JAX arrays.
320325
321-
Why "weakly" diagonal? Consider the matrix representation of the vector field,
322-
as a square diagonal matrix. In general, the (i,i)-th element may depending
323-
upon any of the values of `y`. It is only if the (i,i)-th element only depends
324-
upon the i-th element of `y` that the vector field is said to be "diagonal",
325-
without the "weak". (This stronger property is useful in some SDE solvers.)
326+
If each leaf of `y0` has shape `(y1, ..., yN)`, and the corresponding leaf of
327+
`control` has shape `(c1, ..., cM)`, then the corresponding leaf of
328+
`vector_field` should have shape `(y1, ..., yN, c1, ..., cM)`. Leaf-by-leaf, the
329+
corresponding dimensions of `vector_field` and control are contracted against
330+
each other.
326331
327-
!!! example
332+
This includes normal matrix-vector products as a special case: when `y0` is an
333+
array with shape `(m,)`, the control is an array with shape `(n,)`, and the
334+
vector field is an array with shape `(m, n)`.
335+
336+
2. Advanced way: have the vector field return a [Lineax linear operator](https://docs.kidger.site/lineax/api/operators).
337+
338+
This is suitable for use cases in which you know that the vector field has
339+
special structure -- e.g. it is diagonal -- and you would like to use that
340+
structure for a more efficient implementation.
341+
342+
In this case, then `vector_field` should return a
343+
[Lineax linear operator](https://docs.kidger.site/lineax/api/operators), the
344+
control can return anything compatible with the
345+
[`.mv`](https://docs.kidger.site/lineax/api/operators/#lineax.AbstractLinearOperator.mv)
346+
method of that operator, and the interaction is defined as
347+
`vector_field(t0, y, arg).mv(control(t0, t1))`.
348+
349+
In this case no special PyTree handling is done -- perform this inside the
350+
operator's `.mv` if required. (As you can see, this approach is basically about
351+
deferring the whole linear operation to Lineax.)
352+
353+
!!! Example
354+
355+
In this example we consider an SDE with `m`-dimensional state
356+
$y \in \mathbb{R}^m$, an `n`-dimensional Brownian motion
357+
$W(t) \in \mathbb{R}^n$, and a constant diffusion of shape `(m, n)`.
358+
359+
$\mathrm{d}y(t) = \begin{bmatrix} 1 & ... & 1 \\ & ... & \\ 1 & ... & 1 \end{bmatrix} \mathrm{d}W(t)$
328360
329361
```python
362+
from diffrax import ControlTerm, diffeqsolve, UnsafeBrownianPath
363+
364+
y0 = jnp.ones((m,))
365+
control = UnsafeBrownianPath(shape=(n,), key=...)
366+
367+
def vector_field(t, y, args):
368+
return jnp.ones((m, n))
369+
370+
diffusion_term = ControlTerm(vector_field, control)
371+
diffeqsolve(terms=diffusion_term, y0=y0, ...)
372+
```
373+
!!! Example
374+
375+
In this example we consider an SDE with a one-dimensional state
376+
$y(t) \in \mathbb{R}$ and a two-dimensional Brownian motion
377+
$W(t) \in \mathbb{R}^2$, given by:
378+
379+
$\mathrm{d}y(t) = \begin{bmatrix} y(t) \\ y(t) + 1 \end{bmatrix} \mathrm{d}W(t)$
380+
381+
We use the simple matrix-vector product way of combining things.
382+
383+
```python
384+
from diffrax import ControlTerm, diffeqsolve, UnsafeBrownianPath
385+
330386
control = UnsafeBrownianPath(shape=(2,), key=...)
331-
vector_field = lambda t, y, args: lx.DiagonalLinearOperator(jnp.ones_like(y))
387+
388+
def vector_field(t, y, args):
389+
return jnp.stack([y, y + 1], axis=-1)
390+
332391
diffusion_term = ControlTerm(vector_field, control)
333392
diffeqsolve(diffusion_term, ...)
334393
```
335394
336-
!!! example
395+
!!! Example
396+
397+
In this example we consider an SDE with two-dimensional state
398+
$(y_1(t), y_2(t)) \in \mathbb{R}^2$ and a two-dimensional Brownian motion
399+
$W(t) \in \mathbb{R}^2$ -- and for which the diffusion matrix is
400+
diagonal.
401+
402+
$\mathrm{d}\begin{bmatrix} y_1 \\ y_2 \end{bmatrix}(t) = \begin{bmatrix} y_2(t) & 0 \\ 0 & y_1(t) \end{bmatrix} \mathrm{d}W(t)$
403+
404+
As such we use the more-advanced approach of using
405+
[Lineax](https://github.com/patrick-kidger/lineax/)'s linear operators to
406+
represent the diffusion matrix.
337407
338408
```python
409+
from diffrax import ControlTerm, diffeqsolve, UnsafeBrownianPath
410+
339411
control = UnsafeBrownianPath(shape=(2,), key=...)
340-
vector_field = lambda t, y, args: jnp.stack([y, y], axis=-1)
412+
413+
def vector_field(t, y, args):
414+
# y is a JAX array of shape (2,)
415+
y1, y2 = y
416+
diagonal = jnp.array([y2, y1])
417+
return lineax.DiagonalLinearOperator(diagonal)
418+
341419
diffusion_term = ControlTerm(vector_field, control)
342420
diffeqsolve(diffusion_term, ...)
343421
```
344422
345-
!!! example
423+
!!! Example
424+
425+
In this example we consider a controlled differnetial equation, for which the
426+
control is given by an interpolation of some data. (See also the
427+
[neural controlled differential equation](../examples/neural_cde/) example.)
346428
347429
```python
430+
from diffrax import ControlTerm, diffeqsolve, LinearInterpolation, UnsafeBrownianPath
431+
348432
ts = jnp.array([1., 2., 2.5, 3.])
349433
data = jnp.array([[0.1, 2.0],
350434
[0.3, 1.5],
@@ -355,16 +439,29 @@ class ControlTerm(_AbstractControlTerm[_VF, _Control]):
355439
cde_term = ControlTerm(vector_field, control)
356440
diffeqsolve(cde_term, ...)
357441
```
358-
"""
442+
""" # noqa: E501
359443

360444
def prod(self, vf: _VF, control: _Control) -> Y:
361445
if isinstance(vf, lx.AbstractLinearOperator):
362446
return vf.mv(control)
363-
return jtu.tree_map(_prod, vf, control)
447+
else:
448+
return jtu.tree_map(_prod, vf, control)
364449

365450

366451
class WeaklyDiagonalControlTerm(_AbstractControlTerm[_VF, _Control]):
367-
r"""A term representing the case of $f(t, y(t), args) \mathrm{d}x(t)$, in
452+
r"""
453+
DEPRECATED. Prefer:
454+
455+
```python
456+
def vector_field(t, y, args):
457+
return lineax.DiagonalLinearOperator(...)
458+
459+
diffrax.ControlTerm(vector_field, ...)
460+
```
461+
462+
---
463+
464+
A term representing the case of $f(t, y(t), args) \mathrm{d}x(t)$, in
368465
which the vector field - control interaction is a matrix-vector product, and the
369466
matrix is square and diagonal. In this case we may represent the matrix as a vector
370467
of just its diagonal elements. The matrix-vector product may be calculated by
@@ -385,14 +482,34 @@ class WeaklyDiagonalControlTerm(_AbstractControlTerm[_VF, _Control]):
385482
without the "weak". (This stronger property is useful in some SDE solvers.)
386483
"""
387484

388-
def __init__(self, *args, **kwargs):
485+
def __check_init__(self):
389486
warnings.warn(
390-
"WeaklyDiagonalControlTerm is pending deprecation and may be removed "
391-
"in future versions. Consider using the new alternative "
392-
"ControlTerm(lx.DiagonalLinearOperator(...)).",
393-
DeprecationWarning,
487+
"`WeaklyDiagonalControlTerm` is now deprecated, in favour combining "
488+
"`ControlTerm` with a `lineax.AbstractLinearOperator`. This offers a way "
489+
"to define a vector field with any kind of structure -- diagonal or "
490+
"otherwise.\n"
491+
"For a diagonal linear operator, then this can be easily converted as "
492+
"follows. What was previously:\n"
493+
"```\n"
494+
"def vector_field(t, y, args):\n"
495+
" ...\n"
496+
" return some_vector\n"
497+
"\n"
498+
"diffrax.WeaklyDiagonalControlTerm(vector_field)\n"
499+
"```\n"
500+
"is now:\n"
501+
"```\n"
502+
"import lineax\n"
503+
"\n"
504+
"def vector_field(t, y, args):\n"
505+
" ...\n"
506+
" return lineax.DiagonalLinearOperator(some_vector)\n"
507+
"\n"
508+
"diffrax.ControlTerm(vector_field)\n"
509+
"```\n"
510+
"Lineax is available at `https://github.com/patrick-kidger/lineax`.\n",
511+
stacklevel=3,
394512
)
395-
super().__init__(*args, **kwargs)
396513

397514
def prod(self, vf: _VF, control: _Control) -> Y:
398515
with jax.numpy_dtype_promotion("standard"):

docs/api/terms.md

Lines changed: 38 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,21 +4,25 @@ One of the advanced features of Diffrax is its *term* system. When we write down
44

55
$\mathrm{d}y(t) = f(t, y(t))\mathrm{d}t + g(t, y(t))\mathrm{d}w(t)$
66

7-
then we have two "terms": a drift and a diffusion. Each of these terms has two parts: a *vector field* ($f$ or $g$) and a *control* ($\mathrm{d}t$ or $\mathrm{d}w(t)$). There is also an implicit assumption about how the vector field and control interact: $f$ and $\mathrm{d}t$ interact as a vector-scalar product. $g$ and $\mathrm{d}w(t)$ interact as a matrix-vector product. (This interaction is always linear.)
7+
then we have two "terms": a drift and a diffusion. Each of these terms has two parts: a *vector field* ($f$ or $g$) and a *control* ($\mathrm{d}t$ or $\mathrm{d}w(t)$). In addition (often not represented in mathematical notation), there is also a choice of how the vector field and control interact: $f$ and $\mathrm{d}t$ interact as a vector-scalar product. $g$ and $\mathrm{d}w(t)$ interact as a matrix-vector product. (In general this interaction is always bilinear.)
88

99
"Terms" are thus the building blocks of differential equations.
1010

1111
!!! example
1212

1313
Consider the ODE $\frac{\mathrm{d}{y}}{\mathrm{d}t} = f(t, y(t))$. Then this has vector field $f$, control $\mathrm{d}t$, and their interaction is a vector-scalar product. This can be described as a single [`diffrax.ODETerm`][].
1414

15-
If multiple terms affect the same evolving state, then they should be grouped into a single [`diffrax.MultiTerm`][].
15+
#### Adding multiple terms, such as SDEs
16+
17+
We can add multiple terms together by grouping them into a single [`diffrax.MultiTerm`][].
1618

1719
!!! example
1820

19-
An SDE would have its drift described by [`diffrax.ODETerm`][] and the diffusion described by a [`diffrax.ControlTerm`][]. As these affect the same evolving state variable, they should be passed to the solver as `MultiTerm(ODETerm(...), ControlTerm(...))`.
21+
The SDE above would have its drift described by [`diffrax.ODETerm`][] and the diffusion described by a [`diffrax.ControlTerm`][]. As these affect the same evolving state variable, they should be passed to the solver as `MultiTerm(ODETerm(...), ControlTerm(...))`.
22+
23+
#### Independent terms, such as Hamiltonian systems
2024

21-
If terms affect different pieces of the state, then they should be placed in some PyTree structure. (The exact structure will depend on what the solver accepts.)
25+
If terms affect different pieces of the state, then they should be placed in some PyTree structure.
2226

2327
!!! example
2428

@@ -28,7 +32,31 @@ If terms affect different pieces of the state, then they should be placed in som
2832

2933
These would be passed to the solver as the 2-tuple of `(ODETerm(...), ODETerm(...))`.
3034

31-
Each solver is capable of handling certain classes of problems, as described by their `solver.term_structure`.
35+
#### What each solver accepts
36+
37+
Each solver in Diffrax will specify what kinds of problems it can handle, as described by their `.term_structure` attribute. Not all solvers are able to handle all problems!
38+
39+
Some example term structures include:
40+
41+
1. `solver.term_structure = AbstractTerm`
42+
43+
In this case the solver can handle a simple ODE as descibed above: `ODETerm` is a subclass of `AbstractTerm`.
44+
45+
It can also handle SDEs: `MultiTerm(ODETerm(...), ControlTerm(...))` includes everything wrapped into a single term (the `MultiTerm`), and at that point this defines an interface the solver knows how to handle.
46+
47+
Most solvers in Diffrax have this term structure.
48+
49+
2. `solver.term_structure = MultiTerm[tuple[ODETerm, ControlTerm]]`
50+
51+
In this case the solver specifically handles just SDEs of the form `MultiTerm(ODETerm(...), ControlTerm(...))`; nothing else is compatible.
52+
53+
Some SDE-specific solvers have this term structure.
54+
55+
3. `solver.term_structure = (AbstractTerm, AbstractTerm)`
56+
57+
In this case the solver is used to solve ODEs like the Hamiltonian system described above: we have a PyTree of terms, each of which is treated individually.
58+
59+
---
3260

3361
??? abstract "`diffrax.AbstractTerm`"
3462

@@ -41,10 +69,12 @@ Each solver is capable of handling certain classes of problems, as described by
4169
- vf_prod
4270
- is_vf_expensive
4371

44-
---
72+
??? note "Defining your own term types"
73+
74+
For advanced users: you can create your own terms if appropriate. For example if your diffusion is matrix, itself computed as a matrix-matrix product, then you may wish to define a custom term and specify its [`diffrax.AbstractTerm.vf_prod`][] method. By overriding this method you could express the contraction of the vector field - control as a matrix-(matix-vector) product, which is more efficient than the default (matrix-matrix)-vector product.
4575

46-
!!! note
47-
You can create your own terms if appropriate: e.g. if a diffusion matrix has some particular structure, and you want to use a specialised more efficient matrix-vector product algorithm in `prod`.
76+
77+
---
4878

4979
::: diffrax.ODETerm
5080
selection:

0 commit comments

Comments
 (0)