diff --git a/diffrax/_solver/srk.py b/diffrax/_solver/srk.py index 56be17ba..a7d1e905 100644 --- a/diffrax/_solver/srk.py +++ b/diffrax/_solver/srk.py @@ -352,10 +352,7 @@ def step( else: ignore_stage_g = jnp.array(self.tableau.ignore_stage_g) - # time increment - h = t1 - t0 - - # First the drift related stuff + # # First the drift related stuff a = self._embed_a_lower(self.tableau.a, dtype) c = jnp.asarray( np.insert(self.tableau.c, 0, 0.0), dtype=complex_to_real_dtype(dtype) @@ -380,6 +377,10 @@ def make_zeros_aux(leaf): # Now the diffusion related stuff # Brownian increment (and space-time Lévy area) bm_inc = diffusion.contr(t0, t1, use_levy=True) + + # time increment + h = bm_inc.dt + if not isinstance(bm_inc, self.minimal_levy_area): raise ValueError( f"The Brownian increment {bm_inc} does not have the "