Skip to content

Cannot reproduce example from stiff ode #700

@Rumoa

Description

@Rumoa

Hello, I am trying to use stiff ode solvers like Kvaerno5. When running the code from the example from docs:

import time

import diffrax
import equinox as eqx  # https://github.com/patrick-kidger/equinox
import jax
import jax.numpy as jnp


class Robertson(eqx.Module):
    k1: float
    k2: float
    k3: float

    def __call__(self, t, y, args):
        f0 = -self.k1 * y[0] + self.k3 * y[1] * y[2]
        f1 = self.k1 * y[0] - self.k2 * y[1] ** 2 - self.k3 * y[1] * y[2]
        f2 = self.k2 * y[1] ** 2
        return jnp.stack([f0, f1, f2])


@jax.jit
def main(k1, k2, k3):
    robertson = Robertson(k1, k2, k3)
    terms = diffrax.ODETerm(robertson)
    t0 = 0.0
    t1 = 100.0
    y0 = jnp.array([1.0, 0.0, 0.0])
    dt0 = 0.0002
    solver = diffrax.Kvaerno5()
    saveat = diffrax.SaveAt(ts=jnp.array([0.0, 1e-4, 1e-3, 1e-2, 1e-1, 1e0, 1e1, 1e2]))
    stepsize_controller = diffrax.PIDController(rtol=1e-8, atol=1e-8)
    sol = diffrax.diffeqsolve(
        terms,
        solver,
        t0,
        t1,
        dt0,
        y0,
        saveat=saveat,
        stepsize_controller=stepsize_controller,
    )
    return sol


main(0.04, 3e7, 1e4)

start = time.time()
sol = main(0.04, 3e7, 1e4)
end = time.time()

print("Results:")
for ti, yi in zip(sol.ts, sol.ys):
    print(f"t={ti.item()}, y={yi.tolist()}")
print(f"Took {sol.stats['num_steps']} steps in {end - start} seconds.")

I get the following error:

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[91], line 45
     32     sol = diffrax.diffeqsolve(
     33         terms,
     34         solver,
   (...)
     40         stepsize_controller=stepsize_controller,
     41     )
     42     return sol
---> 45 main(0.04, 3e7, 1e4)
     47 start = time.time()
     48 sol = main(0.04, 3e7, 1e4)

    [... skipping hidden 13 frame]

Cell In[91], line 32, in main(k1, k2, k3)
     30 saveat = diffrax.SaveAt(ts=jnp.array([0.0, 1e-4, 1e-3, 1e-2, 1e-1, 1e0, 1e1, 1e2]))
     31 stepsize_controller = diffrax.PIDController(rtol=1e-8, atol=1e-8)
---> 32 sol = diffrax.diffeqsolve(
     33     terms,
     34     solver,
     35     t0,
     36     t1,
     37     dt0,
     38     y0,
     39     saveat=saveat,
     40     stepsize_controller=stepsize_controller,
     41 )
     42 return sol

    [... skipping hidden 18 frame]

File ~/miniconda3/envs/jax-gr/lib/python3.12/site-packages/diffrax/_integrate.py:1416, in diffeqsolve(terms, solver, t0, t1, dt0, y0, args, saveat, stepsize_controller, adjoint, event, max_steps, throw, progress_meter, solver_state, controller_state, made_jump, discrete_terminating_event)
   1389 init_state = State(
   1390     y=y0,
   1391     tprev=tprev,
   (...)
   1409     event_mask=event_mask,
   1410 )
   1412 #
   1413 # Main loop
   1414 #
-> 1416 final_state, aux_stats = adjoint.loop(
   1417     args=args,
   1418     terms=terms,
   1419     solver=solver,
   1420     stepsize_controller=stepsize_controller,
   1421     event=event,
   1422     saveat=saveat,
   1423     t0=t0,
   1424     t1=t1,
   1425     dt0=dt0,
   1426     max_steps=max_steps,
   1427     init_state=init_state,
   1428     throw=throw,
   1429     passed_solver_state=passed_solver_state,
   1430     passed_controller_state=passed_controller_state,
   1431     progress_meter=progress_meter,
   1432 )
   1434 #
   1435 # Finish up
   1436 #
   1438 progress_meter.close(final_state.progress_meter_state)

    [... skipping hidden 1 frame]

File ~/miniconda3/envs/jax-gr/lib/python3.12/site-packages/diffrax/_adjoint.py:299, in RecursiveCheckpointAdjoint.loop(***failed resolving arguments***)
    295     outer_while_loop = ft.partial(
    296         _outer_loop, kind="checkpointed", checkpoints=self.checkpoints
    297     )
    298     msg = None
--> 299 final_state = self._loop(
    300     terms=terms,
    301     saveat=saveat,
    302     init_state=init_state,
    303     max_steps=max_steps,
    304     inner_while_loop=inner_while_loop,
    305     outer_while_loop=outer_while_loop,
    306     **kwargs,
    307 )
    308 if msg is not None:
    309     final_state = eqxi.nondifferentiable_backward(
    310         final_state, msg=msg, symbolic=True
    311     )

File ~/miniconda3/envs/jax-gr/lib/python3.12/site-packages/diffrax/_integrate.py:619, in loop(solver, stepsize_controller, event, saveat, t0, t1, dt0, max_steps, terms, args, init_state, inner_while_loop, outer_while_loop, progress_meter)
    617 static_made_jump = init_state.made_jump
    618 static_result = init_state.result
--> 619 _, traced_jump, traced_result = eqx.filter_eval_shape(body_fun_aux, init_state)
    620 if traced_jump:
    621     static_made_jump = None

    [... skipping hidden 16 frame]

File ~/miniconda3/envs/jax-gr/lib/python3.12/site-packages/diffrax/_integrate.py:349, in loop.<locals>.body_fun_aux(state)
    342 state = _handle_static(state)
    344 #
    345 # Actually do some differential equation solving! Make numerical steps, adapt
    346 # step sizes, all that jazz.
    347 #
--> 349 (y, y_error, dense_info, solver_state, solver_result) = solver.step(
    350     terms,
    351     state.tprev,
    352     state.tnext,
    353     state.y,
    354     args,
    355     state.solver_state,
    356     state.made_jump,
    357 )
    359 # e.g. if someone has a sqrt(y) in the vector field, and dt0 is so large that
    360 # we get a negative value for y, and then get a NaN vector field. (And then
    361 # everything breaks.) See #143.
    362 y_error = jtu.tree_map(lambda x: jnp.where(jnp.isnan(x), jnp.inf, x), y_error)

    [... skipping hidden 1 frame]

File ~/miniconda3/envs/jax-gr/lib/python3.12/site-packages/diffrax/_solver/runge_kutta.py:1149, in AbstractRungeKutta.step(***failed resolving arguments***)
   1142 const_result = const_result_sentinel = object()
   1143 # Needs to be an `eqxi.while_loop` as:
   1144 # (a) we may have variable length: e.g. an FSAL explicit RK scheme will have one
   1145 #     more stage on the first step.
   1146 # (b) to work around a limitation of JAX's autodiff being unable to express
   1147 #     "triangular computations" (every stage depends on all previous stages)
   1148 #     without spurious copies.
-> 1149 final_val = eqxi.while_loop(
   1150     cond_stage,
   1151     rk_stage,
   1152     init_val,
   1153     max_steps=num_stages,
   1154     buffers=buffers,
   1155     kind="checkpointed" if self.scan_kind is None else self.scan_kind,
   1156     checkpoints=num_stages,
   1157     base=num_stages,
   1158 )
   1159 _, y1, f1_for_fsal, _, _, fs, ks, result = final_val
   1160 assert const_result is not const_result_sentinel

File ~/miniconda3/envs/jax-gr/lib/python3.12/site-packages/equinox/internal/_loop/loop.py:107, in while_loop(***failed resolving arguments***)
    105 elif kind == "checkpointed":
    106     del kind, base
--> 107     return checkpointed_while_loop(
    108         cond_fun,
    109         body_fun,
    110         init_val,
    111         max_steps=max_steps,
    112         buffers=buffers,
    113         checkpoints=checkpoints,
    114     )
    115 elif kind == "bounded":
    116     del kind, checkpoints

File ~/miniconda3/envs/jax-gr/lib/python3.12/site-packages/equinox/internal/_loop/checkpointed.py:247, in checkpointed_while_loop(***failed resolving arguments***)
    245 cond_fun_ = filter_closure_convert(cond_fun_, init_val_)
    246 cond_fun_ = jtu.tree_map(_stop_gradient, cond_fun_)
--> 247 body_fun_ = filter_closure_convert(body_fun_, init_val_)
    248 vjp_arg = (init_val_, body_fun_)
    249 final_val_ = _checkpointed_while_loop(
    250     vjp_arg, cond_fun_, checkpoints, buffers_, max_steps
    251 )

    [... skipping hidden 17 frame]

File ~/miniconda3/envs/jax-gr/lib/python3.12/site-packages/equinox/internal/_loop/common.py:474, in common_rewrite.<locals>.new_body_fun(val)
    472 step, pred, _, val = val
    473 buffer_val = _wrap_buffers(val, pred, tag)
--> 474 buffer_val2 = body_fun(buffer_val)
    475 # Needed to work with `disable_jit`, as then we lose the automatic
    476 # ArrayLike->Array cast provided by JAX's while loops.
    477 # The input `val` is already cast to Array below, so this matches that.
    478 buffer_val2 = jtu.tree_map(fixed_asarray, buffer_val2)

File ~/miniconda3/envs/jax-gr/lib/python3.12/site-packages/diffrax/_solver/runge_kutta.py:984, in AbstractRungeKutta.step.<locals>.rk_stage(val)
    982 if eval_fs:
    983     jac_f = eqxi.nondifferentiable(jac_f, name="jac_f")
--> 984     nonlinear_sol = optx.root_find(
    985         _implicit_relation_f,
    986         self.root_finder,  # pyright: ignore
    987         f_pred,
    988         f_implicit_args,
    989         options=dict(init_state=jac_f),
    990         throw=False,
    991         max_steps=self.root_find_max_steps,  # pyright: ignore
    992     )
    993     implicit_fi = nonlinear_sol.value
    994     implicit_ki = _unused

    [... skipping hidden 18 frame]

File ~/miniconda3/envs/jax-gr/lib/python3.12/site-packages/optimistix/_root_find.py:220, in root_find(fn, solver, y0, args, options, has_aux, max_steps, adjoint, throw, tags)
    218 if options is None:
    219     options = {}
--> 220 return iterative_solve(
    221     fn,
    222     solver,
    223     y0,
    224     args,
    225     options,
    226     max_steps=max_steps,
    227     adjoint=adjoint,
    228     throw=throw,
    229     tags=tags,
    230     f_struct=f_struct,
    231     aux_struct=aux_struct,
    232     rewrite_fn=_rewrite_fn,
    233 )

File ~/miniconda3/envs/jax-gr/lib/python3.12/site-packages/optimistix/_iterate.py:346, in iterative_solve(fn, solver, y0, args, options, max_steps, adjoint, throw, tags, f_struct, aux_struct, rewrite_fn)
    334 aux_struct = jtu.tree_map(eqxi.Static, aux_struct)
    335 inputs = fn, solver, y0, args, options, max_steps, f_struct, aux_struct, tags
    336 (
    337     out,
    338     (
    339         num_steps,
    340         result,
    341         dynamic_final_state,
    342         static_state,
    343         aux,
    344         stats,
    345     ),
--> 346 ) = adjoint.apply(_iterate, rewrite_fn, inputs, tags)
    347 final_state = eqx.combine(dynamic_final_state, unwrap_jaxpr(static_state.value))
    348 stats = {"num_steps": num_steps, "max_steps": max_steps, **stats}

    [... skipping hidden 1 frame]

File ~/miniconda3/envs/jax-gr/lib/python3.12/site-packages/optimistix/_adjoint.py:134, in ImplicitAdjoint.apply(self, primal_fn, rewrite_fn, inputs, tags)
    132 def apply(self, primal_fn, rewrite_fn, inputs, tags):
    133     inputs = inputs + (ft.partial(eqxi.while_loop, kind="lax"),)
--> 134     return implicit_jvp(primal_fn, rewrite_fn, inputs, tags, self.linear_solver)

File ~/miniconda3/envs/jax-gr/lib/python3.12/site-packages/optimistix/_ad.py:60, in implicit_jvp(fn_primal, fn_rewrite, inputs, tags, linear_solver)
     58 assert _is_global_function(fn_primal)
     59 assert _is_global_function(fn_rewrite)
---> 60 root, residual = _implicit_impl(fn_primal, fn_rewrite, inputs, tags, linear_solver)
     61 return root, jtu.tree_map(eqxi.nondifferentiable_backward, residual)

    [... skipping hidden 14 frame]

File ~/miniconda3/envs/jax-gr/lib/python3.12/site-packages/optimistix/_ad.py:67, in _implicit_impl(***failed resolving arguments***)
     64 @eqx.filter_custom_jvp
     65 def _implicit_impl(fn_primal, fn_rewrite, inputs, tags, linear_solver):
     66     del fn_rewrite, tags, linear_solver
---> 67     return jtu.tree_map(jnp.asarray, fn_primal(inputs))

File ~/miniconda3/envs/jax-gr/lib/python3.12/site-packages/optimistix/_iterate.py:242, in _iterate(***failed resolving arguments***)
    239     _, _, state, _ = carry
    240     return solver.buffers(state)
--> 242 final_carry = while_loop(cond_fun, body_fun, init_carry, max_steps=max_steps)
    243 final_y, num_steps, dynamic_final_state, final_aux = final_carry
    244 final_state = eqx.combine(static_state, dynamic_final_state)

File ~/miniconda3/envs/jax-gr/lib/python3.12/site-packages/equinox/internal/_loop/loop.py:103, in while_loop(***failed resolving arguments***)
     99     cond_fun_, body_fun_, init_val_, _ = common_rewrite(
    100         cond_fun, body_fun, init_val, max_steps, buffers, makes_false_steps=False
    101     )
    102     del cond_fun, body_fun, init_val
--> 103     _, _, _, final_val = lax.while_loop(cond_fun_, body_fun_, init_val_)
    104     return final_val
    105 elif kind == "checkpointed":

    [... skipping hidden 10 frame]

File ~/miniconda3/envs/jax-gr/lib/python3.12/site-packages/equinox/internal/_loop/common.py:474, in common_rewrite.<locals>.new_body_fun(val)
    472 step, pred, _, val = val
    473 buffer_val = _wrap_buffers(val, pred, tag)
--> 474 buffer_val2 = body_fun(buffer_val)
    475 # Needed to work with `disable_jit`, as then we lose the automatic
    476 # ArrayLike->Array cast provided by JAX's while loops.
    477 # The input `val` is already cast to Array below, so this matches that.
    478 buffer_val2 = jtu.tree_map(fixed_asarray, buffer_val2)

File ~/miniconda3/envs/jax-gr/lib/python3.12/site-packages/optimistix/_iterate.py:232, in _iterate.<locals>.body_fun(carry)
    230 y, num_steps, dynamic_state, _ = carry
    231 state = eqx.combine(static_state, dynamic_state)
--> 232 new_y, new_state, aux = solver.step(fn, y, args, options, state, tags)
    233 new_dynamic_state, new_static_state = eqx.partition(new_state, eqx.is_array)
    235 assert eqx.tree_equal(static_state, new_static_state) is True

    [... skipping hidden 1 frame]

File ~/miniconda3/envs/jax-gr/lib/python3.12/site-packages/diffrax/_root_finder/_verychord.py:127, in VeryChord.step(***failed resolving arguments***)
    125 jac, linear_state = state.linear_state
    126 linear_state = lax.stop_gradient(linear_state)
--> 127 sol = lx.linear_solve(
    128     jac, fx, self.linear_solver, state=linear_state, throw=False
    129 )
    130 diff = sol.value
    131 new_y = (y**ω - diff**ω).ω

    [... skipping hidden 18 frame]

File ~/miniconda3/envs/jax-gr/lib/python3.12/site-packages/lineax/_solve.py:810, in linear_solve(operator, vector, solver, options, state, throw)
    804 options = eqxi.nondifferentiable(
    805     options, name="`lineax.linear_solve(..., options=...)`"
    806 )
    807 solver = eqxi.nondifferentiable(
    808     solver, name="`lineax.linear_solve(..., solver=...)`"
    809 )
--> 810 solution, result, stats = eqxi.filter_primitive_bind(
    811     linear_solve_p, operator, state, vector, options, solver, throw
    812 )
    813 # TODO: prevent forward-mode autodiff through stats
    814 stats = eqxi.nondifferentiable_backward(stats)

File ~/miniconda3/envs/jax-gr/lib/python3.12/site-packages/equinox/internal/_primitive.py:273, in filter_primitive_bind(prim, *args)
    271 static = tuple(_missing_dynamic if is_array(x) else x for x in flat)
    272 flatten = Flatten()
--> 273 flat_out = prim.bind(*dynamic, treedef=treedef, static=static, flatten=flatten)
    274 treedef_out, static_out = flatten.get()
    275 return combine(jtu.tree_unflatten(treedef_out, flat_out), static_out)

    [... skipping hidden 5 frame]

File ~/miniconda3/envs/jax-gr/lib/python3.12/site-packages/jax/_src/util.py:465, in multi_weakref_lru_cache.<locals>.wrapper(*orig_args, **orig_kwargs)
    462   return cached_call(acc_weakrefs[0],
    463                      *args, **kwargs)
    464 else:
--> 465   value_to_weakref = {v: weakref.ref(v, remove_weakref)
    466                       for v in set(acc_weakrefs)}
    467   key = MultiWeakRefCacheKey(weakrefs=tuple(value_to_weakref[v]
    468                                             for v in acc_weakrefs))
    469   return cached_call(key, *args, **kwargs)

TypeError: cannot create weak reference to 'Flatten' object

The version of jax that I am using is 0.7.1 and diffrax is 0.7.0

Metadata

Metadata

Assignees

No one assigned

    Labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions