From e946857ee488fbb9ef269bfbb33425e72e6791d1 Mon Sep 17 00:00:00 2001 From: Philip Wijesinghe Date: Wed, 5 Nov 2025 11:46:58 +0000 Subject: [PATCH 1/2] fix float error in prev_dt step calculation that led to an infinite loop When: dt is clipped to dtmin, and we wish to continue solver (force_dtmin=True) Calculating if a step should be kept from: prev_dt = t1 - t0 (next_t1 = next_t0 + dt (in previous step)) keep_step = keep_step | (prev_dt <= self.dtmin) can result in float error for high t0 where prev_dt is never <= self.dtmin, and further steps are never accepted -> infinite loop Fix: add a keep_next_step: bool flag to controller_state, and track when we are, and continue to be, at dtmin --- diffrax/_step_size_controller/pid.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/diffrax/_step_size_controller/pid.py b/diffrax/_step_size_controller/pid.py index 7fb034fb..8040a4c0 100644 --- a/diffrax/_step_size_controller/pid.py +++ b/diffrax/_step_size_controller/pid.py @@ -81,8 +81,8 @@ def intermediate(carry): return jnp.minimum(100 * h0, h1) -# _PidState = (prev_inv_scaled_error, prev_prev_inv_scaled_error) -_PidState = tuple[RealScalarLike, RealScalarLike] +# _PidState = (prev_inv_scaled_error, prev_prev_inv_scaled_error, keep_next_step) +_PidState = tuple[RealScalarLike, RealScalarLike, BoolScalarLike] # We use a metaclass for backwards compatibility. When a user calls @@ -388,6 +388,7 @@ def init( return t1, ( jnp.array(1.0, dtype=real_dtype), jnp.array(1.0, dtype=real_dtype), + False, ) def adapt_step_size( @@ -469,6 +470,7 @@ def adapt_step_size( ( prev_inv_scaled_error, prev_prev_inv_scaled_error, + keep_next_step, ) = controller_state error_order = self._get_error_order(error_order) prev_dt = t1 - t0 @@ -489,9 +491,9 @@ def _scale(_y0, _y1_candidate, _y_error): scaled_error = self.norm(jtu.tree_map(_scale, y0, y1_candidate, y_error)) keep_step = scaled_error < 1 - # Automatically keep the step if we're at dtmin. + # Automatically keep the step if it was at dtmin. if self.dtmin is not None: - keep_step = keep_step | (prev_dt <= self.dtmin) + keep_step = keep_step | keep_next_step # Make sure it's not a Python scalar and thus getting a ZeroDivisionError. inv_scaled_error = 1 / jnp.asarray(scaled_error) inv_scaled_error = lax.stop_gradient( @@ -545,6 +547,9 @@ def _scale(_y0, _y1_candidate, _y_error): if self.dtmin is not None: if not self.force_dtmin: result = RESULTS.where(dt < self.dtmin, RESULTS.dt_min_reached, result) + # flag next step to be kept if dtmin is reached + # or if it was reached previously and dt is unchanged + keep_next_step = (dt <= self.dtmin) | (keep_next_step & (factor == 1)) dt = jnp.maximum(dt, self.dtmin) next_t0 = jnp.where(keep_step, t1, t0) @@ -554,7 +559,7 @@ def _scale(_y0, _y1_candidate, _y_error): prev_inv_scaled_error = jnp.where( keep_step, prev_inv_scaled_error, prev_prev_inv_scaled_error ) - controller_state = inv_scaled_error, prev_inv_scaled_error + controller_state = inv_scaled_error, prev_inv_scaled_error, keep_next_step # made_jump is handled by ClipStepSizeController, so we automatically set it to # False return keep_step, next_t0, next_t1, False, controller_state, result From 53958470d5d82685749deee85cfcb12183b75b7a Mon Sep 17 00:00:00 2001 From: Philip Wijesinghe Date: Thu, 6 Nov 2025 09:08:57 +0000 Subject: [PATCH 2/2] avoids accumulation of float precision errors in dt this solution makes sure that dt is reset to the desired dtmin value if the previous step was at dtmin and dt is unchanged (factor=1) if we do not reset dt then the recalculation of prev_dt = t1 - t0 will keep accumulating float precision errors with potential to drift away from the desired dtmin until a step that warrant a relaxation of step size (factor>1) these errors are likely to be minor, but i believe this is the intended behaviour --- diffrax/_step_size_controller/pid.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/diffrax/_step_size_controller/pid.py b/diffrax/_step_size_controller/pid.py index 8040a4c0..2dcb09f0 100644 --- a/diffrax/_step_size_controller/pid.py +++ b/diffrax/_step_size_controller/pid.py @@ -81,7 +81,7 @@ def intermediate(carry): return jnp.minimum(100 * h0, h1) -# _PidState = (prev_inv_scaled_error, prev_prev_inv_scaled_error, keep_next_step) +# _PidState = (prev_inv_scaled_error, prev_prev_inv_scaled_error, at_dtmin) _PidState = tuple[RealScalarLike, RealScalarLike, BoolScalarLike] @@ -470,7 +470,7 @@ def adapt_step_size( ( prev_inv_scaled_error, prev_prev_inv_scaled_error, - keep_next_step, + at_dtmin, ) = controller_state error_order = self._get_error_order(error_order) prev_dt = t1 - t0 @@ -493,7 +493,7 @@ def _scale(_y0, _y1_candidate, _y_error): keep_step = scaled_error < 1 # Automatically keep the step if it was at dtmin. if self.dtmin is not None: - keep_step = keep_step | keep_next_step + keep_step = keep_step | at_dtmin # Make sure it's not a Python scalar and thus getting a ZeroDivisionError. inv_scaled_error = 1 / jnp.asarray(scaled_error) inv_scaled_error = lax.stop_gradient( @@ -547,9 +547,11 @@ def _scale(_y0, _y1_candidate, _y_error): if self.dtmin is not None: if not self.force_dtmin: result = RESULTS.where(dt < self.dtmin, RESULTS.dt_min_reached, result) - # flag next step to be kept if dtmin is reached - # or if it was reached previously and dt is unchanged - keep_next_step = (dt <= self.dtmin) | (keep_next_step & (factor == 1)) + # if we are already at dtmin and dt is unchanged (factor == 1), + # reset dt to dtmin to avoid accumulating float precision errors + dt = jnp.where(at_dtmin & (factor == 1), self.dtmin, dt) + # this flags the next loop to accept step + at_dtmin = dt <= self.dtmin dt = jnp.maximum(dt, self.dtmin) next_t0 = jnp.where(keep_step, t1, t0) @@ -559,7 +561,7 @@ def _scale(_y0, _y1_candidate, _y_error): prev_inv_scaled_error = jnp.where( keep_step, prev_inv_scaled_error, prev_prev_inv_scaled_error ) - controller_state = inv_scaled_error, prev_inv_scaled_error, keep_next_step + controller_state = inv_scaled_error, prev_inv_scaled_error, at_dtmin # made_jump is handled by ClipStepSizeController, so we automatically set it to # False return keep_step, next_t0, next_t1, False, controller_state, result