Skip to content

Commit 29357a7

Browse files
Added more comprehensive gradient tests
1 parent 020942d commit 29357a7

File tree

1 file changed

+64
-9
lines changed

1 file changed

+64
-9
lines changed

test/test_adjoint.py

Lines changed: 64 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ def _run(y0__args__term, saveat, adjoint):
5151
).ys
5252
)
5353

54+
# Only does gradients with respect to y0
5455
def _run_finite_diff(y0__args__term, saveat, adjoint):
5556
y0, args, term = y0__args__term
5657
y0_a = y0 + jnp.array([1e-5, 0])
@@ -62,14 +63,40 @@ def _run_finite_diff(y0__args__term, saveat, adjoint):
6263
out_b = (val_b - val) / 1e-5
6364
return jnp.stack([out_a, out_b])
6465

65-
diff, nondiff = eqx.partition(y0__args__term, eqx.is_inexact_array)
66-
_run_grad = eqx.filter_jit(
67-
jax.grad(
68-
lambda d, saveat, adjoint: _run(eqx.combine(d, nondiff), saveat, adjoint)
69-
)
70-
)
66+
inexact, static = eqx.partition(y0__args__term, eqx.is_inexact_array)
67+
68+
def _run_inexact(inexact, saveat, adjoint):
69+
return _run(eqx.combine(inexact, static), saveat, adjoint)
70+
71+
_run_grad = eqx.filter_jit(jax.grad(_run_inexact))
7172
_run_grad_int = eqx.filter_jit(jax.grad(_run, allow_int=True))
7273

74+
twice_inexact = jtu.tree_map(lambda *x: jnp.stack(x), inexact, inexact)
75+
76+
@eqx.filter_jit
77+
def _run_vmap_grad(twice_inexact, saveat, adjoint):
78+
f = jax.vmap(jax.grad(_run_inexact), in_axes=(0, None, None))
79+
return f(twice_inexact, saveat, adjoint)
80+
81+
# @eqx.filter_jit
82+
# def _run_vmap_finite_diff(twice_inexact, saveat, adjoint):
83+
# @jax.vmap
84+
# def _run_impl(inexact):
85+
# y0__args__term = eqx.combine(inexact, static)
86+
# return _run_finite_diff(y0__args__term, saveat, adjoint)
87+
# return _run_impl(twice_inexact)
88+
89+
@eqx.filter_jit
90+
def _run_grad_vmap(twice_inexact, saveat, adjoint):
91+
@jax.grad
92+
def _run_impl(twice_inexact):
93+
f = jax.vmap(_run_inexact, in_axes=(0, None, None))
94+
out = f(twice_inexact, saveat, adjoint)
95+
assert out.shape == (2,)
96+
return jnp.sum(out)
97+
98+
return _run_impl(twice_inexact)
99+
73100
# Yep, test that they're not implemented. We can remove these checks if we ever
74101
# do implement them.
75102
# Until that day comes, it's worth checking that things don't silently break.
@@ -99,11 +126,11 @@ def _convert_float0(x):
99126
fd_grads = _run_finite_diff(
100127
y0__args__term, saveat, diffrax.RecursiveCheckpointAdjoint()
101128
)
102-
direct_grads = _run_grad(diff, saveat, diffrax.DirectAdjoint())
129+
direct_grads = _run_grad(inexact, saveat, diffrax.DirectAdjoint())
103130
recursive_grads = _run_grad(
104-
diff, saveat, diffrax.RecursiveCheckpointAdjoint()
131+
inexact, saveat, diffrax.RecursiveCheckpointAdjoint()
105132
)
106-
backsolve_grads = _run_grad(diff, saveat, diffrax.BacksolveAdjoint())
133+
backsolve_grads = _run_grad(inexact, saveat, diffrax.BacksolveAdjoint())
107134
assert shaped_allclose(fd_grads, direct_grads[0])
108135
assert shaped_allclose(direct_grads, recursive_grads, atol=1e-5)
109136
assert shaped_allclose(direct_grads, backsolve_grads, atol=1e-5)
@@ -120,6 +147,34 @@ def _convert_float0(x):
120147
direct_grads = jtu.tree_map(_convert_float0, direct_grads)
121148
recursive_grads = jtu.tree_map(_convert_float0, recursive_grads)
122149
backsolve_grads = jtu.tree_map(_convert_float0, backsolve_grads)
150+
assert shaped_allclose(fd_grads, direct_grads[0])
151+
assert shaped_allclose(direct_grads, recursive_grads, atol=1e-5)
152+
assert shaped_allclose(direct_grads, backsolve_grads, atol=1e-5)
153+
154+
fd_grads = jtu.tree_map(lambda *x: jnp.stack(x), fd_grads, fd_grads)
155+
direct_grads = _run_vmap_grad(
156+
twice_inexact, saveat, diffrax.DirectAdjoint()
157+
)
158+
recursive_grads = _run_vmap_grad(
159+
twice_inexact, saveat, diffrax.RecursiveCheckpointAdjoint()
160+
)
161+
backsolve_grads = _run_vmap_grad(
162+
twice_inexact, saveat, diffrax.BacksolveAdjoint()
163+
)
164+
assert shaped_allclose(fd_grads, direct_grads[0])
165+
assert shaped_allclose(direct_grads, recursive_grads, atol=1e-5)
166+
assert shaped_allclose(direct_grads, backsolve_grads, atol=1e-5)
167+
168+
direct_grads = _run_grad_vmap(
169+
twice_inexact, saveat, diffrax.DirectAdjoint()
170+
)
171+
recursive_grads = _run_grad_vmap(
172+
twice_inexact, saveat, diffrax.RecursiveCheckpointAdjoint()
173+
)
174+
backsolve_grads = _run_grad_vmap(
175+
twice_inexact, saveat, diffrax.BacksolveAdjoint()
176+
)
177+
assert shaped_allclose(fd_grads, direct_grads[0])
123178
assert shaped_allclose(direct_grads, recursive_grads, atol=1e-5)
124179
assert shaped_allclose(direct_grads, backsolve_grads, atol=1e-5)
125180

0 commit comments

Comments
 (0)