Skip to content

Commit fe1ca9a

Browse files
authored
Merge branch 'patrick-kidger:main' into main
2 parents 539757d + 712c208 commit fe1ca9a

File tree

16 files changed

+125
-48
lines changed

16 files changed

+125
-48
lines changed

README.md

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -61,16 +61,24 @@ If you found this library useful in academic research, please cite: [(arXiv link
6161

6262
## See also: other libraries in the JAX ecosystem
6363

64+
[jaxtyping](https://github.com/google/jaxtyping): type annotations for shape/dtype of arrays.
65+
6466
[Equinox](https://github.com/patrick-kidger/equinox): neural networks.
6567

6668
[Optax](https://github.com/deepmind/optax): first-order gradient (SGD, Adam, ...) optimisers.
6769

68-
[Lineax](https://github.com/google/lineax): linear solvers and linear least squares.
70+
[Optimistix](https://github.com/patrick-kidger/optimistix): root finding, minimisation, fixed points, and least squares.
6971

70-
[jaxtyping](https://github.com/google/jaxtyping): type annotations for shape/dtype of arrays.
72+
[Lineax](https://github.com/google/lineax): linear solvers.
7173

72-
[Eqxvision](https://github.com/paganpasta/eqxvision): computer vision models.
74+
[BlackJAX](https://github.com/blackjax-devs/blackjax): probabilistic+Bayesian sampling.
75+
76+
[Orbax](https://github.com/google/orbax): checkpointing (async/multi-host/multi-device).
7377

7478
[sympy2jax](https://github.com/google/sympy2jax): SymPy<->JAX conversion; train symbolic expressions via gradient descent.
7579

80+
[Eqxvision](https://github.com/paganpasta/eqxvision): computer vision models.
81+
7682
[Levanter](https://github.com/stanford-crfm/levanter): scalable+reliable training of foundation models (e.g. LLMs).
83+
84+
[PySR](https://github.com/milesCranmer/PySR): symbolic regression. (Non-JAX honourable mention!)

diffrax/adjoint.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -530,6 +530,7 @@ def _loop_backsolve_bwd(
530530
throw,
531531
init_state,
532532
):
533+
assert discrete_terminating_event is None
533534

534535
#
535536
# Unpack our various arguments. Delete a lot of things just to make sure we're not
@@ -565,7 +566,6 @@ def _loop_backsolve_bwd(
565566
adjoint=self,
566567
solver=solver,
567568
stepsize_controller=stepsize_controller,
568-
discrete_terminating_event=discrete_terminating_event,
569569
terms=adjoint_terms,
570570
dt0=None if dt0 is None else -dt0,
571571
max_steps=max_steps,
@@ -744,6 +744,7 @@ def loop(
744744
init_state,
745745
passed_solver_state,
746746
passed_controller_state,
747+
discrete_terminating_event,
747748
**kwargs,
748749
):
749750
if jtu.tree_structure(saveat.subs, is_leaf=_is_subsaveat) != jtu.tree_structure(
@@ -785,6 +786,10 @@ def loop(
785786
"`diffrax.BacksolveAdjoint` is only compatible with solvers that take "
786787
"a single term."
787788
)
789+
if discrete_terminating_event is not None:
790+
raise NotImplementedError(
791+
"`diffrax.BacksolveAdjoint` is not compatible with events."
792+
)
788793

789794
y = init_state.y
790795
init_state = eqx.tree_at(lambda s: s.y, init_state, object())
@@ -798,6 +803,7 @@ def loop(
798803
saveat=saveat,
799804
init_state=init_state,
800805
solver=solver,
806+
discrete_terminating_event=discrete_terminating_event,
801807
**kwargs,
802808
)
803809
final_state = _only_transpose_ys(final_state)

diffrax/brownian/path.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ class UnsafeBrownianPath(AbstractBrownianPath):
3232
correlation structure isn't needed.)
3333
"""
3434

35-
shape: PyTree[jax.ShapeDtypeStruct] = eqx.static_field()
35+
shape: PyTree[jax.ShapeDtypeStruct] = eqx.field(static=True)
3636
# Handled as a string because PRNGKey is actually a function, not a class, which
3737
# makes it appearly badly in autogenerated documentation.
3838
key: "jax.random.PRNGKey" # noqa: F821

diffrax/brownian/tree.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ class VirtualBrownianTree(AbstractBrownianPath):
6060
t0: Scalar = field(init=True)
6161
t1: Scalar = field(init=True) # override init=False in AbstractPath
6262
tol: Scalar
63-
shape: PyTree[jax.ShapeDtypeStruct] = eqx.static_field()
63+
shape: PyTree[jax.ShapeDtypeStruct] = eqx.field(static=True)
6464
key: "jax.random.PRNGKey" # noqa: F821
6565

6666
def __init__(

diffrax/global_interpolation.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,8 @@ class LinearInterpolation(AbstractGlobalInterpolation):
6767
ys: PyTree[Array["times", ...]] # noqa: F821
6868

6969
def __post_init__(self):
70+
super().__post_init__()
71+
7072
def _check(_ys):
7173
if _ys.shape[0] != self.ts.shape[0]:
7274
raise ValueError(
@@ -179,6 +181,8 @@ class CubicInterpolation(AbstractGlobalInterpolation):
179181
]
180182

181183
def __post_init__(self):
184+
super().__post_init__()
185+
182186
def _check(d, c, b, a):
183187
error_msg = (
184188
"Each cubic coefficient must have `times - 1` entries, where "
@@ -287,12 +291,14 @@ def derivative(self, t: Scalar, left: bool = True) -> PyTree:
287291
class DenseInterpolation(AbstractGlobalInterpolation):
288292
ts_size: Int # Takes values in {1, 2, 3, ...}
289293
infos: DenseInfos
290-
interpolation_cls: Type[AbstractLocalInterpolation] = eqx.static_field()
294+
interpolation_cls: Type[AbstractLocalInterpolation] = eqx.field(static=True)
291295
direction: Scalar
292296
t0_if_trivial: Array
293297
y0_if_trivial: PyTree[Array]
294298

295299
def __post_init__(self):
300+
super().__post_init__()
301+
296302
def _check(_infos):
297303
assert _infos.shape[0] + 1 == self.ts.shape[0]
298304

diffrax/integrate.py

Lines changed: 6 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
import jax.tree_util as jtu
1111
from jax.typing import ArrayLike
1212

13-
from .adjoint import AbstractAdjoint, DirectAdjoint, RecursiveCheckpointAdjoint
13+
from .adjoint import AbstractAdjoint, RecursiveCheckpointAdjoint
1414
from .custom_types import Array, Bool, Int, PyTree, Scalar
1515
from .event import AbstractDiscreteTerminatingEvent
1616
from .global_interpolation import DenseInterpolation
@@ -415,6 +415,11 @@ def save_steps(subsaveat: SubSaveAt, save_state: SaveState) -> SaveState:
415415
)
416416
new_state = eqx.tree_at(lambda s: s.result, new_state, result)
417417

418+
if not _filtering:
419+
# This is only necessary for Equinox <0.11.1.
420+
# After that, this fix has been upstreamed to Equinox.
421+
# TODO: remove once we make Equinox >=0.11.1 required.
422+
new_state = jtu.tree_map(jnp.array, new_state)
418423
return new_state
419424

420425
_filtering = True
@@ -633,22 +638,6 @@ def diffeqsolve(
633638
"An SDE should not be solved with adaptive step sizes with Euler's "
634639
"method, as it may not converge to the correct solution."
635640
)
636-
# TODO: remove these lines.
637-
#
638-
# These are to work around an edge case: on the backward pass,
639-
# RecursiveCheckpointAdjoint currently tries to differentiate the overall
640-
# per-step function wrt all floating-point arrays. In particular this includes
641-
# `state.tprev`, which feeds into the control, which feeds into
642-
# VirtualBrownianTree, which can't be differentiated.
643-
# We're waiting on JAX to offer a way of specifying which arguments to a
644-
# custom_vjp have symbolic zero *tangents* (not cotangents) so that we can more
645-
# precisely determine what to differentiate wrt.
646-
#
647-
# We don't replace this in the case of an unsafe SDE because
648-
# RecursiveCheckpointAdjoint will raise an error in that case anyway, so we
649-
# should let the normal error be raised.
650-
if isinstance(adjoint, RecursiveCheckpointAdjoint) and not is_unsafe_sde(terms):
651-
adjoint = DirectAdjoint()
652641
if is_unsafe_sde(terms):
653642
if isinstance(stepsize_controller, AbstractAdaptiveStepSizeController):
654643
raise ValueError(

diffrax/nonlinear_solver/newton.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,7 @@ def body_fn(val):
141141
val = (flat, step + 1, diffsize, diffsize_prev)
142142
return val
143143

144-
val = (flat, 0, 0.0, 0.0)
144+
val = (flat, 0, jnp.array(0.0), jnp.array(0.0))
145145
val = lax.while_loop(cond_fn, body_fn, val)
146146
flat, num_steps, diffsize, diffsize_prev = val
147147

diffrax/solution.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,16 @@
1313
class RESULTS(metaclass=eqxi.ContainerMeta):
1414
successful = ""
1515
discrete_terminating_event_occurred = (
16-
"Terminating solve because a discrete event occurred."
16+
"Terminating differential equation solve because a discrete terminating event "
17+
"occurred."
1718
)
1819
max_steps_reached = (
19-
"The maximum number of solver steps was reached. Try increasing `max_steps`."
20+
"The maximum number of steps was reached in the differential equation solver. "
21+
"Try increasing `diffrax.diffeqsolve(..., max_steps=...)`."
22+
)
23+
dt_min_reached = (
24+
"The minimum step size was reached in the differential equation solver."
2025
)
21-
dt_min_reached = "The minimum step size was reached."
2226
implicit_divergence = "Implicit method diverged."
2327
implicit_nonconvergence = (
2428
"Implicit method did not converge within the required number of iterations."

diffrax/step_size_controller/adaptive.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -423,8 +423,8 @@ def adapt_step_size(
423423
# h_n is the nth step size
424424
# ε_n = atol + norm(y) * rtol with y on the nth step
425425
# r_n = norm(y_error) with y_error on the nth step
426-
# δ_{n,m} = norm(y_error / (atol + norm(y) * rtol)) with y_error on the nth
427-
# step and y on the mth step
426+
# δ_{n,m} = norm(y_error / (atol + norm(y) * rtol))^(-1) with y_error on the nth
427+
# step and y on the mth step
428428
# β_1 = pcoeff + icoeff + dcoeff
429429
# β_2 = -(pcoeff + 2 * dcoeff)
430430
# β_3 = dcoeff

diffrax/term.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -409,7 +409,9 @@ def is_vf_expensive(
409409
y: Tuple[PyTree, PyTree, PyTree, PyTree],
410410
args: PyTree,
411411
) -> bool:
412-
return self.term.is_vf_expensive(t0, t1, y, args)
412+
_t0 = jnp.where(self.direction == 1, t0, -t1)
413+
_t1 = jnp.where(self.direction == 1, t1, -t0)
414+
return self.term.is_vf_expensive(_t0, _t1, y, args)
413415

414416

415417
class AdjointTerm(AbstractTerm):
@@ -422,8 +424,8 @@ def is_vf_expensive(
422424
y: Tuple[PyTree, PyTree, PyTree, PyTree],
423425
args: PyTree,
424426
) -> bool:
425-
control = self.contr(t0, t1)
426-
if sum(c.size for c in jtu.tree_leaves(control)) in (0, 1):
427+
control_struct = jax.eval_shape(self.contr, t0, t1)
428+
if sum(c.size for c in jtu.tree_leaves(control_struct)) in (0, 1):
427429
return False
428430
else:
429431
return True

0 commit comments

Comments
 (0)