-
-
Notifications
You must be signed in to change notification settings - Fork 239
Description
Problem
In the implementation of OrdinaryDiffEqSDIRK schemes, the iteration matrix W is formed via the scaled formula
W = ∂ᵤf - (hγ)⁻¹ I
The same iteration matrix is used in order to get a smoothed error estimate, following a trick credited to Shampine (see e.g. https://doi.org/10.1016/0168-9274(95)00115-8). Indeed, using the embedded scheme, forming
err = u - û
may give an error estimate for which err has the stability characteristics of an explicit method, and can become too conservative for stiff problems. To fix this, Shampine proposes to evaluate a smoothed error estimate ERR
M ERR = err
where the iteration matrix M is formed via the unscaled formula
M = I - hγ ∂ᵤf
Whether this is necessary for methods with L-stable main and embedded methods is unclear (to me).
Alas the schemes in OrdinaryDiffEqSDIRK implement this by default using the default keyword argument smooth_est=true. However, the implementation uses the scaled iteration matrix W instead of the unscaled M! As a consequence the computed smoothed error estimate ERR* satisfies
||ERR*|| = hγ ||ERR||
which in the presence of stiffness (h → 0) yields unreliable error control. In contrast, for large steps (hγ>1) the error estimate is too conservative, which was exactly what we wanted to avoid using Shampine's trick.
Example
To test whether this has any practical effect on performance, I tested the stiff (μ = 100) Van der Pol oscillator problem. Using the unsmoothed (conservative) error estimate, I computed a high-precision reference solution of the first solution component, and made sure to stop the integrator at a set of reference timepoints.
I then computed the absolute error, at a selection of equal absolute and relative tolerances rel_tol=abs_tol=Tol, and plotted them at the reference timepoints. I also plotted the work-precision diagram for the final reference timepoint. Here :circle ⟹ :smooth_est=false and :star ⟹ :smooth-est=true (default).
From the error plots, we see that the solution error is often magnified after the stiff hump (for which h → 0). Surprisingly, from the work-precision diagram, it is less obvious whether this implementation error has a performance impact.
Solution
Compute the smoothed error estimate using the scaled formula
W ERR = (hγ)⁻¹ err
It is likely that the smoothed error estimate will not make much of a difference for many of the L-stable SDIRK formulas. However, the current default seemingly delivers less reliable tolerance options.
Version info
[1dea7af3] OrdinaryDiffEq v6.102.1Example code
using OrdinaryDiffEq
using Plots, Printf
using BenchmarkTools
# Setup plots
colors = palette(:seaborn_colorblind)
# Reference plot
plt = plot(
title="Reference Trajectory", xlabel="t", ylabel="xᵣ(t)",
titlefontsize=18, labelfontsize=14, tickfontsize=12,
)
# Error plots
plt1 = plot(
title="Error (KenCarp47)", xlabel="t", ylabel="|x(t)-xᵣ(t)|",
titlefontsize=18, labelfontsize=14, tickfontsize=12,
yscale=:log10, legend=:bottomright
)
plt2 = plot(
title="Error (TRBDF2)", xlabel="t", ylabel="|x(t)-xᵣ(t)|",
titlefontsize=18, labelfontsize=14, tickfontsize=12,
yscale=:log10, legend=:bottomright
)
# Work-precision plots
plt3 = plot(
title="Work-Precision (KenCarp47)", xlabel="|x(T)-xᵣ(T)|", ylabel="CPU Time [s]",
titlefontsize=18, labelfontsize=14, tickfontsize=12,
yscale=:log10, xscale=:log10, legend=:topright
)
plt4 = plot(
title="Work-Precision (TRBDF2)", xlabel="|x(T)-xᵣ(T)|", ylabel="CPU Time [s]",
titlefontsize=18, labelfontsize=14, tickfontsize=12,
yscale=:log10, xscale=:log10, legend=:topright
)
# Van der Pol oscillator
function f!(dx, x, p, t)
μ = p[1]
dx[1] = x[2]
dx[2] = μ * (1 - x[1]^2) * x[2] - x[1]
return nothing
end
x₀ = [2.0, 0.0]
p = [100.0]
tspan = [0.0, 0.9(p[1] * (3 - 2 * log(2)) + π)]
prob = ODEProblem(f!, x₀, tspan, p)
# Tolerances and timepoints
N = 10
tols = [1e-2, 1e-4, 1e-6, 1e-8, 1e-10]
tsref = range(tspan[1], tspan[2], N + 1)[2:end]
# Baseline
solref = solve(prob, KenCarp47(smooth_est=false), reltol=1e-12, abstol=1e-12, tstops=tsref)
xs = hcat(solref.u...)[1, :]
xsref = hcat(solref.(tsref)...)[1, :]
plot!(plt, solref.t, xs, color=:black, lw=2, label="")
scatter!(plt, tsref, xsref, color=:black, marker=:circle, ms=6, label="")
# Plot baseline
pltout = plot(plt, size=(600,400), margin=Plots.cm)
savefig(pltout, "reference.png")
# KenCarp47
for (i, tol) in enumerate(tols)
tol = tols[i]
t1 = @belapsed solve($prob, $KenCarp47(smooth_est=$false), reltol=$tol, abstol=$tol, tstops=$tsref, saveat=$tsref) seconds=1
t2 = @belapsed solve($prob, $KenCarp47(smooth_est=$true), reltol=$tol, abstol=$tol, tstops=$tsref, saveat=$tsref) seconds=1
sol1 = solve(prob, KenCarp47(smooth_est=false), reltol=tol, abstol=tol, tstops=tsref, saveat=tsref)
sol2 = solve(prob, KenCarp47(smooth_est=true), reltol=tol, abstol=tol, tstops=tsref, saveat=tsref)
xs1 = hcat(sol1.u...)[1, :]
xs2 = hcat(sol2.u...)[1, :]
scatter!(plt1, tsref, abs.(xs1-xsref), color=colors[i], marker=:circle, ms=6, label="Tol = $(@sprintf "%1.E" tol)")
scatter!(plt1, tsref, abs.(xs2-xsref), color=colors[i], marker=:star, ms=6, label="")
scatter!(plt3, [abs((xs1-xsref)[end])], [t1], color=colors[i], marker=:circle, ms=6, label="Tol = $(@sprintf "%1.E" tol)")
scatter!(plt3, [abs((xs2-xsref)[end])], [t2], color=colors[i], marker=:star, ms=6, label="")
end
# TRBDF2
for (i, tol) in enumerate(tols)
tol = tols[i]
t1 = @belapsed solve($prob, $TRBDF2(smooth_est=$false), reltol=$tol, abstol=$tol, tstops=$tsref, saveat=$tsref) seconds=1
t2 = @belapsed solve($prob, $TRBDF2(smooth_est=$true), reltol=$tol, abstol=$tol, tstops=$tsref, saveat=$tsref) seconds=1
sol1 = solve(prob, TRBDF2(smooth_est=false), reltol=tol, abstol=tol, tstops=tsref, saveat=tsref)
sol2 = solve(prob, TRBDF2(smooth_est=true), reltol=tol, abstol=tol, tstops=tsref, saveat=tsref)
xs1 = hcat(sol1.u...)[1, :]
xs2 = hcat(sol2.u...)[1, :]
scatter!(plt2, tsref, abs.(xs1-xsref), color=colors[i], marker=:circle, ms=6, label="Tol = $(@sprintf "%1.E" tol)")
scatter!(plt2, tsref, abs.(xs2-xsref), color=colors[i], marker=:star, ms=6, label="")
scatter!(plt4, [abs((xs1-xsref)[end])], [t1], color=colors[i], marker=:circle, ms=6, label="Tol = $(@sprintf "%1.E" tol)")
scatter!(plt4, [abs((xs2-xsref)[end])], [t2], color=colors[i], marker=:star, ms=6, label="")
end
# Plot error and work-precision
pltout = plot(plt1, plt2, plt3, plt4, layout=(2,2), size=(1200,800), margin=Plots.cm)
savefig(pltout, "errorwp.png")