Skip to content

Conversation

@sammccallum
Copy link

Re-opening #593.

Implements AbstractReversibleSolver base class and ReversibleAdjoint for reversible back propagation.

This updates SemiImplicitEuler, LeapfrogMidpoint and ReversibleHeun to subclass AbstractReversibleSolver.

Implementation

AbstractReversibleSolver subclasses AbstractSolver and adds a backward_step method:

@abc.abstractmethod
def backward_step(
    self,
    terms: PyTree[AbstractTerm],
    t0: RealScalarLike,
    t1: RealScalarLike,
    y1: Y,
    args: Args,
    solver_state: _SolverState,
    made_jump: BoolScalarLike,
) -> tuple[Y, DenseInfo, _SolverState]:

This method should reconstruct y0, solver_state at t0 from y1, solver_state at t1. See the aforementioned solvers for examples.

When backpropagating, ReversibleAdjoint uses this backward_step to reconstruct state. We then take a vjp through a local forward step and accumulate gradients.

ReversibleAdjoint now also pulls back gradients from any interpolated values, so we can use SaveAt(ts=...)!

We allow arbitrary solver_state (provided it can be reconstructed reversibly) and calculate gradients w.r.t. solver_state. Finally, we pull back these gradients onto y0, args, terms using the solver.init method.

@sammccallum
Copy link
Author

I've also added the Reversible RK solvers here which just subclasses AbstractReversibleAdjoint. Let me know what you think of this and I can add some documentation when it's good to go!

Copy link
Owner

@patrick-kidger patrick-kidger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, gosh, this one took far too long for me to get around. Thank you for your patience! If I can I'd like this to be the next big thing I focus on getting in to Diffrax.

reversible_save_index + 1, tprev, reversible_ts
)
reversible_save_index = reversible_save_index + jnp.where(keep_step, 1, 0)

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A very minor bug here: if it just so happens that we run with t0 == t1 then we'll end up with reversible_ts = [t0 inf inf inf ...], which will not produce desired results in the backward solve.

We have a special branch to handle the saving in the t0 == t1 case, we should add a line handling the state.reversible_ts is not None case there.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comment on lines 1409 to 1464
# Reversible info
if max_steps is None:
reversible_ts = None
reversible_save_index = None
else:
reversible_ts = jnp.full(max_steps + 1, jnp.inf, dtype=time_dtype)
reversible_save_index = 0
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've thought of an alternative for this extra buffer, btw: ReversibleAdjoint.loop could intercept saveat and add an SubSaveAt(steps=True, save_fn=lambda t, y, args: None) to record the extra times. Then peel it off again when returning the final state.

I think that (a) might be doable without making any changes to _integrate.py and (b) would allow for also supporting SaveAt(steps=True). (As in that case we can just skip adding the extra SubSaveAt.) And (c) would avoid a few of the subtle issues I've commented on above about exactly which tprev/tnext-like value is actually being saved, because you can trust in the rest of the existing diffeqsolve to do that for you.

It's not a strong suggestion though.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, this was the original idea I tried but I couldn't get around a leaked tracer error! I'm willing to give it another go if you start feeling strongly about it though ;)

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe let's nail everything else down and then consider this. Reflecting on this, I do suspect it will make the code much easier to maintain in the long run.

@sammccallum sammccallum force-pushed the AbstractReversibleSolver branch from 0cfd4ec to 3a26ac3 Compare May 14, 2025 10:14
Copy link
Owner

@patrick-kidger patrick-kidger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure if you were ready for a review on this yet, but I took a look over anyway 😁 We're making really good progress! In particular now that we're settled on just AbstractERK then I think all our complicated state-reconstruction concerns go away, so the chance of footgunning ourselves has gone way down 😁

@patrick-kidger patrick-kidger changed the base branch from main to dev June 16, 2025 22:39
@patrick-kidger
Copy link
Owner

patrick-kidger commented Jun 16, 2025

Heads-up that I've just updated the base branch to dev. It looks like there are a number of old commits sitting around on this PR, likely from where this branch forked off of main. You should be able to resolve these by first (a) squashing all the commits that actually belong on this branch together, and then (b) rebasing that new single commit on top of dev.

(Unrelatedly, lmk when this branch is ready for review.)

add reversible

testing

testing

AbstractReversibleSolver + ReversibleAdjoint

allow arbitrary interpolation

unpacking over indexing

jax while loop

collapse saveat ValueErrors

remove statonovich solver condition

remove unused returns from AbstractReversibleSolver backward_step

add test and remove messy benchmark

add wrapped solver + tests

made_jump=True for both solver steps

improve docstrings

AbstractSolver and docstring note about SDEs

add AbstractReversibleSolver to public API

newline in docstrings

return RESULTS from reversible backward_step

restrict Reversible to AbstractERK and check result in adjoint

correct tprev and tnext of solver init

switch to linear interpolation and y0,y1 dense_info

name UReversible

various doc formatting changes

AbstractReversibleSolver check

add disable_fsal property to AbstractRungeKutta and use in UReversible

allow t0 != 0

Handle StepTo controller

t0==t1 branch
@sammccallum sammccallum force-pushed the AbstractReversibleSolver branch from 8d59058 to ae01942 Compare October 12, 2025 15:35
@sammccallum
Copy link
Author

I think I've now addressed all of your comments, so it should be ready for review 👍

Understanding how to rebase through multiple merges was an experience but I believe that is correct now...

final_state,
(reversible_ts, reversible_save_index),
is_leaf=_is_none,
)
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've left this t0==t1 branch separate from the above jax.lax.cond for readability but we can obviously combine if this is okay.

@sammccallum
Copy link
Author

(No pressure to review anytime soon Patrick, I just marked this as "review requested" so it's easy for you to see across your sprawling jax empire)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants