@@ -51,6 +51,7 @@ def _run(y0__args__term, saveat, adjoint):
5151 ).ys
5252 )
5353
54+ # Only does gradients with respect to y0
5455 def _run_finite_diff (y0__args__term , saveat , adjoint ):
5556 y0 , args , term = y0__args__term
5657 y0_a = y0 + jnp .array ([1e-5 , 0 ])
@@ -62,14 +63,40 @@ def _run_finite_diff(y0__args__term, saveat, adjoint):
6263 out_b = (val_b - val ) / 1e-5
6364 return jnp .stack ([out_a , out_b ])
6465
65- diff , nondiff = eqx .partition (y0__args__term , eqx .is_inexact_array )
66- _run_grad = eqx . filter_jit (
67- jax . grad (
68- lambda d , saveat , adjoint : _run (eqx .combine (d , nondiff ), saveat , adjoint )
69- )
70- )
66+ inexact , static = eqx .partition (y0__args__term , eqx .is_inexact_array )
67+
68+ def _run_inexact ( inexact , saveat , adjoint ):
69+ return _run (eqx .combine (inexact , static ), saveat , adjoint )
70+
71+ _run_grad = eqx . filter_jit ( jax . grad ( _run_inexact ) )
7172 _run_grad_int = eqx .filter_jit (jax .grad (_run , allow_int = True ))
7273
74+ twice_inexact = jtu .tree_map (lambda * x : jnp .stack (x ), inexact , inexact )
75+
76+ @eqx .filter_jit
77+ def _run_vmap_grad (twice_inexact , saveat , adjoint ):
78+ f = jax .vmap (jax .grad (_run_inexact ), in_axes = (0 , None , None ))
79+ return f (twice_inexact , saveat , adjoint )
80+
81+ # @eqx.filter_jit
82+ # def _run_vmap_finite_diff(twice_inexact, saveat, adjoint):
83+ # @jax.vmap
84+ # def _run_impl(inexact):
85+ # y0__args__term = eqx.combine(inexact, static)
86+ # return _run_finite_diff(y0__args__term, saveat, adjoint)
87+ # return _run_impl(twice_inexact)
88+
89+ @eqx .filter_jit
90+ def _run_grad_vmap (twice_inexact , saveat , adjoint ):
91+ @jax .grad
92+ def _run_impl (twice_inexact ):
93+ f = jax .vmap (_run_inexact , in_axes = (0 , None , None ))
94+ out = f (twice_inexact , saveat , adjoint )
95+ assert out .shape == (2 ,)
96+ return jnp .sum (out )
97+
98+ return _run_impl (twice_inexact )
99+
73100 # Yep, test that they're not implemented. We can remove these checks if we ever
74101 # do implement them.
75102 # Until that day comes, it's worth checking that things don't silently break.
@@ -99,11 +126,11 @@ def _convert_float0(x):
99126 fd_grads = _run_finite_diff (
100127 y0__args__term , saveat , diffrax .RecursiveCheckpointAdjoint ()
101128 )
102- direct_grads = _run_grad (diff , saveat , diffrax .DirectAdjoint ())
129+ direct_grads = _run_grad (inexact , saveat , diffrax .DirectAdjoint ())
103130 recursive_grads = _run_grad (
104- diff , saveat , diffrax .RecursiveCheckpointAdjoint ()
131+ inexact , saveat , diffrax .RecursiveCheckpointAdjoint ()
105132 )
106- backsolve_grads = _run_grad (diff , saveat , diffrax .BacksolveAdjoint ())
133+ backsolve_grads = _run_grad (inexact , saveat , diffrax .BacksolveAdjoint ())
107134 assert shaped_allclose (fd_grads , direct_grads [0 ])
108135 assert shaped_allclose (direct_grads , recursive_grads , atol = 1e-5 )
109136 assert shaped_allclose (direct_grads , backsolve_grads , atol = 1e-5 )
@@ -120,6 +147,34 @@ def _convert_float0(x):
120147 direct_grads = jtu .tree_map (_convert_float0 , direct_grads )
121148 recursive_grads = jtu .tree_map (_convert_float0 , recursive_grads )
122149 backsolve_grads = jtu .tree_map (_convert_float0 , backsolve_grads )
150+ assert shaped_allclose (fd_grads , direct_grads [0 ])
151+ assert shaped_allclose (direct_grads , recursive_grads , atol = 1e-5 )
152+ assert shaped_allclose (direct_grads , backsolve_grads , atol = 1e-5 )
153+
154+ fd_grads = jtu .tree_map (lambda * x : jnp .stack (x ), fd_grads , fd_grads )
155+ direct_grads = _run_vmap_grad (
156+ twice_inexact , saveat , diffrax .DirectAdjoint ()
157+ )
158+ recursive_grads = _run_vmap_grad (
159+ twice_inexact , saveat , diffrax .RecursiveCheckpointAdjoint ()
160+ )
161+ backsolve_grads = _run_vmap_grad (
162+ twice_inexact , saveat , diffrax .BacksolveAdjoint ()
163+ )
164+ assert shaped_allclose (fd_grads , direct_grads [0 ])
165+ assert shaped_allclose (direct_grads , recursive_grads , atol = 1e-5 )
166+ assert shaped_allclose (direct_grads , backsolve_grads , atol = 1e-5 )
167+
168+ direct_grads = _run_grad_vmap (
169+ twice_inexact , saveat , diffrax .DirectAdjoint ()
170+ )
171+ recursive_grads = _run_grad_vmap (
172+ twice_inexact , saveat , diffrax .RecursiveCheckpointAdjoint ()
173+ )
174+ backsolve_grads = _run_grad_vmap (
175+ twice_inexact , saveat , diffrax .BacksolveAdjoint ()
176+ )
177+ assert shaped_allclose (fd_grads , direct_grads [0 ])
123178 assert shaped_allclose (direct_grads , recursive_grads , atol = 1e-5 )
124179 assert shaped_allclose (direct_grads , backsolve_grads , atol = 1e-5 )
125180
0 commit comments