Skip to content

Commit 539757d

Browse files
authored
Merge branch 'patrick-kidger:main' into main
2 parents 3bfa34e + 7f30854 commit 539757d

File tree

5 files changed

+49
-4
lines changed

5 files changed

+49
-4
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ _From a technical point of view, the internal structure of the library is pretty
2121
pip install diffrax
2222
```
2323

24-
Requires Python 3.9+, JAX 0.4.13+, and [Equinox](https://github.com/patrick-kidger/equinox) 0.10.10+.
24+
Requires Python 3.9+, JAX 0.4.13+, and [Equinox](https://github.com/patrick-kidger/equinox) 0.10.11+.
2525

2626
## Documentation
2727

diffrax/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,4 +93,4 @@
9393
)
9494

9595

96-
__version__ = "0.4.0"
96+
__version__ = "0.4.1"

docs/index.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ _From a technical point of view, the internal structure of the library is pretty
2020
pip install diffrax
2121
```
2222

23-
Requires Python 3.9+, JAX 0.4.13+, and [Equinox](https://github.com/patrick-kidger/equinox) 0.10.10+.
23+
Requires Python 3.9+, JAX 0.4.13+, and [Equinox](https://github.com/patrick-kidger/equinox) 0.10.11+.
2424

2525
## Quick example
2626

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@
4646

4747
python_requires = "~=3.9"
4848

49-
install_requires = ["jax>=0.4.13", "equinox>=0.10.10"]
49+
install_requires = ["jax>=0.4.13", "equinox>=0.10.11"]
5050

5151
setuptools.setup(
5252
name=name,

test/test_adaptive_stepsize_controller.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
11
import diffrax
22
import equinox as eqx
3+
import jax
34
import jax.numpy as jnp
45
import jax.tree_util as jtu
56

7+
from .helpers import shaped_allclose
8+
69

710
def test_step_ts():
811
term = diffrax.ODETerm(lambda t, y, args: -0.2 * y)
@@ -90,3 +93,45 @@ def run(ys, controller, state):
9093
ys = (y0, y1_candidate, y_error)
9194
grads = run(ys, stepsize_controller, state)
9295
assert not any(jnp.isnan(grad).any() for grad in grads)
96+
97+
98+
def test_grad_of_discontinuous_forcing():
99+
def vector_field(t, y, forcing):
100+
y, _ = y
101+
dy = -y + forcing(t)
102+
dsum = y
103+
return dy, dsum
104+
105+
def run(t):
106+
term = diffrax.ODETerm(vector_field)
107+
solver = diffrax.Tsit5()
108+
t0 = 0
109+
t1 = 1
110+
dt0 = None
111+
y0 = 1.0
112+
stepsize_controller = diffrax.PIDController(
113+
rtol=1e-8, atol=1e-8, step_ts=t[None]
114+
)
115+
116+
def forcing(s):
117+
return jnp.where(s < t, 0, 1)
118+
119+
sol = diffrax.diffeqsolve(
120+
term,
121+
solver,
122+
t0,
123+
t1,
124+
dt0,
125+
(y0, 0),
126+
args=forcing,
127+
stepsize_controller=stepsize_controller,
128+
)
129+
_, sum = sol.ys
130+
(sum,) = sum
131+
return sum
132+
133+
r = jax.jit(run)
134+
eps = 1e-5
135+
finite_diff = (r(0.5) - r(0.5 - eps)) / eps
136+
autodiff = jax.jit(jax.grad(run))(0.5)
137+
assert shaped_allclose(finite_diff, autodiff)

0 commit comments

Comments
 (0)