-
-
Notifications
You must be signed in to change notification settings - Fork 163
Fix for PIDController infinite loop at dtmin #704
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Fix for PIDController infinite loop at dtmin #704
Conversation
When: dt is clipped to dtmin, and we wish to continue solver (force_dtmin=True) Calculating if a step should be kept from: prev_dt = t1 - t0 (next_t1 = next_t0 + dt (in previous step)) keep_step = keep_step | (prev_dt <= self.dtmin) can result in float error for high t0 where prev_dt is never <= self.dtmin, and further steps are never accepted -> infinite loop Fix: add a keep_next_step: bool flag to controller_state, and track when we are, and continue to be, at dtmin
this solution makes sure that dt is reset to the desired dtmin value if the previous step was at dtmin and dt is unchanged (factor=1) if we do not reset dt then the recalculation of prev_dt = t1 - t0 will keep accumulating float precision errors with potential to drift away from the desired dtmin until a step that warrant a relaxation of step size (factor>1) these errors are likely to be minor, but i believe this is the intended behaviour
|
Thanks! I will take a look as soon as I'm able :) |
patrick-kidger
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One comment aside, this LGTM!
| result = RESULTS.where(dt < self.dtmin, RESULTS.dt_min_reached, result) | ||
| # if we are already at dtmin and dt is unchanged (factor == 1), | ||
| # reset dt to dtmin to avoid accumulating float precision errors | ||
| dt = jnp.where(at_dtmin & (factor == 1), self.dtmin, dt) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think dt = prev_dt * 1 should imply dt = prev_dt exactly regardless of precision errors / floating point weirdness. Did you find this line necessary?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For my use case (probably most others too), this doesn't really make a difference. I just wanted a failsafe to avoid max steps reached errors.
But thought I'd add this for completeness. Say we are at dtmin for many steps.
dt = dtmin
t1 = t0 + dt
## next step
prev_dt = t1 - t0 = (dt + e), where e is float error
t1 = t0 + (dt + e)
## next step
prev_dt = t1 - t0 = ((dt + e) + e)
t1 = t0 + ((dt + e) + e)
...
and so on
However, I just tested it, and dt seems to just get truncated to the precision of t0 and stays there until t0 jumps to the next exponent. So the error accumulation is probably negligible for most.
The lines can probably go unless you really, really care about being as close as possible to dtmin at all times?
A proposed fix for an unintended infinite loop encountered in PIDController, issue: #703
issue
This situation is encountered in
pid.PIDController.adapt_step_size()when we set adtminandforce_dtmin=True.When the previous step is performed at
dt=dtmin, the intention was forkeep_stepto be set toTrueHowever,
prev_dtis recalculated fromprev_dt = t1 - t0, when in the previous stepnext_t1 = next_t0 + dt.The intention is that
prev_dtequalsdtfrom the last step, but when the exponents oft0,t1floats are higher than ofdt, we loose float precision inprev_dt, and there is a good chance of(prev_dt <= self.dtmin)being False.In this situation, no further values are changed and steps loop and fail infinitely.
solution
Instead of
(prev_dt <= self.dtmin)we add a bool flagat_dtminwhich we pass via thecontroller_state.We set this to True when
dt <= self.dtmin. Importantly, we also want to keep it as True if we are atdtminalready, and we want to avoid an accumulation of precision errors if we stay atdtminfor a long time. So, in this case, we keep resettingdtto the desireddtmin.