Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 12 additions & 5 deletions diffrax/_step_size_controller/pid.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,8 @@ def intermediate(carry):
return jnp.minimum(100 * h0, h1)


# _PidState = (prev_inv_scaled_error, prev_prev_inv_scaled_error)
_PidState = tuple[RealScalarLike, RealScalarLike]
# _PidState = (prev_inv_scaled_error, prev_prev_inv_scaled_error, at_dtmin)
_PidState = tuple[RealScalarLike, RealScalarLike, BoolScalarLike]


# We use a metaclass for backwards compatibility. When a user calls
Expand Down Expand Up @@ -388,6 +388,7 @@ def init(
return t1, (
jnp.array(1.0, dtype=real_dtype),
jnp.array(1.0, dtype=real_dtype),
False,
)

def adapt_step_size(
Expand Down Expand Up @@ -469,6 +470,7 @@ def adapt_step_size(
(
prev_inv_scaled_error,
prev_prev_inv_scaled_error,
at_dtmin,
) = controller_state
error_order = self._get_error_order(error_order)
prev_dt = t1 - t0
Expand All @@ -489,9 +491,9 @@ def _scale(_y0, _y1_candidate, _y_error):

scaled_error = self.norm(jtu.tree_map(_scale, y0, y1_candidate, y_error))
keep_step = scaled_error < 1
# Automatically keep the step if we're at dtmin.
# Automatically keep the step if it was at dtmin.
if self.dtmin is not None:
keep_step = keep_step | (prev_dt <= self.dtmin)
keep_step = keep_step | at_dtmin
# Make sure it's not a Python scalar and thus getting a ZeroDivisionError.
inv_scaled_error = 1 / jnp.asarray(scaled_error)
inv_scaled_error = lax.stop_gradient(
Expand Down Expand Up @@ -545,6 +547,11 @@ def _scale(_y0, _y1_candidate, _y_error):
if self.dtmin is not None:
if not self.force_dtmin:
result = RESULTS.where(dt < self.dtmin, RESULTS.dt_min_reached, result)
# if we are already at dtmin and dt is unchanged (factor == 1),
# reset dt to dtmin to avoid accumulating float precision errors
dt = jnp.where(at_dtmin & (factor == 1), self.dtmin, dt)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think dt = prev_dt * 1 should imply dt = prev_dt exactly regardless of precision errors / floating point weirdness. Did you find this line necessary?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For my use case (probably most others too), this doesn't really make a difference. I just wanted a failsafe to avoid max steps reached errors.

But thought I'd add this for completeness. Say we are at dtmin for many steps.

dt = dtmin
t1 = t0 + dt
## next step
prev_dt = t1 - t0 = (dt + e), where e is float error
t1 = t0 + (dt + e)
## next step
prev_dt = t1 - t0 = ((dt + e) + e)
t1 = t0 + ((dt + e) + e)
...
and so on

However, I just tested it, and dt seems to just get truncated to the precision of t0 and stays there until t0 jumps to the next exponent. So the error accumulation is probably negligible for most.

The lines can probably go unless you really, really care about being as close as possible to dtmin at all times?

# this flags the next loop to accept step
at_dtmin = dt <= self.dtmin
dt = jnp.maximum(dt, self.dtmin)

next_t0 = jnp.where(keep_step, t1, t0)
Expand All @@ -554,7 +561,7 @@ def _scale(_y0, _y1_candidate, _y_error):
prev_inv_scaled_error = jnp.where(
keep_step, prev_inv_scaled_error, prev_prev_inv_scaled_error
)
controller_state = inv_scaled_error, prev_inv_scaled_error
controller_state = inv_scaled_error, prev_inv_scaled_error, at_dtmin
# made_jump is handled by ClipStepSizeController, so we automatically set it to
# False
return keep_step, next_t0, next_t1, False, controller_state, result
Expand Down