diff --git a/diffrax/__init__.py b/diffrax/__init__.py index d35a7fac..fee98dd7 100644 --- a/diffrax/__init__.py +++ b/diffrax/__init__.py @@ -106,6 +106,7 @@ QUICSORT as QUICSORT, Ralston as Ralston, ReversibleHeun as ReversibleHeun, + Ros3p as Ros3p, SEA as SEA, SemiImplicitEuler as SemiImplicitEuler, ShARK as ShARK, diff --git a/diffrax/_integrate.py b/diffrax/_integrate.py index 6fc38ce3..0fa8f933 100644 --- a/diffrax/_integrate.py +++ b/diffrax/_integrate.py @@ -53,6 +53,7 @@ Euler, EulerHeun, ItoMilstein, + Ros3p, StratonovichMilstein, ) from ._step_size_controller import ( @@ -1034,6 +1035,10 @@ def diffeqsolve( eqx.is_array_like(xi) and jnp.iscomplexobj(xi) for xi in jtu.tree_leaves((terms, y0, args)) ): + if isinstance(solver, Ros3p): + # TODO: add complex dtype support to ros3p. + raise ValueError("Ros3p does not support complex dtypes.") + warnings.warn( "Complex dtype support in Diffrax is a work in progress and may not yet " "produce correct results. Consider splitting your computation into real " diff --git a/diffrax/_solver/__init__.py b/diffrax/_solver/__init__.py index 0a840413..4feabdf8 100644 --- a/diffrax/_solver/__init__.py +++ b/diffrax/_solver/__init__.py @@ -31,6 +31,7 @@ from .quicsort import QUICSORT as QUICSORT from .ralston import Ralston as Ralston from .reversible_heun import ReversibleHeun as ReversibleHeun +from .ros3p import Ros3p as Ros3p from .runge_kutta import ( AbstractDIRK as AbstractDIRK, AbstractERK as AbstractERK, diff --git a/diffrax/_solver/ros3p.py b/diffrax/_solver/ros3p.py new file mode 100644 index 00000000..04232039 --- /dev/null +++ b/diffrax/_solver/ros3p.py @@ -0,0 +1,236 @@ +from collections.abc import Callable +from dataclasses import dataclass +from typing import ClassVar, TypeAlias + +import equinox.internal as eqxi +import jax +import jax.lax as lax +import jax.numpy as jnp +import jax.tree_util as jtu +import lineax as lx +import numpy as np +from equinox.internal import ω +from jaxtyping import ArrayLike + +from .._custom_types import ( + Args, + BoolScalarLike, + DenseInfo, + RealScalarLike, + VF, + Y, +) +from .._local_interpolation import ThirdOrderHermitePolynomialInterpolation +from .._solution import RESULTS +from .._term import AbstractTerm +from .base import AbstractAdaptiveSolver + +_SolverState: TypeAlias = VF + + +@dataclass(frozen=True) +class _RosenbrockTableau: + """The coefficient tableau for Rosenbrock methods""" + + m_sol: np.ndarray + m_error: np.ndarray + + a_lower: tuple[np.ndarray, ...] + c_lower: tuple[np.ndarray, ...] + + α: np.ndarray + γ: np.ndarray + + num_stages: int + + # Example tableau + # + # α1 | a11 a12 a13 | c11 c12 c13 | γ1 + # α1 | a21 a22 a23 | c21 c22 c23 | γ2 + # α3 | a31 a32 a33 | c31 c32 c33 | γ3 + # ---+---------------- + # | m1 m2 m3 + # | me1 me2 me3 + + +_tableau = _RosenbrockTableau( + m_sol=np.array([2.0, 0.5773502691896258, 0.4226497308103742]), + m_error=np.array([2.113248654051871, 1.0, 0.4226497308103742]), + a_lower=( + np.array([1.267949192431123]), + np.array([1.267949192431123, 0.0]), + ), + c_lower=( + np.array([-1.607695154586736]), + np.array([-3.464101615137755, -1.732050807568877]), + ), + α=np.array([0.0, 1.0, 1.0]), + γ=np.array( + [ + 0.7886751345948129, + -0.2113248654051871, + -1.0773502691896260, + ] + ), + num_stages=3, +) + + +class Ros3p(AbstractAdaptiveSolver): + r"""Ros3p method. + + 3rd order Rosenbrock method for solving stiff equation. Uses third-order Hermite + polynomial interpolation for dense output. + + ??? cite "Reference" + + ```bibtex + @article{LangVerwer2001ROS3P, + author = {Lang, J. and Verwer, J.}, + title = {ROS3P---An Accurate Third-Order Rosenbrock Solver Designed + for Parabolic Problems}, + journal = {BIT Numerical Mathematics}, + volume = {41}, + number = {4}, + pages = {731--738}, + year = {2001}, + doi = {10.1023/A:1021900219772} + } + ``` + """ + + term_structure: ClassVar = AbstractTerm[ArrayLike, ArrayLike] + interpolation_cls: ClassVar[ + Callable[..., ThirdOrderHermitePolynomialInterpolation] + ] = ThirdOrderHermitePolynomialInterpolation.from_k + + tableau: ClassVar[_RosenbrockTableau] = _tableau + + def init(self, terms, t0, t1, y0, args) -> _SolverState: + del t1 + return terms.vf(t0, y0, args) + + def order(self, terms): + return 3 + + def step( + self, + terms: AbstractTerm[ArrayLike, ArrayLike], + t0: RealScalarLike, + t1: RealScalarLike, + y0: Y, + args: Args, + solver_state: _SolverState, + made_jump: BoolScalarLike, + ) -> tuple[Y, Y, DenseInfo, _SolverState, RESULTS]: + y0_leaves = jtu.tree_leaves(y0) + sol_dtype = jnp.result_type(*y0_leaves) + + time_derivative = jax.jacfwd(lambda t: terms.vf(t, y0, args))(t0) + control = terms.contr(t0, t1) + + γ = jnp.array(self.tableau.γ, dtype=sol_dtype) + α = jnp.array(self.tableau.α, dtype=sol_dtype) + + def embed_lower(x): + out = np.zeros( + (self.tableau.num_stages, self.tableau.num_stages), dtype=x[0].dtype + ) + for i, val in enumerate(x): + out[i + 1, : i + 1] = val + return jnp.array(out, dtype=sol_dtype) + + a_lower = embed_lower(self.tableau.a_lower) + c_lower = embed_lower(self.tableau.c_lower) + m_sol = jnp.array(self.tableau.m_sol, dtype=sol_dtype) + m_error = jnp.array(self.tableau.m_error, dtype=sol_dtype) + + # common L.H.S + eye_shape = jax.ShapeDtypeStruct(time_derivative.shape, dtype=sol_dtype) + A = (lx.IdentityLinearOperator(eye_shape) / (control * γ[0])) - ( + lx.JacobianLinearOperator( + lambda y, args: terms.vf(t0, y, args), y0, args=args + ) + ) + + u = jnp.zeros( + (self.tableau.num_stages,) + time_derivative.shape, dtype=sol_dtype + ) + + def use_saved_vf(u): + stage_0_vf = solver_state + stage_0_b = ( + stage_0_vf**ω + (control**ω * γ[0] ** ω * time_derivative**ω) + ).ω + stage_0_u = lx.linear_solve(A, stage_0_b).value + + u = u.at[0].set(stage_0_u) + start_stage = 1 + return u, start_stage + + if made_jump is False: + u, start_stage = use_saved_vf(u) + else: + u, start_stage = lax.cond( + eqxi.unvmap_any(made_jump), lambda u: (u, 0), use_saved_vf, u + ) + + def body(u, stage): + vf = terms.vf( + (t0**ω + α[stage] ** ω * control**ω).ω, + ( + y0**ω + + (a_lower[stage][0] ** ω * u[0] ** ω) + + (a_lower[stage][1] ** ω * u[1] ** ω) + ).ω, + args, + ) + b = ( + vf**ω + + ((c_lower[stage][0] ** ω / control**ω) * u[0] ** ω) + + ((c_lower[stage][1] ** ω / control**ω) * u[1] ** ω) + + (control**ω * γ[stage] ** ω * time_derivative**ω) + ).ω + stage_u = lx.linear_solve(A, b).value + u = u.at[stage].set(stage_u) + return u, vf + + u, stage_vf = lax.scan( + f=body, init=u, xs=jnp.arange(start_stage, self.tableau.num_stages) + ) + + y1 = ( + y0**ω + + m_sol[0] ** ω * u[0] ** ω + + m_sol[1] ** ω * u[1] ** ω + + m_sol[2] ** ω * u[2] ** ω + ).ω + y1_lower = ( + y0**ω + + m_error[0] ** ω * u[0] ** ω + + m_error[1] ** ω * u[1] ** ω + + m_error[2] ** ω * u[2] ** ω + ).ω + y1_error = y1 - y1_lower + + if start_stage == 0: + vf0 = stage_vf[0] # type: ignore + else: + vf0 = solver_state + vf1 = terms.vf(t1, y1, args) + k = jnp.stack((terms.prod(vf0, control), terms.prod(vf1, control))) + + dense_info = dict(y0=y0, y1=y1, k=k) + return y1, y1_error, dense_info, vf1, RESULTS.successful + + def func( + self, + terms: AbstractTerm[ArrayLike, ArrayLike], + t0: RealScalarLike, + y0: Y, + args: Args, + ) -> VF: + return terms.vf(t0, y0, args) + + +Ros3p.__init__.__doc__ = """**Arguments:** None""" diff --git a/test/helpers.py b/test/helpers.py index 97b0f074..f49b0b1e 100644 --- a/test/helpers.py +++ b/test/helpers.py @@ -38,6 +38,7 @@ diffrax.Kvaerno3(), diffrax.Kvaerno4(), diffrax.Kvaerno5(), + diffrax.Ros3p(), ) all_split_solvers = ( diff --git a/test/test_detest.py b/test/test_detest.py index 6dbb20e3..61234877 100644 --- a/test/test_detest.py +++ b/test/test_detest.py @@ -418,6 +418,12 @@ def _test(solver, problems, higher): # size. (To avoid the adaptive step sizing sabotaging us.) dt0 = 0.001 stepsize_controller = diffrax.ConstantStepSize() + elif type(solver) is diffrax.Ros3p and problem is _a1: + # Ros3p underestimates the error for _a1. This causes the step-size controller + # to take larger steps and results in an inaccurate solution. + dt0 = 0.0001 + max_steps = 20_000_001 + stepsize_controller = diffrax.ConstantStepSize() else: dt0 = None if solver.order(term) < 4: # pyright: ignore @@ -427,6 +433,7 @@ def _test(solver, problems, higher): rtol = 1e-8 atol = 1e-8 stepsize_controller = diffrax.PIDController(rtol=rtol, atol=atol) + sol = diffrax.diffeqsolve( term, solver=solver, diff --git a/test/test_integrate.py b/test/test_integrate.py index cfcaadfd..e9a4954b 100644 --- a/test/test_integrate.py +++ b/test/test_integrate.py @@ -150,6 +150,10 @@ def test_ode_order(solver, dtype): A = jr.normal(akey, (10, 10), dtype=dtype) * 0.5 + if isinstance(solver, diffrax.Ros3p) and dtype == jnp.complex128: + ## complex support is not added to ros3p. + return + if ( solver.term_structure == diffrax.MultiTerm[tuple[diffrax.AbstractTerm, diffrax.AbstractTerm]] diff --git a/test/test_interpolation.py b/test/test_interpolation.py index d299b090..0ac6b47f 100644 --- a/test/test_interpolation.py +++ b/test/test_interpolation.py @@ -57,6 +57,10 @@ def test_derivative(dtype, getkey): paths.append((local_linear_interp, "local linear", ys[0], ys[-1])) for solver in all_ode_solvers: + if isinstance(solver, diffrax.Ros3p) and dtype == jnp.complex128: + # ros3p does not support complex type. + continue + solver = implicit_tol(solver) y0 = jr.normal(getkey(), (3,), dtype=dtype) diff --git a/test/test_solver.py b/test/test_solver.py index a022f644..331eec43 100644 --- a/test/test_solver.py +++ b/test/test_solver.py @@ -58,9 +58,9 @@ class _DoubleDopri5(diffrax.AbstractRungeKutta): tableau: ClassVar[diffrax.MultiButcherTableau] = diffrax.MultiButcherTableau( diffrax.Dopri5.tableau, diffrax.Dopri5.tableau ) - calculate_jacobian: ClassVar[diffrax.CalculateJacobian] = ( - diffrax.CalculateJacobian.never - ) + calculate_jacobian: ClassVar[ + diffrax.CalculateJacobian + ] = diffrax.CalculateJacobian.never @staticmethod def interpolation_cls(**kwargs): @@ -415,6 +415,7 @@ def f2(t, y, args): diffrax.KenCarp3(), diffrax.KenCarp4(), diffrax.KenCarp5(), + diffrax.Ros3p(), ), ) def test_rober(solver): @@ -479,6 +480,38 @@ def vector_field(t, y, args): f(1.0) +def test_ros3p(): + term = diffrax.ODETerm(lambda t, y, args: -50.0 * y + jnp.sin(t)) + solver = diffrax.Ros3p() + t0 = 0 + t1 = 5 + y0 = 0 + ts = jnp.array([1.0, 2.0, 3.0], dtype=jnp.float64) + saveat = diffrax.SaveAt(ts=ts) + + stepsize_controller = diffrax.PIDController(rtol=1e-10, atol=1e-12) + sol = diffrax.diffeqsolve( + term, + solver, + t0=t0, + t1=t1, + dt0=0.1, + y0=y0, + stepsize_controller=stepsize_controller, + max_steps=60000, + saveat=saveat, + ) + + def exact_sol(t): + return ( + jnp.exp(-50.0 * t) * (y0 + 1 / 2501) + + (50.0 * jnp.sin(t) - jnp.cos(t)) / 2501 + ) + + ys_ref = jtu.tree_map(exact_sol, ts) + tree_allclose(ys_ref, sol.ys) + + # Doesn't crash def test_adaptive_dt0_semiimplicit_euler(): f = diffrax.ODETerm(lambda t, y, args: y)