Skip to content

Commit c27f6c7

Browse files
lockwopatrick-kidger
authored andcommitted
Lineax Control Terms (#436)
* lineax draft * remove from docs * add tests everywhere weakly is * review feedback * move test out of misc now * extra weak removal * add deprecate * simplify tree map * remove warning test * warning fix
1 parent a3688bb commit c27f6c7

File tree

9 files changed

+162
-24
lines changed

9 files changed

+162
-24
lines changed

diffrax/_solver/srk.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -203,8 +203,7 @@ class AbstractSRK(AbstractSolver[_SolverState]):
203203
r"""A general Stochastic Runge-Kutta method.
204204
205205
This accepts `terms` of the form
206-
`MultiTerm(ODETerm(drift), ControlTerm(diffusion, brownian_motion))` or
207-
`MultiTerm(ODETerm(drift), WeaklyDiagonalControlTerm(diffusion, brownian_motion))`.
206+
`MultiTerm(ODETerm(drift), ControlTerm(diffusion, brownian_motion))`.
208207
Depending on the solver, the Brownian motion might need to generate
209208
different types of Lévy areas, specified by the `minimal_levy_area` attribute.
210209

diffrax/_term.py

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
11
import abc
22
import operator
3+
import warnings
34
from collections.abc import Callable
45
from typing import cast, Generic, Optional, TypeVar, Union
56

67
import equinox as eqx
78
import jax
89
import jax.numpy as jnp
910
import jax.tree_util as jtu
11+
import lineax as lx
1012
import numpy as np
1113
from equinox.internal import ω
1214
from jaxtyping import ArrayLike, PyTree, PyTreeDef
@@ -288,7 +290,8 @@ def to_ode(self) -> ODETerm:
288290
- `vector_field`: A callable representing the vector field. This callable takes three
289291
arguments `(t, y, args)`. `t` is a scalar representing the integration time. `y` is
290292
the evolving state of the system. `args` are any static arguments as passed to
291-
[`diffrax.diffeqsolve`][].
293+
[`diffrax.diffeqsolve`][]. This `vector_field` can be a function that returns a
294+
JAX array, or returns any [lineax `AbstractLinearOperator`](https://docs.kidger.site/lineax/api/operators/#lineax.AbstractLinearOperator).
292295
- `control`: The control. Should either be (A) a [`diffrax.AbstractPath`][], in which
293296
case its `evaluate(t0, t1)` method will be used to give the increment of the control
294297
over a time interval `[t0, t1]`, or (B) a callable `(t0, t1) -> increment`, which
@@ -310,6 +313,26 @@ class ControlTerm(_AbstractControlTerm[_VF, _Control]):
310313
A common special case is when `y0` and `control` are vector-valued, and
311314
`vector_field` is matrix-valued.
312315
316+
To make a weakly diagonal control term, simply use your vector field
317+
callable return a `lx.DiagonalLinearOperator`.
318+
319+
!!! info
320+
321+
Why "weakly" diagonal? Consider the matrix representation of the vector field,
322+
as a square diagonal matrix. In general, the (i,i)-th element may depending
323+
upon any of the values of `y`. It is only if the (i,i)-th element only depends
324+
upon the i-th element of `y` that the vector field is said to be "diagonal",
325+
without the "weak". (This stronger property is useful in some SDE solvers.)
326+
327+
!!! example
328+
329+
```python
330+
control = UnsafeBrownianPath(shape=(2,), key=...)
331+
vector_field = lambda t, y, args: lx.DiagonalLinearOperator(jnp.ones_like(y))
332+
diffusion_term = ControlTerm(vector_field, control)
333+
diffeqsolve(diffusion_term, ...)
334+
```
335+
313336
!!! example
314337
315338
```python
@@ -335,6 +358,8 @@ class ControlTerm(_AbstractControlTerm[_VF, _Control]):
335358
"""
336359

337360
def prod(self, vf: _VF, control: _Control) -> Y:
361+
if isinstance(vf, lx.AbstractLinearOperator):
362+
return vf.mv(control)
338363
return jtu.tree_map(_prod, vf, control)
339364

340365

@@ -360,6 +385,15 @@ class WeaklyDiagonalControlTerm(_AbstractControlTerm[_VF, _Control]):
360385
without the "weak". (This stronger property is useful in some SDE solvers.)
361386
"""
362387

388+
def __init__(self, *args, **kwargs):
389+
warnings.warn(
390+
"WeaklyDiagonalControlTerm is pending deprecation and may be removed "
391+
"in future versions. Consider using the new alternative "
392+
"ControlTerm(lx.DiagonalLinearOperator(...)).",
393+
DeprecationWarning,
394+
)
395+
super().__init__(*args, **kwargs)
396+
363397
def prod(self, vf: _VF, control: _Control) -> Y:
364398
with jax.numpy_dtype_promotion("standard"):
365399
return jtu.tree_map(operator.mul, vf, control)

docs/api/solvers/sde_solvers.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ See also [How to choose a solver](../../usage/how-to-choose-a-solver.md#stochast
66

77
The type of solver chosen determines how the `terms` argument of `diffeqsolve` should be laid out.
88

9-
Most solvers handle both ODEs and SDEs in the same way, and expect a single term. So for an ODE you would pass `terms=ODETerm(vector_field)`, and for an SDE you would pass `terms=MultiTerm(ODETerm(drift), ControlTerm(diffusion, brownian_motion))` or `terms=MultiTerm(ODETerm(drift), WeaklyDiagonalControlTerm(diffusion, brownian_motion))`. For example:
9+
Most solvers handle both ODEs and SDEs in the same way, and expect a single term. So for an ODE you would pass `terms=ODETerm(vector_field)`, and for an SDE you would pass `terms=MultiTerm(ODETerm(drift), ControlTerm(diffusion, brownian_motion))`. For example:
1010

1111
```python
1212
drift = lambda t, y, args: -y
@@ -18,7 +18,7 @@ See also [How to choose a solver](../../usage/how-to-choose-a-solver.md#stochast
1818

1919
For any individual solver then this is documented below, and is also available programatically under `<solver>.term_structure`.
2020

21-
For advanced users, note that we typically accept any `AbstractTerm` for the diffusion, so it could be a custom one that implements more-efficient behaviour for the structure of your diffusion matrix. (Much like how [`diffrax.WeaklyDiagonalControlTerm`][] is more efficient than [`diffrax.ControlTerm`][] for diagonal diffusions.)
21+
For advanced users, note that we typically accept any `AbstractTerm` for the diffusion, so it could be a custom one that implements more-efficient behaviour for the structure of your diffusion matrix.
2222

2323
---
2424

@@ -52,7 +52,7 @@ These solvers can be used to solve SDEs just as well as they can be used to solv
5252

5353
!!! info "Term structure"
5454

55-
These solvers are SDE-specific. For these, `terms` must specifically be of the form `MultiTerm(ODETerm(...), SomeOtherTerm(...))` (Typically `SomeOTherTerm` will be a `ControlTerm` or `WeaklyDiagonalControlTerm`) representing the drift and diffusion specifically.
55+
These solvers are SDE-specific. For these, `terms` must specifically be of the form `MultiTerm(ODETerm(...), SomeOtherTerm(...))` (Typically `SomeOTherTerm` will be a `ControlTerm` representing the drift and diffusion specifically.
5656

5757

5858
::: diffrax.EulerHeun

docs/api/terms.md

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ Each solver is capable of handling certain classes of problems, as described by
4444
---
4545

4646
!!! note
47-
You can create your own terms if appropriate: e.g. if a diffusion matrix has some particular structure, and you want to use a specialised more efficient matrix-vector product algorithm in `prod`. For example this is what [`diffrax.WeaklyDiagonalControlTerm`][] does, as compared to just [`diffrax.ControlTerm`][].
47+
You can create your own terms if appropriate: e.g. if a diffusion matrix has some particular structure, and you want to use a specialised more efficient matrix-vector product algorithm in `prod`.
4848

4949
::: diffrax.ODETerm
5050
selection:
@@ -57,12 +57,6 @@ Each solver is capable of handling certain classes of problems, as described by
5757
- __init__
5858
- to_ode
5959

60-
::: diffrax.WeaklyDiagonalControlTerm
61-
selection:
62-
members:
63-
- __init__
64-
- to_ode
65-
6660
::: diffrax.MultiTerm
6761
selection:
6862
members:

docs/usage/extending.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ The main points of extension are as follows:
2424
- **Custom controls** (e.g. **custom interpolation schemes** analogous to [`diffrax.CubicInterpolation`][]) should inherit from [`diffrax.AbstractPath`][].
2525

2626
- **Custom terms** should inherit from [`diffrax.AbstractTerm`][].
27-
- For example, if the vector field - control interaction is a matrix-vector product, but the matrix is known to have special structure, then you may wish to create a custom term that can calculate this interaction more efficiently than is given by a full matrix-vector product. For example this is done with [`diffrax.WeaklyDiagonalControlTerm`][] as compared to [`diffrax.ControlTerm`][].
27+
- For example, if the vector field - control interaction is a matrix-vector product, but the matrix is known to have special structure, then you may wish to create a custom term that can calculate this interaction more efficiently than is given by a full matrix-vector product. Given the large suite of linear operators [lineax](https://docs.kidger.site/lineax/) implements (which are fully supported by [`diffrax.ControlTerm`][]), this is likely rarely necessary.
2828

2929
In each case we recommend looking up existing solvers/etc. in Diffrax to understand how to implement them.
3030

test/test_adjoint.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import jax.numpy as jnp
99
import jax.random as jr
1010
import jax.tree_util as jtu
11+
import lineax as lx
1112
import optax
1213
import pytest
1314
from jaxtyping import Array
@@ -331,7 +332,11 @@ def run(model):
331332
run(mlp)
332333

333334

334-
def test_sde_against(getkey):
335+
@pytest.mark.parametrize(
336+
"diffusion_fn",
337+
["weak", "lineax"],
338+
)
339+
def test_sde_against(diffusion_fn, getkey):
335340
def f(t, y, args):
336341
k0, _ = args
337342
return -k0 * y
@@ -340,14 +345,21 @@ def g(t, y, args):
340345
_, k1 = args
341346
return k1 * y
342347

348+
def g_lx(t, y, args):
349+
_, k1 = args
350+
return lx.DiagonalLinearOperator(k1 * y)
351+
343352
t0 = 0
344353
t1 = 1
345354
dt0 = 0.001
346355
tol = 1e-5
347356
shape = (2,)
348357
bm = diffrax.VirtualBrownianTree(t0, t1, tol, shape, key=getkey())
349358
drift = diffrax.ODETerm(f)
350-
diffusion = diffrax.WeaklyDiagonalControlTerm(g, bm)
359+
if diffusion_fn == "weak":
360+
diffusion = diffrax.WeaklyDiagonalControlTerm(g, bm)
361+
else:
362+
diffusion = diffrax.ControlTerm(g_lx, bm)
351363
terms = diffrax.MultiTerm(drift, diffusion)
352364
solver = diffrax.Heun()
353365

test/test_integrate.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import jax.numpy as jnp
1010
import jax.random as jr
1111
import jax.tree_util as jtu
12+
import lineax as lx
1213
import pytest
1314
import scipy.stats
1415
from diffrax import ControlTerm, MultiTerm, ODETerm
@@ -644,6 +645,9 @@ class TestSolver(diffrax.AbstractSolver):
644645
"e": diffrax.MultiTerm[
645646
tuple[diffrax.ODETerm, diffrax.AbstractTerm[Any, Float[Array, " 5"]]]
646647
],
648+
"f": diffrax.MultiTerm[
649+
tuple[diffrax.ODETerm, diffrax.AbstractTerm[Any, Float[Array, " 5"]]]
650+
],
647651
}
648652
interpolation_cls = diffrax.LocalLinearInterpolation
649653

@@ -676,13 +680,21 @@ def func(self, terms, t0, y0, args):
676680
lambda t, y, args: -y, lambda t0, t1: jnp.array(t1 - t0).repeat(5)
677681
),
678682
),
683+
"f": diffrax.MultiTerm(
684+
ode_term,
685+
diffrax.ControlTerm(
686+
lambda t, y, args: lx.DiagonalLinearOperator(-y),
687+
lambda t0, t1: jnp.array(t1 - t0).repeat(5),
688+
),
689+
),
679690
}
680691
compatible_y0 = {
681692
"a": jnp.array(1.0),
682693
"b": jnp.array(2.0),
683694
"c": jnp.arange(3.0),
684695
"d": jnp.arange(4.0),
685696
"e": jnp.arange(5.0),
697+
"f": jnp.arange(5.0),
686698
}
687699
diffrax.diffeqsolve(compatible_term, solver, 0.0, 1.0, 0.1, compatible_y0)
688700

@@ -698,6 +710,13 @@ def func(self, terms, t0, y0, args):
698710
lambda t0, t1: t1 - t0, # wrong control shape
699711
),
700712
),
713+
"f": diffrax.MultiTerm(
714+
ode_term,
715+
diffrax.ControlTerm(
716+
lambda t, y, args: lx.DiagonalLinearOperator(-y),
717+
lambda t0, t1: jnp.array(t1 - t0).repeat(5),
718+
),
719+
),
701720
}
702721
incompatible_term2 = {
703722
"a": ode_term,
@@ -710,6 +729,13 @@ def func(self, terms, t0, y0, args):
710729
lambda t, y, args: -y, lambda t0, t1: jnp.array(t1 - t0).repeat(3)
711730
),
712731
),
732+
"f": diffrax.MultiTerm(
733+
ode_term,
734+
diffrax.ControlTerm(
735+
lambda t, y, args: lx.DiagonalLinearOperator(-y),
736+
lambda t0, t1: jnp.array(t1 - t0).repeat(5),
737+
),
738+
),
713739
}
714740
incompatible_term3 = {
715741
"a": ode_term,
@@ -720,6 +746,13 @@ def func(self, terms, t0, y0, args):
720746
"e": diffrax.WeaklyDiagonalControlTerm(
721747
lambda t, y, args: -y, lambda t0, t1: jnp.array(t1 - t0).repeat(3)
722748
),
749+
"f": diffrax.MultiTerm(
750+
ode_term,
751+
diffrax.ControlTerm(
752+
lambda t, y, args: lx.DiagonalLinearOperator(-y),
753+
lambda t0, t1: jnp.array(t1 - t0).repeat(5),
754+
),
755+
),
723756
}
724757

725758
incompatible_y0_1 = {
@@ -728,13 +761,15 @@ def func(self, terms, t0, y0, args):
728761
"c": jnp.arange(4.0), # of length 4, not 3
729762
"d": jnp.arange(4.0),
730763
"e": jnp.arange(5.0),
764+
"f": jnp.arange(5.0),
731765
}
732766
incompatible_y0_2 = {
733767
"a": jnp.array(1.0),
734768
"b": jnp.array(2.0),
735769
"c": jnp.arange(3.0),
736770
# Missing "d" piece
737771
"e": jnp.arange(5.0),
772+
"f": jnp.arange(5.0),
738773
}
739774
incompatible_y0_3 = jnp.array(4.0) # Completely the wrong structure!
740775
for term in (

test/test_sde.py

Lines changed: 50 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import jax.numpy as jnp
66
import jax.random as jr
77
import jax.tree_util as jtu
8+
import lineax as lx
89
import pytest
910
from diffrax import ControlTerm, MultiTerm, ODETerm, WeaklyDiagonalControlTerm
1011

@@ -270,18 +271,63 @@ def _drift(t, y, args):
270271
assert solution.ys.shape == (1, 3)
271272

272273

274+
def _lineax_weakly_diagonal_noise_helper(solver, dtype):
275+
w_shape = (3,)
276+
args = (0.5, 1.2)
277+
278+
def _diffusion(t, y, args):
279+
a, b = args
280+
return lx.DiagonalLinearOperator(jnp.array([b, t, 1 / (t + 1.0)], dtype=dtype))
281+
282+
def _drift(t, y, args):
283+
a, b = args
284+
return -a * y
285+
286+
y0 = jnp.ones(w_shape, dtype)
287+
288+
bm = diffrax.VirtualBrownianTree(
289+
0.0, 1.0, 0.05, w_shape, jr.PRNGKey(0), diffrax.SpaceTimeLevyArea
290+
)
291+
292+
terms = MultiTerm(ODETerm(_drift), ControlTerm(_diffusion, bm))
293+
saveat = diffrax.SaveAt(t1=True)
294+
solution = diffrax.diffeqsolve(
295+
terms, solver, 0.0, 1.0, 0.1, y0, args, saveat=saveat
296+
)
297+
assert solution.ys is not None
298+
assert solution.ys.shape == (1, 3)
299+
300+
273301
@pytest.mark.parametrize("solver_ctr", _solvers())
274302
@pytest.mark.parametrize(
275303
"dtype",
276304
(jnp.float64, jnp.complex128),
277305
)
278-
def test_weakly_diagonal_noise(solver_ctr, dtype):
279-
_weakly_diagonal_noise_helper(solver_ctr(), dtype)
306+
@pytest.mark.parametrize(
307+
"weak_type",
308+
("old", "lineax"),
309+
)
310+
def test_weakly_diagonal_noise(solver_ctr, dtype, weak_type):
311+
if weak_type == "old":
312+
_weakly_diagonal_noise_helper(solver_ctr(), dtype)
313+
elif weak_type == "lineax":
314+
_lineax_weakly_diagonal_noise_helper(solver_ctr(), dtype)
315+
else:
316+
raise ValueError("Invalid weak_type")
280317

281318

282319
@pytest.mark.parametrize(
283320
"dtype",
284321
(jnp.float64, jnp.complex128),
285322
)
286-
def test_halfsolver_term_compatible(dtype):
287-
_weakly_diagonal_noise_helper(diffrax.HalfSolver(diffrax.SPaRK()), dtype)
323+
@pytest.mark.parametrize(
324+
"weak_type",
325+
("old", "lineax"),
326+
)
327+
def test_halfsolver_term_compatible(dtype, weak_type):
328+
if weak_type == "old":
329+
_weakly_diagonal_noise_helper(diffrax.HalfSolver(diffrax.SPaRK()), dtype)
330+
elif weak_type == "lineax":
331+
_lineax_weakly_diagonal_noise_helper(diffrax.HalfSolver(diffrax.SPaRK()), dtype)
332+
else:
333+
raise ValueError("Invalid weak_type")

0 commit comments

Comments
 (0)