Skip to content

OrdinaryDiffEqSDIRK implements wrong smooth_est error estimate #2902

@mschytt

Description

@mschytt

Problem

In the implementation of OrdinaryDiffEqSDIRK schemes, the iteration matrix W is formed via the scaled formula

W = ∂ᵤf - (hγ)⁻¹ I 

The same iteration matrix is used in order to get a smoothed error estimate, following a trick credited to Shampine (see e.g. https://doi.org/10.1016/0168-9274(95)00115-8). Indeed, using the embedded scheme, forming

err = u - û

may give an error estimate for which err has the stability characteristics of an explicit method, and can become too conservative for stiff problems. To fix this, Shampine proposes to evaluate a smoothed error estimate ERR

M ERR = err

where the iteration matrix M is formed via the unscaled formula

M = I - hγ ∂ᵤf 

Whether this is necessary for methods with L-stable main and embedded methods is unclear (to me).

Alas the schemes in OrdinaryDiffEqSDIRK implement this by default using the default keyword argument smooth_est=true. However, the implementation uses the scaled iteration matrix W instead of the unscaled M! As a consequence the computed smoothed error estimate ERR* satisfies

||ERR*|| = hγ ||ERR||

which in the presence of stiffness (h → 0) yields unreliable error control. In contrast, for large steps (hγ>1) the error estimate is too conservative, which was exactly what we wanted to avoid using Shampine's trick.

Example

To test whether this has any practical effect on performance, I tested the stiff (μ = 100) Van der Pol oscillator problem. Using the unsmoothed (conservative) error estimate, I computed a high-precision reference solution of the first solution component, and made sure to stop the integrator at a set of reference timepoints.

Image

I then computed the absolute error, at a selection of equal absolute and relative tolerances rel_tol=abs_tol=Tol, and plotted them at the reference timepoints. I also plotted the work-precision diagram for the final reference timepoint. Here :circle ⟹ :smooth_est=false and :star ⟹ :smooth-est=true (default).

Image

From the error plots, we see that the solution error is often magnified after the stiff hump (for which h → 0). Surprisingly, from the work-precision diagram, it is less obvious whether this implementation error has a performance impact.

Solution

Compute the smoothed error estimate using the scaled formula
W ERR = (hγ)⁻¹ err
It is likely that the smoothed error estimate will not make much of a difference for many of the L-stable SDIRK formulas. However, the current default seemingly delivers less reliable tolerance options.

Version info

[1dea7af3] OrdinaryDiffEq v6.102.1

Example code

using OrdinaryDiffEq
using Plots, Printf
using BenchmarkTools

# Setup plots
colors = palette(:seaborn_colorblind)

# Reference plot
plt = plot(
    title="Reference Trajectory", xlabel="t", ylabel="xᵣ(t)",
    titlefontsize=18, labelfontsize=14, tickfontsize=12,
)

# Error plots
plt1 = plot(
    title="Error (KenCarp47)", xlabel="t", ylabel="|x(t)-xᵣ(t)|",
    titlefontsize=18, labelfontsize=14, tickfontsize=12,
    yscale=:log10, legend=:bottomright
)
plt2 = plot(
    title="Error (TRBDF2)", xlabel="t", ylabel="|x(t)-xᵣ(t)|",
    titlefontsize=18, labelfontsize=14, tickfontsize=12,
    yscale=:log10, legend=:bottomright
)

# Work-precision plots
plt3 = plot(
    title="Work-Precision (KenCarp47)", xlabel="|x(T)-xᵣ(T)|", ylabel="CPU Time [s]",
    titlefontsize=18, labelfontsize=14, tickfontsize=12,
    yscale=:log10, xscale=:log10, legend=:topright
)
plt4 = plot(
    title="Work-Precision (TRBDF2)", xlabel="|x(T)-xᵣ(T)|", ylabel="CPU Time [s]",
    titlefontsize=18, labelfontsize=14, tickfontsize=12,
    yscale=:log10, xscale=:log10, legend=:topright
)

# Van der Pol oscillator
function f!(dx, x, p, t)
    μ = p[1]
    dx[1] = x[2]
    dx[2] = μ * (1 - x[1]^2) * x[2] - x[1]
    return nothing
end
x₀ = [2.0, 0.0]
p = [100.0]
tspan = [0.0, 0.9(p[1] * (3 - 2 * log(2)) + π)]
prob = ODEProblem(f!, x₀, tspan, p)

# Tolerances and timepoints
N = 10
tols = [1e-2, 1e-4, 1e-6, 1e-8, 1e-10]
tsref = range(tspan[1], tspan[2], N + 1)[2:end]

# Baseline
solref = solve(prob, KenCarp47(smooth_est=false), reltol=1e-12, abstol=1e-12, tstops=tsref)
xs = hcat(solref.u...)[1, :]
xsref = hcat(solref.(tsref)...)[1, :]
plot!(plt, solref.t, xs, color=:black, lw=2, label="")
scatter!(plt, tsref, xsref, color=:black, marker=:circle, ms=6, label="")

# Plot baseline
pltout = plot(plt, size=(600,400), margin=Plots.cm)
savefig(pltout, "reference.png")

# KenCarp47
for (i, tol) in enumerate(tols)
    tol = tols[i]
    t1 = @belapsed solve($prob, $KenCarp47(smooth_est=$false), reltol=$tol, abstol=$tol, tstops=$tsref, saveat=$tsref) seconds=1
    t2 = @belapsed solve($prob, $KenCarp47(smooth_est=$true), reltol=$tol, abstol=$tol, tstops=$tsref, saveat=$tsref) seconds=1
    sol1 = solve(prob, KenCarp47(smooth_est=false), reltol=tol, abstol=tol, tstops=tsref, saveat=tsref)
    sol2 = solve(prob, KenCarp47(smooth_est=true), reltol=tol, abstol=tol, tstops=tsref, saveat=tsref)
    xs1 = hcat(sol1.u...)[1, :]
    xs2 = hcat(sol2.u...)[1, :]
    scatter!(plt1, tsref, abs.(xs1-xsref), color=colors[i], marker=:circle, ms=6, label="Tol = $(@sprintf "%1.E" tol)")
    scatter!(plt1, tsref, abs.(xs2-xsref), color=colors[i], marker=:star, ms=6, label="")
    scatter!(plt3, [abs((xs1-xsref)[end])], [t1], color=colors[i], marker=:circle, ms=6, label="Tol = $(@sprintf "%1.E" tol)")
    scatter!(plt3, [abs((xs2-xsref)[end])], [t2], color=colors[i], marker=:star, ms=6, label="")
end

# TRBDF2
for (i, tol) in enumerate(tols)
    tol = tols[i]
    t1 = @belapsed solve($prob, $TRBDF2(smooth_est=$false), reltol=$tol, abstol=$tol, tstops=$tsref, saveat=$tsref) seconds=1
    t2 = @belapsed solve($prob, $TRBDF2(smooth_est=$true), reltol=$tol, abstol=$tol, tstops=$tsref, saveat=$tsref) seconds=1
    sol1 = solve(prob, TRBDF2(smooth_est=false), reltol=tol, abstol=tol, tstops=tsref, saveat=tsref)
    sol2 = solve(prob, TRBDF2(smooth_est=true), reltol=tol, abstol=tol, tstops=tsref, saveat=tsref)
    xs1 = hcat(sol1.u...)[1, :]
    xs2 = hcat(sol2.u...)[1, :]
    scatter!(plt2, tsref, abs.(xs1-xsref), color=colors[i], marker=:circle, ms=6, label="Tol = $(@sprintf "%1.E" tol)")
    scatter!(plt2, tsref, abs.(xs2-xsref), color=colors[i], marker=:star, ms=6, label="")
    scatter!(plt4, [abs((xs1-xsref)[end])], [t1], color=colors[i], marker=:circle, ms=6, label="Tol = $(@sprintf "%1.E" tol)")
    scatter!(plt4, [abs((xs2-xsref)[end])], [t2], color=colors[i], marker=:star, ms=6, label="")
end

# Plot error and work-precision
pltout = plot(plt1, plt2, plt3, plt4, layout=(2,2), size=(1200,800), margin=Plots.cm)
savefig(pltout, "errorwp.png")

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions