diff --git a/docs/api/solvers/sde_solvers.md b/docs/api/solvers/sde_solvers.md index e307c195..9ec5b025 100644 --- a/docs/api/solvers/sde_solvers.md +++ b/docs/api/solvers/sde_solvers.md @@ -1,6 +1,9 @@ # SDE solvers -See also [How to choose a solver](../../usage/how-to-choose-a-solver.md#stochastic-differential-equations). +See also [How to choose a solver](../../usage/how-to-choose-a-solver.md#stochastic-differential-equations) +and [Advanced SDE example](../../examples/sde_example.ipynb) which gives a walkthrough of how to simulate SDEs +and how to perform optimisation with respect to SDE parameters. +For a table of all SDE solvers and their properties see [SDE solver table](../../devdocs/SDE_solver_table.md). !!! info "Term structure" diff --git a/docs/devdocs/SDE_solver_table.md b/docs/devdocs/SDE_solver_table.md new file mode 100644 index 00000000..1c12def2 --- /dev/null +++ b/docs/devdocs/SDE_solver_table.md @@ -0,0 +1,44 @@ +# SDE solver table + +For an explanation of the terms in the table, see [how to choose a solver](../usage/how-to-choose-a-solver.md#stochastic-differential-equations). + +``` ++----------------+-------+------------+------------------------------------+-------------------+----------------+------------------------------------------+ +| | SDE | Lévy | Strong/weak order (per noise type) | VF evaluations | Embedded error | Recommended for | +| | type | area +----------+--------------+----------+-------+-----------+ estimation | (and other notes) | +| | | | General | Commutative | Additive | Drift | Diffusion | | | ++----------------+-------+------------+----------+--------------+----------+-------+-----------+----------------+------------------------------------------+ +| Euler | Itô | BM only | 0.5/1.0 | 0.5/1.0 | 1.0/1.0 | 1 | 1 | No | Itô SDEs when a cheap solver is needed. | ++----------------+-------+------------+----------+--------------+----------+-------+-----------+----------------+------------------------------------------+ +| Heun | Strat | BM only | 0.5/1.0 | 1.0/1.0 | 1.0/1.0 | 2 | 2 | Yes | Standard solver for Stratonovich SDEs. | ++----------------+-------+------------+----------+--------------+----------+-------+-----------+----------------+------------------------------------------+ +| EulerHeun | Strat | BM only | 0.5/1.0 | 0.5/1.0 | 1.0/1.0 | 1 | 2 | No | Stratonovich SDEs with expensive drift. | ++----------------+-------+------------+----------+--------------+----------+-------+-----------+----------------+------------------------------------------+ +| ItoMilstein | Itô | BM only | 0.5/1.0 | 1.0/1.0 | 1.0/1.0 | 1 | 1 | No | Better than Euler for Itô SDEs, but | +| | | | | | | | | | comuptes the derivative of diffusion VF. | ++----------------+-------+------------+----------+--------------+----------+-------+-----------+----------------+------------------------------------------+ +| Stratonovich | Strat | BM only | 0.5/1.0 | 1.0/1.0 | 1.0/1.0 | 1 | 1 | No | For commutative Stratonovich SDEs when | +| Milstein | | | | | | | | | space-time Lévy area is not available. | +| | | | | | | | | | Computes derivative of diffusion VF. | ++----------------+-------+------------+----------+--------------+----------+-------+-----------+----------------+------------------------------------------+ +| ReversibleHeun | Strat | BM only | 0.5/1.0 | 1.0/1.0 | 1.0/1.0 | 2 | 2 | Yes | When a reversible solver is needed. | ++----------------+-------+------------+----------+--------------+----------+-------+-----------+----------------+------------------------------------------+ +| Midpoint | Strat | BM only | 0.5/1.0 | 1.0/1.0 | 1.0/1.0 | 2 | 2 | Yes | Usually Heun should be preferred. | ++----------------+-------+------------+----------+--------------+----------+-------+-----------+----------------+------------------------------------------+ +| Ralston | Strat | BM only | 0.5/1.0 | 1.0/1.0 | 1.0/1.0 | 2 | 2 | Yes | Usually Heun should be preferred. | ++----------------+-------+------------+----------+--------------+----------+-------+-----------+----------------+------------------------------------------+ +| ShARK | Strat | space-time | / | / | 1.5/2.0 | 2 | 2 | Yes | Additive noise SDEs. | ++----------------+-------+------------+----------+--------------+----------+-------+-----------+----------------+------------------------------------------+ +| SRA1 | Strat | space-time | / | / | 1.5/2.0 | 2 | 2 | Yes | Only slightly worse than ShARK. | ++----------------+-------+------------+----------+--------------+----------+-------+-----------+----------------+------------------------------------------+ +| SEA | Strat | space-time | / | / | 1.0/1.0 | 1 | 1 | No | Cheap solver for additive noise SDEs. | ++----------------+-------+------------+----------+--------------+----------+-------+-----------+----------------+------------------------------------------+ +| SPaRK | Strat | space-time | 0.5/1.0 | 1.0/1.0 | 1.5/2.0 | 3 | 3 | Yes | General SDEs when embedded error | +| | | | | | | | | | estimation is needed. | ++----------------+-------+------------+----------+--------------+----------+-------+-----------+----------------+------------------------------------------+ +| GeneralShARK | Strat | space-time | 0.5/1.0 | 1.0/1.0 | 1.5/2.0 | 2 | 3 | No | General SDEs when embedded error | +| | | | | | | | | | estimaiton is not needed. | ++----------------+-------+------------+----------+--------------+----------+-------+-----------+----------------+------------------------------------------+ +| SlowRK | Strat | space-time | 0.5/1.0 | 1.5/2.0 | 1.5/2.0 | 2 | 5 | No | Commutative noise SDEs. | ++----------------+-------+------------+----------+--------------+----------+-------+-----------+----------------+------------------------------------------+ +``` \ No newline at end of file diff --git a/docs/examples/sde_example.ipynb b/docs/examples/sde_example.ipynb new file mode 100644 index 00000000..f9c7fffc --- /dev/null +++ b/docs/examples/sde_example.ipynb @@ -0,0 +1,582 @@ +{ + "cells": [ + { + "cell_type": "code", + "id": "initial_id", + "metadata": { + "collapsed": true, + "ExecuteTime": { + "end_time": "2025-10-09T01:03:28.847571Z", + "start_time": "2025-10-09T01:03:28.174400Z" + } + }, + "source": [ + "%env JAX_PLATFORM_NAME=cuda\n", + "\n", + "from warnings import simplefilter\n", + "\n", + "\n", + "simplefilter(\"ignore\", category=FutureWarning)\n", + "\n", + "from functools import partial\n", + "\n", + "import diffrax\n", + "import equinox as eqx\n", + "import jax\n", + "import jax.numpy as jnp\n", + "import jax.random as jr\n", + "import matplotlib.pyplot as plt\n", + "import optax\n", + "from jaxtyping import Array\n", + "\n", + "\n", + "jax.config.update(\"jax_enable_x64\", True)\n", + "jnp.set_printoptions(precision=4, suppress=True)" + ], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "env: JAX_PLATFORM_NAME=cuda\n" + ] + } + ], + "execution_count": 1 + }, + { + "cell_type": "markdown", + "id": "86d4e8b062a81d7e", + "metadata": {}, + "source": [ + "# Advanced SDE example\n", + "\n", + "We will be simulating a Stratonovich SDE of the form:\n", + "\n", + "$$\n", + " dy(t) = f(y(t), t) dt + g(y(t), t) \\circ dw(t), \n", + "$$\n", + "\n", + "where $t \\in [0, T]$, $y(t) \\in \\mathbb{R}^e$, and $w$ is a standard Brownian motion on $\\mathbb{R}^d$. We refer to $f: \\mathbb{R}^e \\times [0, T] \\to \\mathbb{R}^e$ as the drift vector field and $g: \\mathbb{R}^e \\times [0, T] \\to \\mathbb{R}^{e \\times d}$ is the diffusion matrix field. The Stratonovich integral is denoted by $\\circ$.\n", + "\n", + "Our SDE will have the following drift and diffusion terms:\n", + "\n", + "\\begin{align*}\n", + " f(y(t), t) &= \\alpha - \\beta y(t), \\\\\n", + " g(y(t), t) &= \\gamma \\begin{bmatrix} \\Vert y(t) \\Vert_2 & 0 \\\\ 0 & y_1(t) \\\\ 0 & 10t \\end{bmatrix},\n", + "\\end{align*}\n", + "\n", + "where $\\alpha, \\gamma \\in \\mathbb{R}^3$ and $\\beta \\in \\mathbb{R}_{\\geq 0}$ are some parameters.\n", + "\n", + "Let's write the SDE in the form that Diffrax expects:" + ] + }, + { + "cell_type": "code", + "id": "ba23e9cc0370fbac", + "metadata": { + "ExecuteTime": { + "end_time": "2025-10-09T01:03:29.331509Z", + "start_time": "2025-10-09T01:03:28.853859Z" + } + }, + "source": [ + "# Drift VF (e = 3)\n", + "def f(t, y, args):\n", + " alpha, beta, gamma = args\n", + " beta = jnp.abs(beta)\n", + " assert alpha.shape == (3,)\n", + " return jnp.array(alpha - beta * y, dtype=y.dtype)\n", + "\n", + "\n", + "# Diffusion matrix field (e = 3, d = 2)\n", + "def g(t, y, args):\n", + " alpha, beta, gamma = args\n", + " assert gamma.shape == y.shape == (3,)\n", + " gamma = jnp.reshape(gamma, (3, 1))\n", + " out = gamma * jnp.array(\n", + " [[jnp.sqrt(jnp.sum(y**2)), 0.0], [0.0, 3 * y[0]], [0.0, 20 * t]], dtype=y.dtype\n", + " )\n", + " return out\n", + "\n", + "\n", + "# Initial condition\n", + "y0 = jnp.array([1.0, 1.0, 1.0])\n", + "\n", + "# Args\n", + "alpha = 0.5 * jnp.ones((3,))\n", + "beta = 1.0\n", + "gamma = jnp.ones((3,))\n", + "args = (alpha, beta, gamma)\n", + "\n", + "# Time domain\n", + "t0 = 0.0\n", + "t1 = 2.0\n", + "dt0 = 2**-9" + ], + "outputs": [], + "execution_count": 2 + }, + { + "cell_type": "markdown", + "id": "ef2ff90865907b7d", + "metadata": {}, + "source": [ + "## Brownian motion and its Levy area\n", + "\n", + "Different solvers require different information about the Brownian motion. For example, the `SPaRK` solver requires access to the space-time Levy area of the Brownian motion. The required Levy area for each solver is documented in the table at the end of this notebook, or can be checked via `solver.minimal_levy_area`.\n", + " \n", + "We will use the `VirtualBrownianTree` class to generate the Brownian motion and its Levy area." + ] + }, + { + "cell_type": "code", + "id": "4110735158215acc", + "metadata": { + "ExecuteTime": { + "end_time": "2025-10-09T01:03:29.483519Z", + "start_time": "2025-10-09T01:03:29.337297Z" + } + }, + "source": [ + "# check minimal levy area\n", + "solver = diffrax.SPaRK()\n", + "print(f\"Minimal levy area for SPaRK: {solver.minimal_levy_area}.\")\n", + "\n", + "# Brownian motion\n", + "key = jr.key(0)\n", + "bm_tol = 2**-13\n", + "bm_shape = (2,)\n", + "bm = diffrax.VirtualBrownianTree(\n", + " t0, t1, bm_tol, bm_shape, key, levy_area=diffrax.SpaceTimeLevyArea\n", + ")\n", + "\n", + "# Defining the terms of the SDE\n", + "ode_term = diffrax.ODETerm(f)\n", + "diffusion_term = diffrax.ControlTerm(g, bm) # Note that the BM is baked into the term\n", + "terms = diffrax.MultiTerm(ode_term, diffusion_term)" + ], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Minimal levy area for SPaRK: .\n" + ] + } + ], + "execution_count": 3 + }, + { + "cell_type": "markdown", + "id": "e71db03c5257bd46", + "metadata": {}, + "source": [ + "### Using `diffrax.diffeqsolve` to solve the SDE\n", + "\n", + "We will first use constant steps of size $h = 2^{-9}$ to solve the SDE. It is very important to have $h > \\mathtt{bm\\_tol}$, where $\\mathtt{bm\\_tol}$ is the tolerance of the Brownian motion. This is important because the output distribution of the VirtualBrownianTree is precise as long as the times that we sample it at are at least $\\mathtt{bm\\_tol}$ apart. For more details see the [Single-seed Brownian Motion paper](https://arxiv.org/abs/2405.06464).\n", + "\n", + " We will use the SPaRK solver to solve the SDE. SPaRK is a stochastic Runge-Kutta method that requires access to space-time Levy area." + ] + }, + { + "cell_type": "code", + "id": "8a969e1b9bd9f09", + "metadata": { + "ExecuteTime": { + "end_time": "2025-10-09T01:03:35.618860Z", + "start_time": "2025-10-09T01:03:29.493121Z" + } + }, + "source": [ + "sol = diffrax.diffeqsolve(\n", + " terms, diffrax.SPaRK(), t0, t1, dt0, y0, args, saveat=diffrax.SaveAt(steps=True)\n", + ")\n", + "\n", + "# Plotting the solution on ax1 and the BM on ax2\n", + "fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(6, 8))\n", + "ax1.plot(sol.ts, sol.ys[:, 0], label=\"y_1\")\n", + "ax1.plot(sol.ts, sol.ys[:, 1], label=\"y_2\")\n", + "ax1.plot(sol.ts, sol.ys[:, 2], label=\"y_3\")\n", + "ax1.set_title(\"SDE solution\")\n", + "ax1.legend()\n", + "\n", + "bm_vals = jax.vmap(lambda t: bm.evaluate(t0, t))(jnp.clip(sol.ts, t0, t1))\n", + "ax2.plot(sol.ts, bm_vals[:, 0], label=\"BM_1\")\n", + "ax2.plot(sol.ts, bm_vals[:, 1], label=\"BM_2\")\n", + "ax2.set_title(\"Brownian motion\")\n", + "\n", + "plt.show()" + ], + "outputs": [ + { + "data": { + "text/plain": [ + "
" + ], + "image/png": "" + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "execution_count": 4 + }, + { + "cell_type": "markdown", + "id": "fd3251c814306cd", + "metadata": {}, + "source": [ + "## Using adaptive time-stepping via the PID-controller\n", + "\n", + "In order to use adaptive time stepping, the solver must produce an estimate of its error on each step. This is then used by the PID controller to adjust the step size.\n", + "To perform this error estimation the `SPaRK` solver uses an embedded method. For solvers like `GeneralShARK`, which do not have an embedded method, we'd instead need to use `HalfSolver(GeneralShARK())` as the solver in order to estimate the error." + ] + }, + { + "cell_type": "code", + "id": "42ca5c5520079b5f", + "metadata": { + "ExecuteTime": { + "end_time": "2025-10-09T01:03:43.605326Z", + "start_time": "2025-10-09T01:03:35.705678Z" + } + }, + "source": [ + "controller = diffrax.PIDController(\n", + " rtol=0,\n", + " atol=0.005,\n", + " pcoeff=0.2,\n", + " icoeff=0.5,\n", + " dcoeff=0,\n", + " dtmin=2**-12,\n", + " dtmax=0.25,\n", + ")\n", + "\n", + "solver = diffrax.SPaRK()\n", + "# solver = diffrax.HalfSolver(diffrax.GeneralShARK())\n", + "\n", + "sol_pid_spark = diffrax.diffeqsolve(\n", + " terms,\n", + " solver,\n", + " t0,\n", + " t1,\n", + " dt0,\n", + " y0,\n", + " args,\n", + " saveat=diffrax.SaveAt(steps=True),\n", + " stepsize_controller=controller,\n", + " max_steps=2**16,\n", + ")\n", + "accepted_steps = sol_pid_spark.stats[\"num_accepted_steps\"]\n", + "rejected_steps = sol_pid_spark.stats[\"num_rejected_steps\"]\n", + "print(\n", + " f\"Accepted steps: {accepted_steps}, Rejected steps: {rejected_steps},\"\n", + " f\" total steps: {accepted_steps + rejected_steps}\"\n", + ")\n", + "\n", + "# Plot the solution on ax1 and the density of ts on ax2\n", + "fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(6, 8))\n", + "ax1.plot(sol_pid_spark.ts, sol_pid_spark.ys[:, 0], label=\"y_1\")\n", + "ax1.plot(sol_pid_spark.ts, sol_pid_spark.ys[:, 1], label=\"y_2\")\n", + "ax1.plot(sol_pid_spark.ts, sol_pid_spark.ys[:, 2], label=\"y_3\")\n", + "ax1.set_title(\"SDE solution\")\n", + "ax1.legend()\n", + "\n", + "# Plot the density of ts\n", + "# sol_pid.ts is padded with inf values at the end, so we remove them\n", + "padding_idx = jnp.argmax(jnp.isinf(sol_pid_spark.ts))\n", + "ts = sol_pid_spark.ts[:padding_idx]\n", + "ax2.hist(ts, bins=100)\n", + "ax2.set_title(\"Density of ts\")\n", + "\n", + "plt.show()" + ], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Accepted steps: 2968, Rejected steps: 1637, total steps: 4605\n" + ] + }, + { + "data": { + "text/plain": [ + "
" + ], + "image/png": "" + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "execution_count": 5 + }, + { + "cell_type": "markdown", + "id": "344b5f07d5120128", + "metadata": {}, + "source": [ + "## Solving an SDE for a batch of Brownian motions\n", + "\n", + "When doing Monte Carlo simulations, we often need to solve the same SDE for multiple Brownian motions. We can do this via `jax.vmap`." + ] + }, + { + "cell_type": "code", + "id": "ffe3ced461ebb823", + "metadata": { + "ExecuteTime": { + "end_time": "2025-10-09T01:03:43.764668Z", + "start_time": "2025-10-09T01:03:43.657163Z" + } + }, + "source": [ + "def get_terms(bm):\n", + " return diffrax.MultiTerm(ode_term, diffrax.ControlTerm(g, bm))\n", + "\n", + "\n", + "# Fix which times we step to (this is equivalent to a constant step size)\n", + "# We do this because the combination of using dt0 and SaveAt(steps=True) pads the\n", + "# output with inf values up to max_steps.\n", + "# Instead we specify exactly which times we want to save at, so Diffrax allocates\n", + "# the correct amount of memory at the outset.\n", + "num_steps = 2**8\n", + "step_times = jnp.linspace(t0, t1, num_steps + 1, endpoint=True)\n", + "constant_controller = diffrax.StepTo(ts=step_times)\n", + "saveat = diffrax.SaveAt(ts=step_times)\n", + "\n", + "\n", + "# We will vmap over keys\n", + "@eqx.filter_jit\n", + "@partial(jax.vmap, in_axes=(0, None, None))\n", + "def batch_sde_solve(key, saveat, args):\n", + " bm = diffrax.VirtualBrownianTree(\n", + " t0, t1, bm_tol, bm_shape, key, levy_area=diffrax.SpaceTimeLevyArea\n", + " )\n", + " terms = get_terms(bm)\n", + " return diffrax.diffeqsolve(\n", + " terms,\n", + " diffrax.SPaRK(),\n", + " t0,\n", + " t1,\n", + " None,\n", + " y0,\n", + " args,\n", + " saveat=saveat,\n", + " stepsize_controller=constant_controller,\n", + " )\n", + "\n", + "\n", + "# Split the keys and compute the batched solutions\n", + "num_samples = 100\n", + "keys = jr.split(jr.PRNGKey(0), num_samples)" + ], + "outputs": [], + "execution_count": 6 + }, + { + "cell_type": "code", + "id": "3c1206025f30100d", + "metadata": { + "ExecuteTime": { + "end_time": "2025-10-09T01:03:46.771758Z", + "start_time": "2025-10-09T01:03:43.769093Z" + } + }, + "source": [ + "batch_sols = batch_sde_solve(keys, saveat, args)\n", + "print(\n", + " f\"Shape of batch_sols: \"\n", + " f\"{batch_sols.ys.shape} == {num_samples} x {num_steps + 1} x (dim of y)\"\n", + ")" + ], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Shape of batch_sols: (100, 257, 3) == 100 x 257 x (dim of y)\n" + ] + } + ], + "execution_count": 7 + }, + { + "cell_type": "markdown", + "id": "71dda42d79d4c553", + "metadata": {}, + "source": [ + "## Optimizing wrt. SDE parameters\n", + "We will optimize the SDE parameters with the aim of achieving a mean of 0 and variance 4 at time `t1`." + ] + }, + { + "cell_type": "code", + "id": "d278fc2d438ffc82", + "metadata": { + "ExecuteTime": { + "end_time": "2025-10-09T01:07:34.736540Z", + "start_time": "2025-10-09T01:03:46.832280Z" + } + }, + "source": [ + "saveat_t1 = diffrax.SaveAt(t1=True)\n", + "batch_ys = batch_sde_solve(keys, saveat_t1, args).ys\n", + "print(batch_ys.shape)\n", + "ys_t1 = batch_ys[:, 0]\n", + "mean_t1 = jnp.mean(ys_t1, axis=0)\n", + "var_t1 = jnp.mean(ys_t1**2, axis=0) - mean_t1**2\n", + "print(f\"Stats at t=t1: mean={mean_t1}, var={var_t1}\")\n", + "\n", + "\n", + "# We will optimize for achieving a mean of 0\n", + "def loss(args: tuple[Array, Array, Array]):\n", + " _batch_sols = batch_sde_solve(keys, saveat_t1, args)\n", + " batch_ys = _batch_sols.ys\n", + " assert batch_ys.shape == (num_samples, 1, 3)\n", + " mean = jnp.mean(batch_ys, axis=(0, 1))\n", + " std = jnp.sqrt(jnp.mean(batch_ys**2, axis=(0, 1)) - mean**2)\n", + " target_mean = jnp.array([0.0, 1.0, 0.0])\n", + " target_stds = 2 * jnp.ones((3,))\n", + " loss = jnp.sqrt(\n", + " jnp.sum((mean - target_mean) ** 2) + jnp.sum((std - target_stds) ** 2)\n", + " )\n", + " return loss\n", + "\n", + "\n", + "# Define the parameters to optimize\n", + "alpha_opt = 0.5 * jnp.ones((3,))\n", + "beta_opt = jnp.array(1.0)\n", + "gamma_opt = jnp.ones((3,))\n", + "args_opt = (alpha_opt, beta_opt, gamma_opt)\n", + "\n", + "# Define the optimizer\n", + "num_steps = 191\n", + "schedule = optax.cosine_decay_schedule(3e-1, num_steps, 1e-2)\n", + "opt = optax.chain(\n", + " optax.scale_by_adam(b1=0.9, b2=0.99, eps=1e-8),\n", + " optax.scale_by_schedule(schedule),\n", + " optax.scale(-1),\n", + ")\n", + "# opt = optax.adam(2e-1)\n", + "opt_state = opt.init(args_opt)\n", + "\n", + "\n", + "@jax.jit\n", + "def step(i, opt_state, args):\n", + " loss_val, grad = jax.value_and_grad(loss)(args)\n", + " updates, opt_state = opt.update(grad, opt_state)\n", + "\n", + " # One way to apply updates\n", + " # args = optax.apply_updates(args, updates)\n", + "\n", + " # Another way to apply updates\n", + " args = jax.tree_util.tree_map(lambda x, u: x + u, args, updates)\n", + "\n", + " return opt_state, args, loss_val\n", + "\n", + "\n", + "for i in range(num_steps):\n", + " opt_state, args_opt, loss_val = step(i, opt_state, args_opt)\n", + " alpha_opt, beta_opt, gamma_opt = args_opt\n", + " if i % 10 == 0:\n", + " print(f\"Step {i}, loss: {loss_val}\")\n", + "\n", + "print(\n", + " f\"Optimal parameters:\\n\"\n", + " f\"alpha={alpha_opt},\"\n", + " f\" beta={beta_opt}, gamma={gamma_opt}\"\n", + ")" + ], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "(100, 1, 3)\n", + "Stats at t=t1: mean=[-1.0154 1.5479 2.0878], var=[329.4802 711.1078 424.3449]\n", + "Step 0, loss: 34.94212183883967\n", + "Step 10, loss: 4.874442625767792\n", + "Step 20, loss: 2.521634210842532\n", + "Step 30, loss: 1.4702096092338783\n", + "Step 40, loss: 0.7936488640119762\n", + "Step 50, loss: 0.20701309373712398\n", + "Step 60, loss: 0.43545896965573144\n", + "Step 70, loss: 0.48191871779789575\n", + "Step 80, loss: 0.14351136791805125\n", + "Step 90, loss: 0.42323194856385005\n", + "Step 100, loss: 0.7814543571174357\n", + "Step 110, loss: 0.5590729910392899\n", + "Step 120, loss: 0.09288914937239617\n", + "Step 130, loss: 0.1462945784163213\n", + "Step 140, loss: 0.29703455403048784\n", + "Step 150, loss: 0.06270444996116936\n", + "Step 160, loss: 0.01298645327270607\n", + "Step 170, loss: 0.08775177455266986\n", + "Step 180, loss: 0.016462953232162895\n", + "Step 190, loss: 0.018917675036979466\n", + "Optimal parameters:\n", + "alpha=[-0.1822 3.5395 -0.0834], beta=3.645413009852767, gamma=[-1.6817 -0.8223 0.149 ]\n" + ] + } + ], + "execution_count": 8 + }, + { + "cell_type": "code", + "id": "834651877787c7e6", + "metadata": { + "ExecuteTime": { + "end_time": "2025-10-09T01:07:38.550105Z", + "start_time": "2025-10-09T01:07:34.801808Z" + } + }, + "source": [ + "batch_ys_opt = batch_sde_solve(keys, saveat_t1, args_opt).ys\n", + "ys_t1 = batch_ys_opt[:, -1]\n", + "mean_t1 = jnp.mean(ys_t1, axis=0)\n", + "var_t1 = jnp.mean(ys_t1**2, axis=0) - mean_t1**2\n", + "print(f\"Stats at t=t1: mean={mean_t1}, var={var_t1}\")" + ], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Stats at t=t1: mean=[0.0001 1.0005 0.0002], var=[4.0193 4.0269 4.0155]\n" + ] + } + ], + "execution_count": 9 + }, + { + "cell_type": "markdown", + "id": "d103fe1695cdd847", + "metadata": {}, + "source": "With the magic of JAX and Diffrax we were able to differentiate through the SDE solver and optimize the parameters of the SDE to achieve the desired mean and variance at time `t1`." + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 2 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython2", + "version": "2.7.6" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/docs/usage/how-to-choose-a-solver.md b/docs/usage/how-to-choose-a-solver.md index 2dce12d1..54d890a9 100644 --- a/docs/usage/how-to-choose-a-solver.md +++ b/docs/usage/how-to-choose-a-solver.md @@ -42,54 +42,51 @@ For "split stiffness" problems, with one term that is stiff and another term tha ## Stochastic differential equations -SDE solvers are relatively specialised depending on the type of problem. Each solver will converge to either the Itô solution or the Stratonovich solution. In addition some solvers require "commutative noise". +SDE solvers are relatively specialised depending on the type of problem. Each solver will converge to either the Itô solution or the Stratonovich solution of the SDE. The Itô and Stratonovich solutions coincide iff the SDE has additive noise (as defined below). In addition some solvers require the SDE to have "commutative noise" or "additive noise". All of these terms are defined below. -!!! info "Commutative noise" +### General (noncommutative) noise +This includes any SDE of the form $dy(t) = f(y(t), t) dt + g(y(t), t) dw(t),$ where $t \in [0, T]$, $y(t) \in \mathbb{R}^e$, and $w$ is a standard Brownian motion on $\mathbb{R}^d$. We refer to $f: \mathbb{R}^e \times [0, T] \to \mathbb{R}^e$ as the drift vector field (VF) and $g: \mathbb{R}^e \times [0, T] \to \mathbb{R}^{e \times d}$ is the diffusion matrix field with columns $g_i$ for $i = 1, \ldots, d$. - Consider the SDE - $\mathrm{d}y(t) = μ(t, y(t))\mathrm{d}t + σ(t, y(t))\mathrm{d}w(t)$ +### Commutative noise +The diffusion matrix $σ$ is said to satisfy the commutativity condition if - then the diffusion matrix $σ$ is said to satisfy the commutativity condition if +$\sum_{i=1}^d g_{i, j} \frac{\partial g_{k, l}}{\partial y_i} = \sum_{i=1}^d g_{i, l} \frac{\partial g_{k, j}}{\partial y_i}$ - $\sum_{i=1}^d σ_{i, j} \frac{\partial σ_{k, l}}{\partial y_i} = \sum_{i=1}^d σ_{i, l} \frac{\partial σ_{k, j}}{\partial y_i}$ +For example, this holds: - Some common special cases in which this condition is satisfied are: +- when $g$ is a diagonal operator (i.e. $g(y,t)$ is a diagonal matrix for all $y, t$ and the i-th diagonal entry depends only on $y_i$), +- when the dimension of BM is $d=1$, or +- when the noise is additive (see below). - - the diffusion is additive ($σ$ is independent of $y$); - - the noise is scalar ($w$ is scalar-valued); - - the diffusion is diagonal ($σ$ is a diagonal matrix and such that the i-th - diagonal entry depends only on $y_i$; *not* to be confused with the simpler - but insufficient condition that $σ$ is only a diagonal matrix) +- The solver with the highest order of convergence for commutative noise SDEs is [`diffrax.SlowRK`][]. [`diffrax.ItoMilstein`][] and [`diffrax.StratonovichMilstein`][] are alternatives which evaluate the vector field fewer times per step, but also compute its derivative. + + +### Additive noise +We say that the diffusion is additive when $g$ does not depend on $y(t)$ and the SDE can be written as $dy(t) = f(y(t), t) dt + g(t) dw(t)$. + +Additive noise is a special case of commutative noise. For additive noise SDEs, the Itô and Stratonovich solutions conicide. Some solvers are specifically designed for additive noise SDEs, of these [`diffrax.SEA`][] is the cheapest, [`diffrax.ShARK`][] is the most accurate and [`diffrax.SRA1`][] is another alternative. ### Itô For Itô SDEs: +- For general noise [`diffrax.Euler`][] is a typical choice. - If the noise is commutative then [`diffrax.ItoMilstein`][] is a typical choice; -- If the noise is noncommutative then [`diffrax.Euler`][] is a typical choice. ### Stratonovich For Stratonovich SDEs: - If cheap low-accuracy solves are desired then [`diffrax.EulerHeun`][] is a typical choice. -- Otherwise, and if the noise is commutative, then [`diffrax.SlowRK`][] has the best order of convergence, but is expensive per step. [`diffrax.StratonovichMilstein`][] is a good cheap alternative. -- If the noise is noncommutative, [`diffrax.GeneralShARK`][] is the most efficient choice, while [`diffrax.Heun`][] is a good cheap alternative. -- If the noise is noncommutative and an embedded method for adaptive step size control is desired, then [`diffrax.SPaRK`][] is the recommended choice. - -### Additive noise - -Consider the SDE - -$\mathrm{d}y(t) = μ(t, y(t))\mathrm{d}t + σ(t, y(t))\mathrm{d}w(t)$ - -Then the diffusion matrix $σ$ is said to be additive if $σ(t, y) = σ(t)$. That is to say if the diffusion is independent of $y$. +- For general noise, [`diffrax.GeneralShARK`][] is the most efficient choice, while [`diffrax.Heun`][] is a good cheap alternative. +- If an embedded method for adaptive step size control is desired and the noise is noncommutative then [`diffrax.SPaRK`][] is the recommended choice. +- If the noise is commutative, then [`diffrax.SlowRK`][] has the best order of convergence, but is expensive per step. [`diffrax.StratonovichMilstein`][] is a good cheaper alternative. -In this case the Itô solution and the Stratonovich solution coincide, and mathematically speaking the choice of Itô vs Stratonovich is unimportant. Special solvers for additive-noise SDEs tend to do particularly well as compared to the general Itô or Stratonovich solvers discussed above. +### More information about SDE solvers -- The cheapest (but least accurate) solver is [`diffrax.SEA`][]. -- Otherwise [`diffrax.ShARK`][] or [`diffrax.SRA1`][] are good choices. +A detailed example of how to simulate SDEs can be found in the [SDE example](../examples/sde_example.ipynb). +A table of all SDE solvers and their properties can be found in [SDE solver table](../devdocs/SDE_solver_table.md). ### Underdamped Langevin Diffusion diff --git a/mkdocs.yml b/mkdocs.yml index a493b353..25ce7aa8 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -123,6 +123,7 @@ nav: - Second-order sensitivities: 'examples/hessian.ipynb' - Nonlinear heat PDE: 'examples/nonlinear_heat_pde.ipynb' - Underdamped Langevin diffusion: 'examples/underdamped_langevin_example.ipynb' + - Advanced SDE simulation example: 'examples/sde_example.ipynb' - Basic API: - 'api/diffeqsolve.md' - Solvers: @@ -150,3 +151,4 @@ nav: - 'devdocs/predictor_dirk.md' - 'devdocs/adjoint_commutative_noise.md' - Stochastic Runge-Kutta methods: 'devdocs/srk_example.ipynb' + - Table of SDE solvers: 'devdocs/SDE_solver_table.md'