Skip to content

Commit 020942d

Browse files
Optimise StepTo to match performance of a naive scan.
1 parent 6af45a5 commit 020942d

File tree

5 files changed

+206
-42
lines changed

5 files changed

+206
-42
lines changed

benchmarks/against_scan.py

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
# This benchmark should be ran on the GPU.
2+
3+
import timeit
4+
5+
import diffrax as dfx
6+
import jax
7+
import jax.lax as lax
8+
import jax.numpy as jnp
9+
10+
11+
# SETUP
12+
13+
N = 256
14+
N_steps = 2000
15+
ts = jnp.linspace(0, 1, N_steps + 1)
16+
u0, v0 = jnp.zeros((N, N)), jnp.zeros((N, N)).at[32, 32].set(1.0)
17+
fields = (u0, v0)
18+
du = lambda t, v, args: -(v**2)
19+
dv = lambda t, u, args: -jnp.fft.irfft(jnp.sin(jnp.fft.rfft(u)))
20+
sample = lambda t, y, args: y[0][64, 64] # Some arbitrary sampling function
21+
22+
23+
def speedtest(fn, name):
24+
fwd = jax.jit(fn)
25+
bwd = jax.jit(jax.grad(fn))
26+
integration_times = timeit.repeat(
27+
lambda: jax.block_until_ready(fwd(fields, ts)), number=1, repeat=10
28+
)
29+
print(f"{name} fwd: {min(integration_times)}")
30+
grad_times = timeit.repeat(
31+
lambda: jax.block_until_ready(bwd(fields, ts)), number=1, repeat=10
32+
)
33+
print(f"{name} fwd+bwd: {min(grad_times)}")
34+
35+
36+
# INTEGRATE WITH scan
37+
38+
39+
@jax.checkpoint
40+
def body(carry, t):
41+
u, v, dt = carry
42+
u = u + du(t, v, None) * dt
43+
v = v + dv(t, u, None) * dt
44+
return (u, v, dt), sample(t, (u, v), None)
45+
46+
47+
def scan_fn(fields, t):
48+
dt = t[1] - t[0]
49+
carry = (fields[0], fields[1], dt)
50+
_, values = lax.scan(body, carry, t[:-1])
51+
return jnp.mean(values**2)
52+
53+
54+
speedtest(scan_fn, "scan")
55+
56+
57+
# INTEGRATE WITH SemiImplicitEuler
58+
59+
60+
@jax.jit
61+
def dfx_fn(fields, t):
62+
return dfx.diffeqsolve(
63+
terms=(dfx.ODETerm(du), dfx.ODETerm(dv)),
64+
solver=dfx.SemiImplicitEuler(),
65+
t0=t[0],
66+
t1=t[-1],
67+
dt0=None,
68+
y0=fields,
69+
args=None,
70+
saveat=dfx.SaveAt(steps=True, fn=sample, dense=False),
71+
stepsize_controller=dfx.StepTo(ts),
72+
adjoint=dfx.RecursiveCheckpointAdjoint(checkpoints=N_steps),
73+
max_steps=N_steps,
74+
throw=False,
75+
).ys
76+
77+
78+
speedtest(dfx_fn, "SemiImplicitEuler")

diffrax/integrate.py

Lines changed: 65 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,14 @@
88
import jax
99
import jax.numpy as jnp
1010
import jax.tree_util as jtu
11+
from jax.typing import ArrayLike
1112

1213
from .adjoint import AbstractAdjoint, DirectAdjoint, RecursiveCheckpointAdjoint
1314
from .custom_types import Array, Bool, Int, PyTree, Scalar
1415
from .event import AbstractDiscreteTerminatingEvent
1516
from .global_interpolation import DenseInterpolation
1617
from .heuristics import is_sde, is_unsafe_sde
18+
from .misc import static_select
1719
from .saveat import SaveAt, SubSaveAt
1820
from .solution import is_okay, is_successful, RESULTS, Solution
1921
from .solver import (
@@ -29,7 +31,6 @@
2931
AbstractAdaptiveStepSizeController,
3032
AbstractStepSizeController,
3133
ConstantStepSize,
32-
PIDController,
3334
StepTo,
3435
)
3536
from .term import AbstractTerm, MultiTerm, ODETerm, WrapTerm
@@ -141,6 +142,19 @@ def _clip_to_end(tprev, tnext, t1, keep_step):
141142
return jnp.where(clip, tclip, tnext)
142143

143144

145+
def _maybe_static(static_x: ArrayLike, x: Array) -> ArrayLike:
146+
# Some values (made_jump and result) are not used in many common use-cases. If we
147+
# detect that they're unused then we make sure they're non-Array Python values, so
148+
# that we can special case on them at trace time and get a performance boost.
149+
if isinstance(static_x, (bool, int, float, complex)):
150+
return static_x
151+
elif type(jax.core.get_aval(static_x)) is jax.core.ConcreteArray:
152+
with jax.ensure_compile_time_eval():
153+
return static_x.item()
154+
else:
155+
return x
156+
157+
144158
def loop(
145159
*,
146160
solver,
@@ -175,19 +189,27 @@ def save_t0(subsaveat: SubSaveAt, save_state: SaveState) -> SaveState:
175189
lambda s: s.save_state, init_state, save_state, is_leaf=_is_none
176190
)
177191

178-
# Privileged optimisation for the common case of no jumps. We can reduce
179-
# solver compile time with this.
180-
# TODO: somehow make this a non-privileged optimisation, i.e. detect when
181-
# we can make jumps or not.
182-
cannot_make_jump = isinstance(stepsize_controller, (ConstantStepSize, StepTo)) or (
183-
isinstance(stepsize_controller, PIDController)
184-
and stepsize_controller.jump_ts is None
185-
)
192+
def _handle_static(state):
193+
# We can improve runtime by resolving `result` at trace time if possible.
194+
# We can improve compiletime by resolving `made_jump` at trace time if possible.
195+
result = _maybe_static(static_result, state.result)
196+
made_jump = _maybe_static(static_made_jump, state.made_jump)
197+
return eqx.tree_at(
198+
lambda s: (s.result, s.made_jump), state, (result, made_jump)
199+
)
186200

187201
def cond_fun(state):
188-
return (state.tprev < t1) & is_successful(state.result)
202+
if isinstance(stepsize_controller, StepTo):
203+
# Privileged optimisation.
204+
# This is a measurably cheaper check than the tprev < t1 check.
205+
out = state.num_steps < len(stepsize_controller.ts) - 1
206+
else:
207+
out = state.tprev < t1
208+
state = _handle_static(state)
209+
return out & is_successful(state.result)
189210

190211
def body_fun(state):
212+
state = _handle_static(state)
191213

192214
#
193215
# Actually do some differential equation solving! Make numerical steps, adapt
@@ -201,7 +223,7 @@ def body_fun(state):
201223
state.y,
202224
args,
203225
state.solver_state,
204-
False if cannot_make_jump else state.made_jump,
226+
state.made_jump,
205227
)
206228

207229
# e.g. if someone has a sqrt(y) in the vector field, and dt0 is so large that
@@ -228,16 +250,6 @@ def body_fun(state):
228250
state.controller_state,
229251
)
230252
assert jnp.result_type(keep_step) is jnp.dtype(bool)
231-
if cannot_make_jump:
232-
# Should hopefully get DCE'd out.
233-
made_jump = eqxi.error_if(
234-
made_jump,
235-
made_jump,
236-
(
237-
"Internal error in Diffrax: made unexpected jump. Please report an "
238-
"issue at https://github.com/patrick-kidger/diffrax/issues"
239-
),
240-
)
241253

242254
#
243255
# Do some book-keeping.
@@ -252,8 +264,8 @@ def body_fun(state):
252264
keep = lambda a, b: jnp.where(keep_step, a, b)
253265
y = jtu.tree_map(keep, y, state.y)
254266
solver_state = jtu.tree_map(keep, solver_state, state.solver_state)
255-
made_jump = keep(made_jump, state.made_jump)
256-
solver_result = keep(solver_result, RESULTS.successful)
267+
made_jump = static_select(keep_step, made_jump, state.made_jump)
268+
solver_result = static_select(keep_step, solver_result, RESULTS.successful)
257269

258270
# TODO: if we ever support non-terminating events, then they should go in here.
259271
# In particular the thing to be careful about is in the `if saveat.steps`
@@ -262,9 +274,8 @@ def body_fun(state):
262274
# previous step's `tnext`, i.e. immediately before the jump.)
263275

264276
# Store the first unsuccessful result we get whilst iterating (if any).
265-
result = state.result
266-
result = jnp.where(is_okay(result), solver_result, result)
267-
result = jnp.where(is_okay(result), stepsize_controller_result, result)
277+
result = static_select(is_okay(state.result), solver_result, state.result)
278+
result = static_select(is_okay(result), stepsize_controller_result, result)
268279

269280
# Count the number of steps, just for statistical purposes.
270281
num_steps = state.num_steps + 1
@@ -328,7 +339,15 @@ def _body_fun(_save_state):
328339
)
329340

330341
def maybe_inplace(i, u, x):
331-
return x.at[i].set(u, pred=keep_step)
342+
# Annoying hack. We normally call this with `x` wrapped into a buffer
343+
# (from Equinox's while loops). However we do also first trace through to
344+
# see if we can resolve some values statically, in which case normal JAX
345+
# arrays don't support the extra `pred` argument. We don't then use the
346+
# result of this so we just skip it.
347+
if _filtering:
348+
return x
349+
else:
350+
return x.at[i].set(u, pred=keep_step)
332351

333352
def save_steps(subsaveat: SubSaveAt, save_state: SaveState) -> SaveState:
334353
if subsaveat.steps:
@@ -389,7 +408,7 @@ def save_steps(subsaveat: SubSaveAt, save_state: SaveState) -> SaveState:
389408
terms=terms,
390409
args=args,
391410
)
392-
result = jnp.where(
411+
result = static_select(
393412
discrete_terminating_event_occurred,
394413
RESULTS.discrete_terminating_event_occurred,
395414
result,
@@ -398,6 +417,15 @@ def save_steps(subsaveat: SubSaveAt, save_state: SaveState) -> SaveState:
398417

399418
return new_state
400419

420+
_filtering = True
421+
static_made_jump = init_state.made_jump
422+
static_result = init_state.result
423+
filter_state = eqx.filter_eval_shape(body_fun, init_state)
424+
_filtering = False
425+
static_made_jump = filter_state.made_jump
426+
static_result = filter_state.result
427+
del filter_state
428+
401429
final_state = outer_while_loop(
402430
cond_fun, body_fun, init_state, max_steps=max_steps, buffers=_outer_buffers
403431
)
@@ -420,6 +448,7 @@ def _save_t1(subsaveat, save_state):
420448
lambda s: s.save_state, final_state, save_state, is_leaf=_is_none
421449
)
422450

451+
final_state = _handle_static(final_state)
423452
result = jnp.where(
424453
cond_fun(final_state), RESULTS.max_steps_reached, final_state.result
425454
)
@@ -751,7 +780,7 @@ def _allocate_output(subsaveat: SubSaveAt) -> SaveState:
751780
num_accepted_steps = 0
752781
num_rejected_steps = 0
753782
made_jump = False if made_jump is None else made_jump
754-
result = jnp.array(RESULTS.successful)
783+
result = RESULTS.successful
755784
if saveat.dense:
756785
if max_steps is None:
757786
raise ValueError(
@@ -871,10 +900,11 @@ def _allocate_output(subsaveat: SubSaveAt) -> SaveState:
871900
)
872901

873902
error_index = eqxi.unvmap_max(result)
874-
sol = eqxi.branched_error_if(
875-
sol,
876-
throw & jnp.invert(is_okay(result)),
877-
error_index,
878-
RESULTS.reverse_lookup,
879-
)
903+
if throw:
904+
sol = eqxi.branched_error_if(
905+
sol,
906+
jnp.invert(is_okay(result)),
907+
error_index,
908+
RESULTS.reverse_lookup,
909+
)
880910
return sol

diffrax/misc.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
1-
from typing import Callable, Optional, Tuple
1+
from typing import Callable, Optional, Tuple, Union
22

33
import jax
4+
import jax.core
45
import jax.flatten_util as fu
56
import jax.lax as lax
67
import jax.numpy as jnp
78
import jax.tree_util as jtu
9+
from jax.typing import ArrayLike
810

911
from .custom_types import Array, PyTree, Scalar
1012

@@ -162,3 +164,22 @@ def split_by_tree(key, tree, is_leaf: Optional[Callable[[PyTree], bool]] = None)
162164

163165
def is_tuple_of_ints(obj):
164166
return isinstance(obj, tuple) and all(isinstance(x, int) for x in obj)
167+
168+
169+
def static_select(pred: Union[bool, Array], a: ArrayLike, b: ArrayLike) -> ArrayLike:
170+
# This is mostly useful in that it doesn't promote `a` or `b` to Arrays when the
171+
# predicate is statically known.
172+
# This in turn allows us to perform some trace-time optimisations that XLA isn't
173+
# smart enough to do on its own.
174+
if (
175+
type(pred) is not bool
176+
and type(jax.core.get_aval(pred)) is jax.core.ConcreteArray
177+
):
178+
with jax.ensure_compile_time_eval():
179+
pred = pred.item()
180+
if pred is True:
181+
return a
182+
elif pred is False:
183+
return b
184+
else:
185+
return lax.select(pred, a, b)

diffrax/solution.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,11 @@
22
from typing import Any, Dict, Optional
33

44
import equinox.internal as eqxi
5-
import jax.numpy as jnp
5+
import jax
66

77
from .custom_types import Array, Bool, PyTree, Scalar
88
from .global_interpolation import DenseInterpolation
9+
from .misc import static_select
910
from .path import AbstractPath
1011

1112

@@ -25,17 +26,20 @@ class RESULTS(metaclass=eqxi.ContainerMeta):
2526

2627

2728
def is_okay(result: RESULTS) -> Bool:
28-
return is_successful(result) | is_event(result)
29+
with jax.ensure_compile_time_eval():
30+
return is_successful(result) | is_event(result)
2931

3032

3133
def is_successful(result: RESULTS) -> Bool:
32-
return result == RESULTS.successful
34+
with jax.ensure_compile_time_eval():
35+
return result == RESULTS.successful
3336

3437

3538
# TODO: In the future we may support other event types, in which case this function
3639
# should be updated.
3740
def is_event(result: RESULTS) -> Bool:
38-
return result == RESULTS.discrete_terminating_event_occurred
41+
with jax.ensure_compile_time_eval():
42+
return result == RESULTS.discrete_terminating_event_occurred
3943

4044

4145
def update_result(old_result: RESULTS, new_result: RESULTS) -> RESULTS:
@@ -49,8 +53,11 @@ def update_result(old_result: RESULTS, new_result: RESULTS) -> RESULTS:
4953
event_n | event_n event_o error_o
5054
error_n | error_n error_n error_o
5155
"""
52-
out_result = jnp.where(is_okay(old_result), new_result, old_result)
53-
return jnp.where(is_okay(new_result) & is_event(old_result), old_result, out_result)
56+
with jax.ensure_compile_time_eval():
57+
out_result = static_select(is_okay(old_result), new_result, old_result)
58+
return static_select(
59+
is_okay(new_result) & is_event(old_result), old_result, out_result
60+
)
5461

5562

5663
class Solution(AbstractPath):

0 commit comments

Comments
 (0)