22from typing import Any , Dict
33
44import equinox as eqx
5+ import equinox .internal as eqxi
56import jax .lax as lax
67import jax .numpy as jnp
78import jax .tree_util as jtu
9+ from equinox .internal import ω
810
9- from .misc import implicit_jvp , nondifferentiable_output , ω
11+ from .misc import implicit_jvp
1012from .saveat import SaveAt
1113from .term import AbstractTerm , AdjointTerm
1214
1315
16+ def _is_none (x ):
17+ return x is None
18+
19+
20+ def _no_transpose_final_state (final_state ):
21+ y = eqxi .nondifferentiable_backward (final_state .y , name = "y" )
22+ tprev = eqxi .nondifferentiable_backward (final_state .tprev , name = "tprev" )
23+ tnext = eqxi .nondifferentiable_backward (final_state .tnext , name = "tnext" )
24+ solver_state = eqxi .nondifferentiable_backward (
25+ final_state .solver_state , name = "solver_state"
26+ )
27+ controller_state = eqxi .nondifferentiable_backward (
28+ final_state .controller_state , name = "controller_state"
29+ )
30+ ts = eqxi .nondifferentiable_backward (final_state .ts , name = "ts" )
31+ ys = final_state .ys
32+ dense_ts = eqxi .nondifferentiable_backward (final_state .dense_ts , name = "dense_ts" )
33+ dense_infos = eqxi .nondifferentiable_backward (
34+ final_state .dense_infos , name = "dense_infos"
35+ )
36+ final_state = eqxi .nondifferentiable_backward (final_state ) # no more specific name
37+ final_state = eqx .tree_at (
38+ lambda s : (
39+ s .y ,
40+ s .tprev ,
41+ s .tnext ,
42+ s .solver_state ,
43+ s .controller_state ,
44+ s .ts ,
45+ s .ys ,
46+ s .dense_ts ,
47+ s .dense_infos ,
48+ ),
49+ final_state ,
50+ (
51+ y ,
52+ tprev ,
53+ tnext ,
54+ solver_state ,
55+ controller_state ,
56+ ts ,
57+ ys ,
58+ dense_ts ,
59+ dense_infos ,
60+ ),
61+ is_leaf = _is_none ,
62+ )
63+ return final_state
64+
65+
1466class AbstractAdjoint (eqx .Module ):
1567 """Abstract base class for all adjoint methods."""
1668
@@ -30,6 +82,8 @@ def loop(
3082 max_steps ,
3183 throw ,
3284 init_state ,
85+ passed_solver_state ,
86+ passed_controller_state ,
3387 ):
3488 """Runs the main solve loop. Subclasses can override this to provide custom
3589 backpropagation behaviour; see for example the implementation of
@@ -69,27 +123,26 @@ class RecursiveCheckpointAdjoint(AbstractAdjoint):
69123 For most problems this is the preferred technique for backpropagating through a
70124 differential equation.
71125
72- A binomial checkpointing scheme is used so that memory usage is low.
126+ In addition a binomial checkpointing scheme is used so that memory usage is low.
127+ (This checkpointing can increase compile time a bit, though.)
73128 """
74129
75- def loop (self , * , throw , ** kwargs ):
76- del throw
130+ def loop (self , * , throw , passed_solver_state , passed_controller_state , ** kwargs ):
131+ del throw , passed_solver_state , passed_controller_state
77132 return self ._loop_fn (** kwargs , is_bounded = True )
78133
79134
80135class NoAdjoint (AbstractAdjoint ):
81136 """Disable backpropagation through [`diffrax.diffeqsolve`][].
82-
83137 Forward-mode autodifferentiation (`jax.jvp`) will continue to work as normal.
84-
85138 If you do not need to differentiate the results of [`diffrax.diffeqsolve`][] then
86139 this may sometimes improve the speed at which the differential equation is solved.
87140 """
88141
89- def loop (self , * , throw , ** kwargs ):
90- del throw
142+ def loop (self , * , throw , passed_solver_state , passed_controller_state , ** kwargs ):
143+ del throw , passed_solver_state , passed_controller_state
91144 final_state , aux_stats = self ._loop_fn (** kwargs , is_bounded = False )
92- final_state = jtu . tree_map ( nondifferentiable_output , final_state )
145+ final_state = eqxi . nondifferentiable_backward ( final_state )
93146 return final_state , aux_stats
94147
95148
@@ -135,7 +188,19 @@ class ImplicitAdjoint(AbstractAdjoint):
135188 via the implicit function theorem.
136189 """ # noqa: E501
137190
138- def loop (self , * , args , terms , solver , saveat , throw , init_state , ** kwargs ):
191+ def loop (
192+ self ,
193+ * ,
194+ args ,
195+ terms ,
196+ solver ,
197+ saveat ,
198+ throw ,
199+ init_state ,
200+ passed_solver_state ,
201+ passed_controller_state ,
202+ ** kwargs ,
203+ ):
139204 del throw
140205
141206 # `is` check because this may return a Tracer from SaveAt(ts=<array>)
@@ -144,21 +209,30 @@ def loop(self, *, args, terms, solver, saveat, throw, init_state, **kwargs):
144209 "Can only use `adjoint=ImplicitAdjoint()` with `SaveAt(t1=True)`."
145210 )
146211
147- init_state = eqx .tree_at (
148- lambda s : (s .y , s .solver_state , s .controller_state ),
149- init_state ,
150- replace_fn = lax .stop_gradient ,
151- )
212+ if not passed_solver_state :
213+ init_state = eqx .tree_at (
214+ lambda s : s .solver_state ,
215+ init_state ,
216+ replace_fn = lax .stop_gradient ,
217+ is_leaf = _is_none ,
218+ )
219+ if not passed_controller_state :
220+ init_state = eqx .tree_at (
221+ lambda s : s .controller_state ,
222+ init_state ,
223+ replace_fn = lax .stop_gradient ,
224+ is_leaf = _is_none ,
225+ )
226+
152227 closure = (self , kwargs , solver , saveat , init_state )
153228 ys , residual = implicit_jvp (_solve , _vf , (args , terms ), closure )
154229
155230 final_state_no_ys , aux_stats = residual
156- return (
157- eqx .tree_at (
158- lambda s : s .ys , final_state_no_ys , ys , is_leaf = lambda x : x is None
159- ),
160- aux_stats ,
231+ final_state = eqx .tree_at (
232+ lambda s : s .ys , final_state_no_ys , ys , is_leaf = _is_none
161233 )
234+ final_state = _no_transpose_final_state (final_state )
235+ return final_state , aux_stats
162236
163237
164238# Compute derivatives with respect to the first argument:
@@ -174,7 +248,7 @@ def _loop_backsolve(y__args__terms, *, self, throw, init_state, **kwargs):
174248 )
175249 del y
176250 return self ._loop_fn (
177- args = args , terms = terms , init_state = init_state , ** kwargs , is_bounded = False
251+ args = args , terms = terms , init_state = init_state , is_bounded = False , ** kwargs
178252 )
179253
180254
@@ -398,7 +472,18 @@ def __init__(self, **kwargs):
398472 )
399473 self .kwargs = kwargs
400474
401- def loop (self , * , args , terms , saveat , init_state , ** kwargs ):
475+ def loop (
476+ self ,
477+ * ,
478+ args ,
479+ terms ,
480+ saveat ,
481+ init_state ,
482+ passed_solver_state ,
483+ passed_controller_state ,
484+ ** kwargs ,
485+ ):
486+ del passed_solver_state , passed_controller_state
402487 if saveat .steps or saveat .dense :
403488 raise NotImplementedError (
404489 "Cannot use `adjoint=BacksolveAdjoint()` with "
@@ -414,13 +499,5 @@ def loop(self, *, args, terms, saveat, init_state, **kwargs):
414499 final_state , aux_stats = _loop_backsolve (
415500 (y , args , terms ), self = self , saveat = saveat , init_state = init_state , ** kwargs
416501 )
417-
418- # We only allow backpropagation through `ys`; in particular not through
419- # `solver_state` etc.
420- ys = final_state .ys
421- final_state = jtu .tree_map (nondifferentiable_output , final_state )
422- final_state = eqx .tree_at (
423- lambda s : jtu .tree_leaves (s .ys ), final_state , jtu .tree_leaves (ys )
424- )
425-
502+ final_state = _no_transpose_final_state (final_state )
426503 return final_state , aux_stats
0 commit comments