|
10 | 10 | import jax.tree_util as jtu |
11 | 11 | from jax.typing import ArrayLike |
12 | 12 |
|
13 | | -from .adjoint import AbstractAdjoint, DirectAdjoint, RecursiveCheckpointAdjoint |
| 13 | +from .adjoint import AbstractAdjoint, RecursiveCheckpointAdjoint |
14 | 14 | from .custom_types import Array, Bool, Int, PyTree, Scalar |
15 | 15 | from .event import AbstractDiscreteTerminatingEvent |
16 | 16 | from .global_interpolation import DenseInterpolation |
@@ -415,6 +415,11 @@ def save_steps(subsaveat: SubSaveAt, save_state: SaveState) -> SaveState: |
415 | 415 | ) |
416 | 416 | new_state = eqx.tree_at(lambda s: s.result, new_state, result) |
417 | 417 |
|
| 418 | + if not _filtering: |
| 419 | + # This is only necessary for Equinox <0.11.1. |
| 420 | + # After that, this fix has been upstreamed to Equinox. |
| 421 | + # TODO: remove once we make Equinox >=0.11.1 required. |
| 422 | + new_state = jtu.tree_map(jnp.array, new_state) |
418 | 423 | return new_state |
419 | 424 |
|
420 | 425 | _filtering = True |
@@ -633,22 +638,6 @@ def diffeqsolve( |
633 | 638 | "An SDE should not be solved with adaptive step sizes with Euler's " |
634 | 639 | "method, as it may not converge to the correct solution." |
635 | 640 | ) |
636 | | - # TODO: remove these lines. |
637 | | - # |
638 | | - # These are to work around an edge case: on the backward pass, |
639 | | - # RecursiveCheckpointAdjoint currently tries to differentiate the overall |
640 | | - # per-step function wrt all floating-point arrays. In particular this includes |
641 | | - # `state.tprev`, which feeds into the control, which feeds into |
642 | | - # VirtualBrownianTree, which can't be differentiated. |
643 | | - # We're waiting on JAX to offer a way of specifying which arguments to a |
644 | | - # custom_vjp have symbolic zero *tangents* (not cotangents) so that we can more |
645 | | - # precisely determine what to differentiate wrt. |
646 | | - # |
647 | | - # We don't replace this in the case of an unsafe SDE because |
648 | | - # RecursiveCheckpointAdjoint will raise an error in that case anyway, so we |
649 | | - # should let the normal error be raised. |
650 | | - if isinstance(adjoint, RecursiveCheckpointAdjoint) and not is_unsafe_sde(terms): |
651 | | - adjoint = DirectAdjoint() |
652 | 641 | if is_unsafe_sde(terms): |
653 | 642 | if isinstance(stepsize_controller, AbstractAdaptiveStepSizeController): |
654 | 643 | raise ValueError( |
|
0 commit comments