Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions diffrax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,8 @@
from ._term import (
AbstractTerm as AbstractTerm,
ControlTerm as ControlTerm,
KLState as KLState,
make_kl_terms as make_kl_terms,
MultiTerm as MultiTerm,
ODETerm as ODETerm,
UnderdampedLangevinDiffusionTerm as UnderdampedLangevinDiffusionTerm,
Expand Down
164 changes: 164 additions & 0 deletions diffrax/_term.py
Original file line number Diff line number Diff line change
Expand Up @@ -910,6 +910,170 @@ def _to_vjp(_y, _diff_args, _diff_term):
return dy, da_y, da_diff_args, da_diff_term


class KLState(eqx.Module, strict=True):
"""
The state of the SDE and the KL divergence.
"""

posterior: Y
kl_metric: Array


def _compute_kl_integral(
drift_term1: ODETerm,
drift_term2: ODETerm,
diffusion_term: ControlTerm,
t0: RealScalarLike,
y0: Y,
args: Args,
linear_solver: lx.AbstractLinearSolver,
) -> KLState:
"""
Compute the KL divergence.
"""
drift1 = drift_term1.vf(t0, y0, args)
drift2 = drift_term2.vf(t0, y0, args)
drift = jtu.tree_map(operator.sub, drift1, drift2)

diffusion = diffusion_term.vf(t0, y0, args) # assumes same diffusion

if not isinstance(diffusion, lx.AbstractLinearOperator):
diffusion = lx.MatrixLinearOperator(diffusion)

divergences = lx.linear_solve(diffusion, drift, solver=linear_solver).value

kl_divergence = jtu.tree_map(lambda x: 0.5 * jnp.sum(x**2), divergences)
kl_divergence = jtu.tree_reduce(operator.add, kl_divergence)

return KLState(drift1, kl_divergence)


class _KLDrift(AbstractTerm):
drift1: ODETerm
drift2: ODETerm
diffusion: ControlTerm
linear_solver: lx.AbstractLinearSolver

def vf(self, t: RealScalarLike, y: KLState, args: Args) -> KLState:
y = y.posterior
return _compute_kl_integral(
self.drift1, self.drift2, self.diffusion, t, y, args, self.linear_solver
)

def contr(self, t0: RealScalarLike, t1: RealScalarLike, **kwargs) -> Control:
return t1 - t0

def prod(self, vf: VF, control: RealScalarLike) -> Y:
return jtu.tree_map(lambda v: control * v, vf)


class _KLControlTerm(AbstractTerm):
control_term: ControlTerm

def vf(self, t: RealScalarLike, y: KLState, args: Args) -> KLState:
post_vf = self.control_term.vf(t, y.posterior, args)
return KLState(post_vf, jnp.array(0.0))

def contr(self, t0: RealScalarLike, t1: RealScalarLike, **kwargs) -> Control:
return self.control_term.contr(t0, t1)

def vf_prod(
self, t: RealScalarLike, y: KLState, args: Args, control: Control
) -> KLState:
prod_post = self.control_term.vf_prod(t, y.posterior, args, control)
return KLState(prod_post, jnp.array(0.0))

def prod(self, vf: KLState, control: Control) -> KLState:
vf_post = self.control_term.prod(vf.posterior, control)
return KLState(vf_post, jnp.array(0.0))


def make_kl_terms(
posterior_sde: MultiTerm[tuple[ODETerm, ControlTerm]],
prior_sde: MultiTerm[tuple[ODETerm, ControlTerm]],
y0: Y,
linear_solver: lx.AbstractLinearSolver = lx.AutoLinearSolver(well_posed=None),
) -> tuple[MultiTerm[tuple[_KLDrift, _KLControlTerm]], KLState]:
r"""
This generates the term and initial state for estimating the KL divergence
between two SDEs with the same drift. Specifically, given SDEs of the form

$$
\mathrm{d}y(t) = f_\theta (t, y(t)) dt + g_\phi (t, y(t)) dW(t) \qquad \zeta_\theta (ts[0]) = y_0
$$

$$
\mathrm{d}z(t) = h_\psi (t, z(t)) dt + g_\phi (t, z(t)) dW(t) \qquad \nu_\psi (ts[0]) = z_0
$$

compute:

$$
\int_{ts[i-1]}^{ts[i]} g_\phi (t, y(t))^{-1} (f_\theta (t, y(y)) - h_\psi (t, y(t))) dt
$$

for every time interval. This is useful for KL based latent SDEs. The output
of the solution.ys will be a KLState containing the posterior SDE integration and the
KL integrations over time. Note that this method requires inverting the diffusion
matrix and as such, unless the diffusion is diagonal, the inverse can be extremely
costly for higher dimenions.

Each sde must be a `MultiTerm` composed of the drift `f`
and diffusion `g` and the second either a SDE. Note that the diffusions are
not checked and are assumed to be the same.

??? cite "References"

See section 5 of:

```bibtex
@inproceedings{li2020scalable,
title={Scalable gradients for stochastic differential equations},
author={Li, Xuechen and Wong, Ting-Kam Leonard and Chen, Ricky TQ and Duvenaud, David},
booktitle={International Conference on Artificial Intelligence and Statistics},
pages={3870--3882},
year={2020},
organization={PMLR}
}
```

Or section 4.3.2 of:

```bibtex
@article{kidger2022neural,
title={On neural differential equations},
author={Kidger, Patrick},
journal={arXiv preprint arXiv:2202.02435},
year={2022}
}
```

**Arguments**

- `posterior_sde`: the posterior SDE to be integrated, this is the SDE which
will have its integration tracked and logged in the `KLState`
- `prior_sde`: the prior SDE from which we are estimating the KL divergence,
this will not be fully integrated or logged.
- `y0`: the initial state
- `linear_solver`: the method for computing $g^{-1}f$.

**Returns**

A tuple containing the new terms to be fed into any SDE solver,
and the `KLState` representing the initial starting point.

""" # noqa: E501
post_drift = posterior_sde.terms[0]
prior_drift = prior_sde.terms[0]
diffusion_term = posterior_sde.terms[1]
terms = MultiTerm(
_KLDrift(post_drift, prior_drift, diffusion_term, linear_solver),
_KLControlTerm(diffusion_term),
)
state = KLState(y0, jnp.array(0.0))
return terms, state


# The Underdamped Langevin SDE trajectory consists of two components: the position
# `x` and the velocity `v`. Both of these have the same shape.
# So, by UnderdampedLangevinX we denote the shape of the x component, and by
Expand Down
3 changes: 2 additions & 1 deletion docs/api/terms.md
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ Some example term structures include:
members:
- __init__

::: diffrax.make_kl_terms

---

Expand Down Expand Up @@ -125,4 +126,4 @@ where `bm` is an [`diffrax.AbstractBrownianPath`][] and the same values of `gamm
::: diffrax.UnderdampedLangevinDiffusionTerm
options:
members:
- __init__
- __init__
Loading