88import jax
99import jax .numpy as jnp
1010import jax .tree_util as jtu
11+ from jax .typing import ArrayLike
1112
1213from .adjoint import AbstractAdjoint , DirectAdjoint , RecursiveCheckpointAdjoint
1314from .custom_types import Array , Bool , Int , PyTree , Scalar
1415from .event import AbstractDiscreteTerminatingEvent
1516from .global_interpolation import DenseInterpolation
1617from .heuristics import is_sde , is_unsafe_sde
18+ from .misc import static_select
1719from .saveat import SaveAt , SubSaveAt
1820from .solution import is_okay , is_successful , RESULTS , Solution
1921from .solver import (
2931 AbstractAdaptiveStepSizeController ,
3032 AbstractStepSizeController ,
3133 ConstantStepSize ,
32- PIDController ,
3334 StepTo ,
3435)
3536from .term import AbstractTerm , MultiTerm , ODETerm , WrapTerm
@@ -141,6 +142,19 @@ def _clip_to_end(tprev, tnext, t1, keep_step):
141142 return jnp .where (clip , tclip , tnext )
142143
143144
145+ def _maybe_static (static_x : ArrayLike , x : Array ) -> ArrayLike :
146+ # Some values (made_jump and result) are not used in many common use-cases. If we
147+ # detect that they're unused then we make sure they're non-Array Python values, so
148+ # that we can special case on them at trace time and get a performance boost.
149+ if isinstance (static_x , (bool , int , float , complex )):
150+ return static_x
151+ elif type (jax .core .get_aval (static_x )) is jax .core .ConcreteArray :
152+ with jax .ensure_compile_time_eval ():
153+ return static_x .item ()
154+ else :
155+ return x
156+
157+
144158def loop (
145159 * ,
146160 solver ,
@@ -175,19 +189,27 @@ def save_t0(subsaveat: SubSaveAt, save_state: SaveState) -> SaveState:
175189 lambda s : s .save_state , init_state , save_state , is_leaf = _is_none
176190 )
177191
178- # Privileged optimisation for the common case of no jumps. We can reduce
179- # solver compile time with this .
180- # TODO: somehow make this a non-privileged optimisation, i.e. detect when
181- # we can make jumps or not.
182- cannot_make_jump = isinstance ( stepsize_controller , ( ConstantStepSize , StepTo )) or (
183- isinstance ( stepsize_controller , PIDController )
184- and stepsize_controller . jump_ts is None
185- )
192+ def _handle_static ( state ):
193+ # We can improve runtime by resolving `result` at trace time if possible .
194+ # We can improve compiletime by resolving `made_jump` at trace time if possible.
195+ result = _maybe_static ( static_result , state . result )
196+ made_jump = _maybe_static ( static_made_jump , state . made_jump )
197+ return eqx . tree_at (
198+ lambda s : ( s . result , s . made_jump ), state , ( result , made_jump )
199+ )
186200
187201 def cond_fun (state ):
188- return (state .tprev < t1 ) & is_successful (state .result )
202+ if isinstance (stepsize_controller , StepTo ):
203+ # Privileged optimisation.
204+ # This is a measurably cheaper check than the tprev < t1 check.
205+ out = state .num_steps < len (stepsize_controller .ts ) - 1
206+ else :
207+ out = state .tprev < t1
208+ state = _handle_static (state )
209+ return out & is_successful (state .result )
189210
190211 def body_fun (state ):
212+ state = _handle_static (state )
191213
192214 #
193215 # Actually do some differential equation solving! Make numerical steps, adapt
@@ -201,7 +223,7 @@ def body_fun(state):
201223 state .y ,
202224 args ,
203225 state .solver_state ,
204- False if cannot_make_jump else state .made_jump ,
226+ state .made_jump ,
205227 )
206228
207229 # e.g. if someone has a sqrt(y) in the vector field, and dt0 is so large that
@@ -228,16 +250,6 @@ def body_fun(state):
228250 state .controller_state ,
229251 )
230252 assert jnp .result_type (keep_step ) is jnp .dtype (bool )
231- if cannot_make_jump :
232- # Should hopefully get DCE'd out.
233- made_jump = eqxi .error_if (
234- made_jump ,
235- made_jump ,
236- (
237- "Internal error in Diffrax: made unexpected jump. Please report an "
238- "issue at https://github.com/patrick-kidger/diffrax/issues"
239- ),
240- )
241253
242254 #
243255 # Do some book-keeping.
@@ -252,8 +264,8 @@ def body_fun(state):
252264 keep = lambda a , b : jnp .where (keep_step , a , b )
253265 y = jtu .tree_map (keep , y , state .y )
254266 solver_state = jtu .tree_map (keep , solver_state , state .solver_state )
255- made_jump = keep ( made_jump , state .made_jump )
256- solver_result = keep ( solver_result , RESULTS .successful )
267+ made_jump = static_select ( keep_step , made_jump , state .made_jump )
268+ solver_result = static_select ( keep_step , solver_result , RESULTS .successful )
257269
258270 # TODO: if we ever support non-terminating events, then they should go in here.
259271 # In particular the thing to be careful about is in the `if saveat.steps`
@@ -262,9 +274,8 @@ def body_fun(state):
262274 # previous step's `tnext`, i.e. immediately before the jump.)
263275
264276 # Store the first unsuccessful result we get whilst iterating (if any).
265- result = state .result
266- result = jnp .where (is_okay (result ), solver_result , result )
267- result = jnp .where (is_okay (result ), stepsize_controller_result , result )
277+ result = static_select (is_okay (state .result ), solver_result , state .result )
278+ result = static_select (is_okay (result ), stepsize_controller_result , result )
268279
269280 # Count the number of steps, just for statistical purposes.
270281 num_steps = state .num_steps + 1
@@ -328,7 +339,15 @@ def _body_fun(_save_state):
328339 )
329340
330341 def maybe_inplace (i , u , x ):
331- return x .at [i ].set (u , pred = keep_step )
342+ # Annoying hack. We normally call this with `x` wrapped into a buffer
343+ # (from Equinox's while loops). However we do also first trace through to
344+ # see if we can resolve some values statically, in which case normal JAX
345+ # arrays don't support the extra `pred` argument. We don't then use the
346+ # result of this so we just skip it.
347+ if _filtering :
348+ return x
349+ else :
350+ return x .at [i ].set (u , pred = keep_step )
332351
333352 def save_steps (subsaveat : SubSaveAt , save_state : SaveState ) -> SaveState :
334353 if subsaveat .steps :
@@ -389,7 +408,7 @@ def save_steps(subsaveat: SubSaveAt, save_state: SaveState) -> SaveState:
389408 terms = terms ,
390409 args = args ,
391410 )
392- result = jnp . where (
411+ result = static_select (
393412 discrete_terminating_event_occurred ,
394413 RESULTS .discrete_terminating_event_occurred ,
395414 result ,
@@ -398,6 +417,15 @@ def save_steps(subsaveat: SubSaveAt, save_state: SaveState) -> SaveState:
398417
399418 return new_state
400419
420+ _filtering = True
421+ static_made_jump = init_state .made_jump
422+ static_result = init_state .result
423+ filter_state = eqx .filter_eval_shape (body_fun , init_state )
424+ _filtering = False
425+ static_made_jump = filter_state .made_jump
426+ static_result = filter_state .result
427+ del filter_state
428+
401429 final_state = outer_while_loop (
402430 cond_fun , body_fun , init_state , max_steps = max_steps , buffers = _outer_buffers
403431 )
@@ -420,6 +448,7 @@ def _save_t1(subsaveat, save_state):
420448 lambda s : s .save_state , final_state , save_state , is_leaf = _is_none
421449 )
422450
451+ final_state = _handle_static (final_state )
423452 result = jnp .where (
424453 cond_fun (final_state ), RESULTS .max_steps_reached , final_state .result
425454 )
@@ -751,7 +780,7 @@ def _allocate_output(subsaveat: SubSaveAt) -> SaveState:
751780 num_accepted_steps = 0
752781 num_rejected_steps = 0
753782 made_jump = False if made_jump is None else made_jump
754- result = jnp . array ( RESULTS .successful )
783+ result = RESULTS .successful
755784 if saveat .dense :
756785 if max_steps is None :
757786 raise ValueError (
@@ -871,10 +900,11 @@ def _allocate_output(subsaveat: SubSaveAt) -> SaveState:
871900 )
872901
873902 error_index = eqxi .unvmap_max (result )
874- sol = eqxi .branched_error_if (
875- sol ,
876- throw & jnp .invert (is_okay (result )),
877- error_index ,
878- RESULTS .reverse_lookup ,
879- )
903+ if throw :
904+ sol = eqxi .branched_error_if (
905+ sol ,
906+ jnp .invert (is_okay (result )),
907+ error_index ,
908+ RESULTS .reverse_lookup ,
909+ )
880910 return sol
0 commit comments