-
-
Notifications
You must be signed in to change notification settings - Fork 163
AbstractReversibleSolver + ReversibleAdjoint #603
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: dev
Are you sure you want to change the base?
AbstractReversibleSolver + ReversibleAdjoint #603
Conversation
|
I've also added the |
patrick-kidger
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay, gosh, this one took far too long for me to get around. Thank you for your patience! If I can I'd like this to be the next big thing I focus on getting in to Diffrax.
| reversible_save_index + 1, tprev, reversible_ts | ||
| ) | ||
| reversible_save_index = reversible_save_index + jnp.where(keep_step, 1, 0) | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A very minor bug here: if it just so happens that we run with t0 == t1 then we'll end up with reversible_ts = [t0 inf inf inf ...], which will not produce desired results in the backward solve.
We have a special branch to handle the saving in the t0 == t1 case, we should add a line handling the state.reversible_ts is not None case there.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See #603 (comment).
| # Reversible info | ||
| if max_steps is None: | ||
| reversible_ts = None | ||
| reversible_save_index = None | ||
| else: | ||
| reversible_ts = jnp.full(max_steps + 1, jnp.inf, dtype=time_dtype) | ||
| reversible_save_index = 0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've thought of an alternative for this extra buffer, btw: ReversibleAdjoint.loop could intercept saveat and add an SubSaveAt(steps=True, save_fn=lambda t, y, args: None) to record the extra times. Then peel it off again when returning the final state.
I think that (a) might be doable without making any changes to _integrate.py and (b) would allow for also supporting SaveAt(steps=True). (As in that case we can just skip adding the extra SubSaveAt.) And (c) would avoid a few of the subtle issues I've commented on above about exactly which tprev/tnext-like value is actually being saved, because you can trust in the rest of the existing diffeqsolve to do that for you.
It's not a strong suggestion though.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, this was the original idea I tried but I couldn't get around a leaked tracer error! I'm willing to give it another go if you start feeling strongly about it though ;)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe let's nail everything else down and then consider this. Reflecting on this, I do suspect it will make the code much easier to maintain in the long run.
0cfd4ec to
3a26ac3
Compare
patrick-kidger
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure if you were ready for a review on this yet, but I took a look over anyway 😁 We're making really good progress! In particular now that we're settled on just AbstractERK then I think all our complicated state-reconstruction concerns go away, so the chance of footgunning ourselves has gone way down 😁
|
Heads-up that I've just updated the base branch to (Unrelatedly, lmk when this branch is ready for review.) |
add reversible testing testing AbstractReversibleSolver + ReversibleAdjoint allow arbitrary interpolation unpacking over indexing jax while loop collapse saveat ValueErrors remove statonovich solver condition remove unused returns from AbstractReversibleSolver backward_step add test and remove messy benchmark add wrapped solver + tests made_jump=True for both solver steps improve docstrings AbstractSolver and docstring note about SDEs add AbstractReversibleSolver to public API newline in docstrings return RESULTS from reversible backward_step restrict Reversible to AbstractERK and check result in adjoint correct tprev and tnext of solver init switch to linear interpolation and y0,y1 dense_info name UReversible various doc formatting changes AbstractReversibleSolver check add disable_fsal property to AbstractRungeKutta and use in UReversible allow t0 != 0 Handle StepTo controller t0==t1 branch
8d59058 to
ae01942
Compare
|
I think I've now addressed all of your comments, so it should be ready for review 👍 Understanding how to rebase through multiple merges was an experience but I believe that is correct now... |
| final_state, | ||
| (reversible_ts, reversible_save_index), | ||
| is_leaf=_is_none, | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've left this t0==t1 branch separate from the above jax.lax.cond for readability but we can obviously combine if this is okay.
|
(No pressure to review anytime soon Patrick, I just marked this as "review requested" so it's easy for you to see across your sprawling jax empire) |
Re-opening #593.
Implements
AbstractReversibleSolverbase class andReversibleAdjointfor reversible back propagation.This updates
SemiImplicitEuler,LeapfrogMidpointandReversibleHeunto subclassAbstractReversibleSolver.Implementation
AbstractReversibleSolversubclassesAbstractSolverand adds abackward_stepmethod:This method should reconstruct
y0,solver_stateatt0fromy1,solver_stateatt1. See the aforementioned solvers for examples.When backpropagating,
ReversibleAdjointuses thisbackward_stepto reconstruct state. We then take avjpthrough a local forward step and accumulate gradients.ReversibleAdjointnow also pulls back gradients from any interpolated values, so we can useSaveAt(ts=...)!We allow arbitrary
solver_state(provided it can be reconstructed reversibly) and calculate gradients w.r.t.solver_state. Finally, we pull back these gradients ontoy0,args,termsusing thesolver.initmethod.