Skip to content

Conversation

@philipwijesinghe
Copy link

A proposed fix for an unintended infinite loop encountered in PIDController, issue: #703

issue

This situation is encountered in pid.PIDController.adapt_step_size() when we set a dtmin and force_dtmin=True.

When the previous step is performed at dt=dtmin, the intention was for keep_step to be set to True

if self.dtmin is not None:
    keep_step = keep_step | (prev_dt <= self.dtmin)

However, prev_dt is recalculated from prev_dt = t1 - t0, when in the previous step next_t1 = next_t0 + dt.
The intention is that prev_dt equals dt from the last step, but when the exponents of t0, t1 floats are higher than of dt, we loose float precision in prev_dt, and there is a good chance of (prev_dt <= self.dtmin) being False.

In this situation, no further values are changed and steps loop and fail infinitely.

solution

Instead of (prev_dt <= self.dtmin) we add a bool flag at_dtmin which we pass via the controller_state.

We set this to True when dt <= self.dtmin. Importantly, we also want to keep it as True if we are at dtmin already, and we want to avoid an accumulation of precision errors if we stay at dtmin for a long time. So, in this case, we keep resetting dt to the desired dtmin.

if self.dtmin is not None:
    if not self.force_dtmin:
        result = RESULTS.where(dt < self.dtmin, RESULTS.dt_min_reached, result)
    # if we are already at dtmin and dt is unchanged (factor == 1),
    # reset dt to dtmin to avoid accumulating float precision errors
    dt = jnp.where(at_dtmin & (factor == 1), self.dtmin, dt)
    # this flags the next loop to accept step
    at_dtmin = dt <= self.dtmin
    dt = jnp.maximum(dt, self.dtmin)

When: dt is clipped to dtmin, and we wish to continue solver (force_dtmin=True)

Calculating if a step should be kept from:
prev_dt = t1 - t0  (next_t1 = next_t0 + dt (in previous step))
keep_step = keep_step | (prev_dt <= self.dtmin)

can result in float error for high t0 where prev_dt is never <= self.dtmin, and further steps are never accepted -> infinite loop

Fix: add a keep_next_step: bool flag to controller_state, and track when we are, and continue to be, at dtmin
this solution makes sure that dt is reset to the desired dtmin value if the previous step was at dtmin and dt is unchanged (factor=1)

if we do not reset dt then the recalculation of prev_dt = t1 - t0 will keep accumulating float precision errors with potential to drift away from the desired dtmin until a step that warrant a relaxation of step size (factor>1)

these errors are likely to be minor, but i believe this is the intended behaviour
@johannahaffner
Copy link
Contributor

Thanks! I will take a look as soon as I'm able :)

Copy link
Owner

@patrick-kidger patrick-kidger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One comment aside, this LGTM!

result = RESULTS.where(dt < self.dtmin, RESULTS.dt_min_reached, result)
# if we are already at dtmin and dt is unchanged (factor == 1),
# reset dt to dtmin to avoid accumulating float precision errors
dt = jnp.where(at_dtmin & (factor == 1), self.dtmin, dt)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think dt = prev_dt * 1 should imply dt = prev_dt exactly regardless of precision errors / floating point weirdness. Did you find this line necessary?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For my use case (probably most others too), this doesn't really make a difference. I just wanted a failsafe to avoid max steps reached errors.

But thought I'd add this for completeness. Say we are at dtmin for many steps.

dt = dtmin
t1 = t0 + dt
## next step
prev_dt = t1 - t0 = (dt + e), where e is float error
t1 = t0 + (dt + e)
## next step
prev_dt = t1 - t0 = ((dt + e) + e)
t1 = t0 + ((dt + e) + e)
...
and so on

However, I just tested it, and dt seems to just get truncated to the precision of t0 and stays there until t0 jumps to the next exponent. So the error accumulation is probably negligible for most.

The lines can probably go unless you really, really care about being as close as possible to dtmin at all times?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants