-
-
Notifications
You must be signed in to change notification settings - Fork 163
Open
Labels
questionUser queriesUser queries
Description
Hello, I am trying to use stiff ode solvers like Kvaerno5. When running the code from the example from docs:
import time
import diffrax
import equinox as eqx # https://github.com/patrick-kidger/equinox
import jax
import jax.numpy as jnp
class Robertson(eqx.Module):
k1: float
k2: float
k3: float
def __call__(self, t, y, args):
f0 = -self.k1 * y[0] + self.k3 * y[1] * y[2]
f1 = self.k1 * y[0] - self.k2 * y[1] ** 2 - self.k3 * y[1] * y[2]
f2 = self.k2 * y[1] ** 2
return jnp.stack([f0, f1, f2])
@jax.jit
def main(k1, k2, k3):
robertson = Robertson(k1, k2, k3)
terms = diffrax.ODETerm(robertson)
t0 = 0.0
t1 = 100.0
y0 = jnp.array([1.0, 0.0, 0.0])
dt0 = 0.0002
solver = diffrax.Kvaerno5()
saveat = diffrax.SaveAt(ts=jnp.array([0.0, 1e-4, 1e-3, 1e-2, 1e-1, 1e0, 1e1, 1e2]))
stepsize_controller = diffrax.PIDController(rtol=1e-8, atol=1e-8)
sol = diffrax.diffeqsolve(
terms,
solver,
t0,
t1,
dt0,
y0,
saveat=saveat,
stepsize_controller=stepsize_controller,
)
return sol
main(0.04, 3e7, 1e4)
start = time.time()
sol = main(0.04, 3e7, 1e4)
end = time.time()
print("Results:")
for ti, yi in zip(sol.ts, sol.ys):
print(f"t={ti.item()}, y={yi.tolist()}")
print(f"Took {sol.stats['num_steps']} steps in {end - start} seconds.")
I get the following error:
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
Cell In[91], line 45
32 sol = diffrax.diffeqsolve(
33 terms,
34 solver,
(...)
40 stepsize_controller=stepsize_controller,
41 )
42 return sol
---> 45 main(0.04, 3e7, 1e4)
47 start = time.time()
48 sol = main(0.04, 3e7, 1e4)
[... skipping hidden 13 frame]
Cell In[91], line 32, in main(k1, k2, k3)
30 saveat = diffrax.SaveAt(ts=jnp.array([0.0, 1e-4, 1e-3, 1e-2, 1e-1, 1e0, 1e1, 1e2]))
31 stepsize_controller = diffrax.PIDController(rtol=1e-8, atol=1e-8)
---> 32 sol = diffrax.diffeqsolve(
33 terms,
34 solver,
35 t0,
36 t1,
37 dt0,
38 y0,
39 saveat=saveat,
40 stepsize_controller=stepsize_controller,
41 )
42 return sol
[... skipping hidden 18 frame]
File ~/miniconda3/envs/jax-gr/lib/python3.12/site-packages/diffrax/_integrate.py:1416, in diffeqsolve(terms, solver, t0, t1, dt0, y0, args, saveat, stepsize_controller, adjoint, event, max_steps, throw, progress_meter, solver_state, controller_state, made_jump, discrete_terminating_event)
1389 init_state = State(
1390 y=y0,
1391 tprev=tprev,
(...)
1409 event_mask=event_mask,
1410 )
1412 #
1413 # Main loop
1414 #
-> 1416 final_state, aux_stats = adjoint.loop(
1417 args=args,
1418 terms=terms,
1419 solver=solver,
1420 stepsize_controller=stepsize_controller,
1421 event=event,
1422 saveat=saveat,
1423 t0=t0,
1424 t1=t1,
1425 dt0=dt0,
1426 max_steps=max_steps,
1427 init_state=init_state,
1428 throw=throw,
1429 passed_solver_state=passed_solver_state,
1430 passed_controller_state=passed_controller_state,
1431 progress_meter=progress_meter,
1432 )
1434 #
1435 # Finish up
1436 #
1438 progress_meter.close(final_state.progress_meter_state)
[... skipping hidden 1 frame]
File ~/miniconda3/envs/jax-gr/lib/python3.12/site-packages/diffrax/_adjoint.py:299, in RecursiveCheckpointAdjoint.loop(***failed resolving arguments***)
295 outer_while_loop = ft.partial(
296 _outer_loop, kind="checkpointed", checkpoints=self.checkpoints
297 )
298 msg = None
--> 299 final_state = self._loop(
300 terms=terms,
301 saveat=saveat,
302 init_state=init_state,
303 max_steps=max_steps,
304 inner_while_loop=inner_while_loop,
305 outer_while_loop=outer_while_loop,
306 **kwargs,
307 )
308 if msg is not None:
309 final_state = eqxi.nondifferentiable_backward(
310 final_state, msg=msg, symbolic=True
311 )
File ~/miniconda3/envs/jax-gr/lib/python3.12/site-packages/diffrax/_integrate.py:619, in loop(solver, stepsize_controller, event, saveat, t0, t1, dt0, max_steps, terms, args, init_state, inner_while_loop, outer_while_loop, progress_meter)
617 static_made_jump = init_state.made_jump
618 static_result = init_state.result
--> 619 _, traced_jump, traced_result = eqx.filter_eval_shape(body_fun_aux, init_state)
620 if traced_jump:
621 static_made_jump = None
[... skipping hidden 16 frame]
File ~/miniconda3/envs/jax-gr/lib/python3.12/site-packages/diffrax/_integrate.py:349, in loop.<locals>.body_fun_aux(state)
342 state = _handle_static(state)
344 #
345 # Actually do some differential equation solving! Make numerical steps, adapt
346 # step sizes, all that jazz.
347 #
--> 349 (y, y_error, dense_info, solver_state, solver_result) = solver.step(
350 terms,
351 state.tprev,
352 state.tnext,
353 state.y,
354 args,
355 state.solver_state,
356 state.made_jump,
357 )
359 # e.g. if someone has a sqrt(y) in the vector field, and dt0 is so large that
360 # we get a negative value for y, and then get a NaN vector field. (And then
361 # everything breaks.) See #143.
362 y_error = jtu.tree_map(lambda x: jnp.where(jnp.isnan(x), jnp.inf, x), y_error)
[... skipping hidden 1 frame]
File ~/miniconda3/envs/jax-gr/lib/python3.12/site-packages/diffrax/_solver/runge_kutta.py:1149, in AbstractRungeKutta.step(***failed resolving arguments***)
1142 const_result = const_result_sentinel = object()
1143 # Needs to be an `eqxi.while_loop` as:
1144 # (a) we may have variable length: e.g. an FSAL explicit RK scheme will have one
1145 # more stage on the first step.
1146 # (b) to work around a limitation of JAX's autodiff being unable to express
1147 # "triangular computations" (every stage depends on all previous stages)
1148 # without spurious copies.
-> 1149 final_val = eqxi.while_loop(
1150 cond_stage,
1151 rk_stage,
1152 init_val,
1153 max_steps=num_stages,
1154 buffers=buffers,
1155 kind="checkpointed" if self.scan_kind is None else self.scan_kind,
1156 checkpoints=num_stages,
1157 base=num_stages,
1158 )
1159 _, y1, f1_for_fsal, _, _, fs, ks, result = final_val
1160 assert const_result is not const_result_sentinel
File ~/miniconda3/envs/jax-gr/lib/python3.12/site-packages/equinox/internal/_loop/loop.py:107, in while_loop(***failed resolving arguments***)
105 elif kind == "checkpointed":
106 del kind, base
--> 107 return checkpointed_while_loop(
108 cond_fun,
109 body_fun,
110 init_val,
111 max_steps=max_steps,
112 buffers=buffers,
113 checkpoints=checkpoints,
114 )
115 elif kind == "bounded":
116 del kind, checkpoints
File ~/miniconda3/envs/jax-gr/lib/python3.12/site-packages/equinox/internal/_loop/checkpointed.py:247, in checkpointed_while_loop(***failed resolving arguments***)
245 cond_fun_ = filter_closure_convert(cond_fun_, init_val_)
246 cond_fun_ = jtu.tree_map(_stop_gradient, cond_fun_)
--> 247 body_fun_ = filter_closure_convert(body_fun_, init_val_)
248 vjp_arg = (init_val_, body_fun_)
249 final_val_ = _checkpointed_while_loop(
250 vjp_arg, cond_fun_, checkpoints, buffers_, max_steps
251 )
[... skipping hidden 17 frame]
File ~/miniconda3/envs/jax-gr/lib/python3.12/site-packages/equinox/internal/_loop/common.py:474, in common_rewrite.<locals>.new_body_fun(val)
472 step, pred, _, val = val
473 buffer_val = _wrap_buffers(val, pred, tag)
--> 474 buffer_val2 = body_fun(buffer_val)
475 # Needed to work with `disable_jit`, as then we lose the automatic
476 # ArrayLike->Array cast provided by JAX's while loops.
477 # The input `val` is already cast to Array below, so this matches that.
478 buffer_val2 = jtu.tree_map(fixed_asarray, buffer_val2)
File ~/miniconda3/envs/jax-gr/lib/python3.12/site-packages/diffrax/_solver/runge_kutta.py:984, in AbstractRungeKutta.step.<locals>.rk_stage(val)
982 if eval_fs:
983 jac_f = eqxi.nondifferentiable(jac_f, name="jac_f")
--> 984 nonlinear_sol = optx.root_find(
985 _implicit_relation_f,
986 self.root_finder, # pyright: ignore
987 f_pred,
988 f_implicit_args,
989 options=dict(init_state=jac_f),
990 throw=False,
991 max_steps=self.root_find_max_steps, # pyright: ignore
992 )
993 implicit_fi = nonlinear_sol.value
994 implicit_ki = _unused
[... skipping hidden 18 frame]
File ~/miniconda3/envs/jax-gr/lib/python3.12/site-packages/optimistix/_root_find.py:220, in root_find(fn, solver, y0, args, options, has_aux, max_steps, adjoint, throw, tags)
218 if options is None:
219 options = {}
--> 220 return iterative_solve(
221 fn,
222 solver,
223 y0,
224 args,
225 options,
226 max_steps=max_steps,
227 adjoint=adjoint,
228 throw=throw,
229 tags=tags,
230 f_struct=f_struct,
231 aux_struct=aux_struct,
232 rewrite_fn=_rewrite_fn,
233 )
File ~/miniconda3/envs/jax-gr/lib/python3.12/site-packages/optimistix/_iterate.py:346, in iterative_solve(fn, solver, y0, args, options, max_steps, adjoint, throw, tags, f_struct, aux_struct, rewrite_fn)
334 aux_struct = jtu.tree_map(eqxi.Static, aux_struct)
335 inputs = fn, solver, y0, args, options, max_steps, f_struct, aux_struct, tags
336 (
337 out,
338 (
339 num_steps,
340 result,
341 dynamic_final_state,
342 static_state,
343 aux,
344 stats,
345 ),
--> 346 ) = adjoint.apply(_iterate, rewrite_fn, inputs, tags)
347 final_state = eqx.combine(dynamic_final_state, unwrap_jaxpr(static_state.value))
348 stats = {"num_steps": num_steps, "max_steps": max_steps, **stats}
[... skipping hidden 1 frame]
File ~/miniconda3/envs/jax-gr/lib/python3.12/site-packages/optimistix/_adjoint.py:134, in ImplicitAdjoint.apply(self, primal_fn, rewrite_fn, inputs, tags)
132 def apply(self, primal_fn, rewrite_fn, inputs, tags):
133 inputs = inputs + (ft.partial(eqxi.while_loop, kind="lax"),)
--> 134 return implicit_jvp(primal_fn, rewrite_fn, inputs, tags, self.linear_solver)
File ~/miniconda3/envs/jax-gr/lib/python3.12/site-packages/optimistix/_ad.py:60, in implicit_jvp(fn_primal, fn_rewrite, inputs, tags, linear_solver)
58 assert _is_global_function(fn_primal)
59 assert _is_global_function(fn_rewrite)
---> 60 root, residual = _implicit_impl(fn_primal, fn_rewrite, inputs, tags, linear_solver)
61 return root, jtu.tree_map(eqxi.nondifferentiable_backward, residual)
[... skipping hidden 14 frame]
File ~/miniconda3/envs/jax-gr/lib/python3.12/site-packages/optimistix/_ad.py:67, in _implicit_impl(***failed resolving arguments***)
64 @eqx.filter_custom_jvp
65 def _implicit_impl(fn_primal, fn_rewrite, inputs, tags, linear_solver):
66 del fn_rewrite, tags, linear_solver
---> 67 return jtu.tree_map(jnp.asarray, fn_primal(inputs))
File ~/miniconda3/envs/jax-gr/lib/python3.12/site-packages/optimistix/_iterate.py:242, in _iterate(***failed resolving arguments***)
239 _, _, state, _ = carry
240 return solver.buffers(state)
--> 242 final_carry = while_loop(cond_fun, body_fun, init_carry, max_steps=max_steps)
243 final_y, num_steps, dynamic_final_state, final_aux = final_carry
244 final_state = eqx.combine(static_state, dynamic_final_state)
File ~/miniconda3/envs/jax-gr/lib/python3.12/site-packages/equinox/internal/_loop/loop.py:103, in while_loop(***failed resolving arguments***)
99 cond_fun_, body_fun_, init_val_, _ = common_rewrite(
100 cond_fun, body_fun, init_val, max_steps, buffers, makes_false_steps=False
101 )
102 del cond_fun, body_fun, init_val
--> 103 _, _, _, final_val = lax.while_loop(cond_fun_, body_fun_, init_val_)
104 return final_val
105 elif kind == "checkpointed":
[... skipping hidden 10 frame]
File ~/miniconda3/envs/jax-gr/lib/python3.12/site-packages/equinox/internal/_loop/common.py:474, in common_rewrite.<locals>.new_body_fun(val)
472 step, pred, _, val = val
473 buffer_val = _wrap_buffers(val, pred, tag)
--> 474 buffer_val2 = body_fun(buffer_val)
475 # Needed to work with `disable_jit`, as then we lose the automatic
476 # ArrayLike->Array cast provided by JAX's while loops.
477 # The input `val` is already cast to Array below, so this matches that.
478 buffer_val2 = jtu.tree_map(fixed_asarray, buffer_val2)
File ~/miniconda3/envs/jax-gr/lib/python3.12/site-packages/optimistix/_iterate.py:232, in _iterate.<locals>.body_fun(carry)
230 y, num_steps, dynamic_state, _ = carry
231 state = eqx.combine(static_state, dynamic_state)
--> 232 new_y, new_state, aux = solver.step(fn, y, args, options, state, tags)
233 new_dynamic_state, new_static_state = eqx.partition(new_state, eqx.is_array)
235 assert eqx.tree_equal(static_state, new_static_state) is True
[... skipping hidden 1 frame]
File ~/miniconda3/envs/jax-gr/lib/python3.12/site-packages/diffrax/_root_finder/_verychord.py:127, in VeryChord.step(***failed resolving arguments***)
125 jac, linear_state = state.linear_state
126 linear_state = lax.stop_gradient(linear_state)
--> 127 sol = lx.linear_solve(
128 jac, fx, self.linear_solver, state=linear_state, throw=False
129 )
130 diff = sol.value
131 new_y = (y**ω - diff**ω).ω
[... skipping hidden 18 frame]
File ~/miniconda3/envs/jax-gr/lib/python3.12/site-packages/lineax/_solve.py:810, in linear_solve(operator, vector, solver, options, state, throw)
804 options = eqxi.nondifferentiable(
805 options, name="`lineax.linear_solve(..., options=...)`"
806 )
807 solver = eqxi.nondifferentiable(
808 solver, name="`lineax.linear_solve(..., solver=...)`"
809 )
--> 810 solution, result, stats = eqxi.filter_primitive_bind(
811 linear_solve_p, operator, state, vector, options, solver, throw
812 )
813 # TODO: prevent forward-mode autodiff through stats
814 stats = eqxi.nondifferentiable_backward(stats)
File ~/miniconda3/envs/jax-gr/lib/python3.12/site-packages/equinox/internal/_primitive.py:273, in filter_primitive_bind(prim, *args)
271 static = tuple(_missing_dynamic if is_array(x) else x for x in flat)
272 flatten = Flatten()
--> 273 flat_out = prim.bind(*dynamic, treedef=treedef, static=static, flatten=flatten)
274 treedef_out, static_out = flatten.get()
275 return combine(jtu.tree_unflatten(treedef_out, flat_out), static_out)
[... skipping hidden 5 frame]
File ~/miniconda3/envs/jax-gr/lib/python3.12/site-packages/jax/_src/util.py:465, in multi_weakref_lru_cache.<locals>.wrapper(*orig_args, **orig_kwargs)
462 return cached_call(acc_weakrefs[0],
463 *args, **kwargs)
464 else:
--> 465 value_to_weakref = {v: weakref.ref(v, remove_weakref)
466 for v in set(acc_weakrefs)}
467 key = MultiWeakRefCacheKey(weakrefs=tuple(value_to_weakref[v]
468 for v in acc_weakrefs))
469 return cached_call(key, *args, **kwargs)
TypeError: cannot create weak reference to 'Flatten' object
The version of jax that I am using is 0.7.1 and diffrax is 0.7.0
Metadata
Metadata
Assignees
Labels
questionUser queriesUser queries