-
-
Notifications
You must be signed in to change notification settings - Fork 163
Open
Description
diffeqsolve traces vf from ODETerms before running to check that shapes are compatible, however when a user has a bug in their vf traversing the stack trace can be quite cumbersome (typically the problem is about 30-50% through the stack trace). The lowest error shown is just that terms are not compatible which is not helpful, scrawling through this when y0 is a complicated pytree and vf a complicated eqx.Module can be quite cumbersome. When an error occurs is it possible to exit earlier/truncate the unnecessary diffeqsolve stack trace?
MWE:
import jax.numpy as jnp
import diffrax
def f(t, y, args):
return x
diffrax.diffeqsolve(diffrax.ODETerm(vf), diffrax.Euler(), 0.0, 1.0, 0.1, jnp.zeros(1))Gives the following stack trace
Traceback (most recent call last):
File "/Users/jonathanbrodrick/pasteurcodes/diffrax/diffrax/_integrate.py", line 165, in _check
vf_type = eqx.filter_eval_shape(term.vf, t, yi, args)
File "/Users/jonathanbrodrick/.virtualenvs/ergodic/lib/python3.13/site-packages/equinox/_eval_shape.py", line 38, in filter_eval_shape
dynamic_out, static_out = jax.eval_shape(ft.partial(_fn, static), dynamic)
~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/jonathanbrodrick/.virtualenvs/ergodic/lib/python3.13/site-packages/jax/_src/traceback_util.py", line 182, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/Users/jonathanbrodrick/.virtualenvs/ergodic/lib/python3.13/site-packages/jax/_src/api.py", line 3014, in eval_shape
return jit(fun).eval_shape(*args, **kwargs)
~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^
File "/Users/jonathanbrodrick/.virtualenvs/ergodic/lib/python3.13/site-packages/jax/_src/traceback_util.py", line 182, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/Users/jonathanbrodrick/.virtualenvs/ergodic/lib/python3.13/site-packages/jax/_src/pjit.py", line 352, in jit_eval_shape
p, _ = _infer_params(jit_func._fun, jit_func._jit_info, args, kwargs)
~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/jonathanbrodrick/.virtualenvs/ergodic/lib/python3.13/site-packages/jax/_src/pjit.py", line 686, in _infer_params
return _infer_params_internal(fun, ji, args, kwargs)
File "/Users/jonathanbrodrick/.virtualenvs/ergodic/lib/python3.13/site-packages/jax/_src/pjit.py", line 710, in _infer_params_internal
p, args_flat = _infer_params_impl(
~~~~~~~~~~~~~~~~~~^
fun, ji, ctx_mesh, dbg, args, kwargs, in_avals=avals)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/jonathanbrodrick/.virtualenvs/ergodic/lib/python3.13/site-packages/jax/_src/pjit.py", line 606, in _infer_params_impl
jaxpr, consts, out_avals, attrs_tracked = _create_pjit_jaxpr(
~~~~~~~~~~~~~~~~~~^
flat_fun, in_type, attr_token, IgnoreKey(ji.inline))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/jonathanbrodrick/.virtualenvs/ergodic/lib/python3.13/site-packages/jax/_src/linear_util.py", line 471, in memoized_fun
ans = call(fun, *args)
File "/Users/jonathanbrodrick/.virtualenvs/ergodic/lib/python3.13/site-packages/jax/_src/pjit.py", line 1414, in _create_pjit_jaxpr
jaxpr, global_out_avals, consts, attrs_tracked = pe.trace_to_jaxpr_dynamic(fun, in_type)
~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^
File "/Users/jonathanbrodrick/.virtualenvs/ergodic/lib/python3.13/site-packages/jax/_src/profiler.py", line 354, in wrapper
return func(*args, **kwargs)
File "/Users/jonathanbrodrick/.virtualenvs/ergodic/lib/python3.13/site-packages/jax/_src/interpreters/partial_eval.py", line 2292, in trace_to_jaxpr_dynamic
ans = fun.call_wrapped(*in_tracers)
File "/Users/jonathanbrodrick/.virtualenvs/ergodic/lib/python3.13/site-packages/jax/_src/linear_util.py", line 211, in call_wrapped
return self.f_transformed(*args, **kwargs)
~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^
File "/Users/jonathanbrodrick/.virtualenvs/ergodic/lib/python3.13/site-packages/jax/_src/api_util.py", line 288, in _argnums_partial
return _fun(*args, **kwargs)
File "/Users/jonathanbrodrick/.virtualenvs/ergodic/lib/python3.13/site-packages/jax/_src/api_util.py", line 73, in flatten_fun
ans = f(*py_args, **py_kwargs)
File "/Users/jonathanbrodrick/.virtualenvs/ergodic/lib/python3.13/site-packages/jax/_src/linear_util.py", line 396, in _get_result_paths_thunk
ans = _fun(*args, **kwargs)
File "/Users/jonathanbrodrick/.virtualenvs/ergodic/lib/python3.13/site-packages/equinox/_eval_shape.py", line 33, in _fn
_out = _fun(*_args, **_kwargs)
File "/Users/jonathanbrodrick/.virtualenvs/ergodic/lib/python3.13/site-packages/equinox/_module.py", line 1060, in __call__
return self.__func__(self.__self__, *args, **kwargs)
~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/jonathanbrodrick/pasteurcodes/diffrax/diffrax/_term.py", line 194, in vf
out = self.vector_field(t, y, args)
File "<python-input-3>", line 2, in vf
return x
^
NameError: name 'x' is not defined
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "/Users/jonathanbrodrick/pasteurcodes/diffrax/diffrax/_integrate.py", line 195, in _assert_term_compatible
jtu.tree_map(_check, term_structure, terms, contr_kwargs, y)
~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/jonathanbrodrick/.virtualenvs/ergodic/lib/python3.13/site-packages/jax/_src/tree_util.py", line 362, in tree_map
return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/jonathanbrodrick/.virtualenvs/ergodic/lib/python3.13/site-packages/jax/_src/tree_util.py", line 362, in <genexpr>
return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
~^^^^^
File "/Users/jonathanbrodrick/pasteurcodes/diffrax/diffrax/_integrate.py", line 167, in _check
raise ValueError(f"Error while tracing {term}.vf: " + str(e))
ValueError: Error while tracing ODETerm(vector_field=<function vf>).vf: name 'x' is not defined
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "<python-input-3>", line 3, in <module>
diffrax.diffeqsolve(diffrax.ODETerm(vf), diffrax.Euler(), 0.0, 1.0, 0.1, jnp.zeros(1))
~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/jonathanbrodrick/.virtualenvs/ergodic/lib/python3.13/site-packages/equinox/_jit.py", line 209, in __call__
return _call(self, False, args, kwargs)
File "/Users/jonathanbrodrick/.virtualenvs/ergodic/lib/python3.13/site-packages/equinox/_jit.py", line 263, in _call
marker, _, _ = out = jit_wrapper._cached(
~~~~~~~~~~~~~~~~~~~^
dynamic_donate, dynamic_nodonate, static
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
)
^
File "/Users/jonathanbrodrick/pasteurcodes/diffrax/diffrax/_integrate.py", line 1117, in diffeqsolve
_assert_term_compatible(
~~~~~~~~~~~~~~~~~~~~~~~^
t0,
^^^
...<4 lines>...
solver.term_compatible_contr_kwargs,
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
)
^
File "/Users/jonathanbrodrick/pasteurcodes/diffrax/diffrax/_integrate.py", line 200, in _assert_term_compatible
raise ValueError(
...<3 lines>...
) from e
ValueError: Terms are not compatible with solver! Got:
ODETerm(vector_field=<function vf>)
but expected:
diffrax.AbstractTerm
Note that terms are checked recursively: if you scroll up you may find a root-cause error that is more specific.
--------------------
For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.
It would be preferable to return the stack trace ending with NameError: name 'x' is not defined which is the real issue.
Metadata
Metadata
Assignees
Labels
No labels