Skip to content

Improve error messages #681

@jpbrodrick89

Description

@jpbrodrick89

diffeqsolve traces vf from ODETerms before running to check that shapes are compatible, however when a user has a bug in their vf traversing the stack trace can be quite cumbersome (typically the problem is about 30-50% through the stack trace). The lowest error shown is just that terms are not compatible which is not helpful, scrawling through this when y0 is a complicated pytree and vf a complicated eqx.Module can be quite cumbersome. When an error occurs is it possible to exit earlier/truncate the unnecessary diffeqsolve stack trace?

MWE:

import jax.numpy as jnp
import diffrax

def f(t, y, args):
    return x

diffrax.diffeqsolve(diffrax.ODETerm(vf), diffrax.Euler(), 0.0, 1.0, 0.1, jnp.zeros(1))
Gives the following stack trace
Traceback (most recent call last):
  File "/Users/jonathanbrodrick/pasteurcodes/diffrax/diffrax/_integrate.py", line 165, in _check
    vf_type = eqx.filter_eval_shape(term.vf, t, yi, args)
  File "/Users/jonathanbrodrick/.virtualenvs/ergodic/lib/python3.13/site-packages/equinox/_eval_shape.py", line 38, in filter_eval_shape
    dynamic_out, static_out = jax.eval_shape(ft.partial(_fn, static), dynamic)
                              ~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/jonathanbrodrick/.virtualenvs/ergodic/lib/python3.13/site-packages/jax/_src/traceback_util.py", line 182, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/Users/jonathanbrodrick/.virtualenvs/ergodic/lib/python3.13/site-packages/jax/_src/api.py", line 3014, in eval_shape
    return jit(fun).eval_shape(*args, **kwargs)
           ~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^
  File "/Users/jonathanbrodrick/.virtualenvs/ergodic/lib/python3.13/site-packages/jax/_src/traceback_util.py", line 182, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/Users/jonathanbrodrick/.virtualenvs/ergodic/lib/python3.13/site-packages/jax/_src/pjit.py", line 352, in jit_eval_shape
    p, _ = _infer_params(jit_func._fun, jit_func._jit_info, args, kwargs)
           ~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/jonathanbrodrick/.virtualenvs/ergodic/lib/python3.13/site-packages/jax/_src/pjit.py", line 686, in _infer_params
    return _infer_params_internal(fun, ji, args, kwargs)
  File "/Users/jonathanbrodrick/.virtualenvs/ergodic/lib/python3.13/site-packages/jax/_src/pjit.py", line 710, in _infer_params_internal
    p, args_flat = _infer_params_impl(
                   ~~~~~~~~~~~~~~~~~~^
        fun, ji, ctx_mesh, dbg, args, kwargs, in_avals=avals)
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/jonathanbrodrick/.virtualenvs/ergodic/lib/python3.13/site-packages/jax/_src/pjit.py", line 606, in _infer_params_impl
    jaxpr, consts, out_avals, attrs_tracked = _create_pjit_jaxpr(
                                              ~~~~~~~~~~~~~~~~~~^
        flat_fun, in_type, attr_token, IgnoreKey(ji.inline))
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/jonathanbrodrick/.virtualenvs/ergodic/lib/python3.13/site-packages/jax/_src/linear_util.py", line 471, in memoized_fun
    ans = call(fun, *args)
  File "/Users/jonathanbrodrick/.virtualenvs/ergodic/lib/python3.13/site-packages/jax/_src/pjit.py", line 1414, in _create_pjit_jaxpr
    jaxpr, global_out_avals, consts, attrs_tracked = pe.trace_to_jaxpr_dynamic(fun, in_type)
                                                     ~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^
  File "/Users/jonathanbrodrick/.virtualenvs/ergodic/lib/python3.13/site-packages/jax/_src/profiler.py", line 354, in wrapper
    return func(*args, **kwargs)
  File "/Users/jonathanbrodrick/.virtualenvs/ergodic/lib/python3.13/site-packages/jax/_src/interpreters/partial_eval.py", line 2292, in trace_to_jaxpr_dynamic
    ans = fun.call_wrapped(*in_tracers)
  File "/Users/jonathanbrodrick/.virtualenvs/ergodic/lib/python3.13/site-packages/jax/_src/linear_util.py", line 211, in call_wrapped
    return self.f_transformed(*args, **kwargs)
           ~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^
  File "/Users/jonathanbrodrick/.virtualenvs/ergodic/lib/python3.13/site-packages/jax/_src/api_util.py", line 288, in _argnums_partial
    return _fun(*args, **kwargs)
  File "/Users/jonathanbrodrick/.virtualenvs/ergodic/lib/python3.13/site-packages/jax/_src/api_util.py", line 73, in flatten_fun
    ans = f(*py_args, **py_kwargs)
  File "/Users/jonathanbrodrick/.virtualenvs/ergodic/lib/python3.13/site-packages/jax/_src/linear_util.py", line 396, in _get_result_paths_thunk
    ans = _fun(*args, **kwargs)
  File "/Users/jonathanbrodrick/.virtualenvs/ergodic/lib/python3.13/site-packages/equinox/_eval_shape.py", line 33, in _fn
    _out = _fun(*_args, **_kwargs)
  File "/Users/jonathanbrodrick/.virtualenvs/ergodic/lib/python3.13/site-packages/equinox/_module.py", line 1060, in __call__
    return self.__func__(self.__self__, *args, **kwargs)
           ~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/jonathanbrodrick/pasteurcodes/diffrax/diffrax/_term.py", line 194, in vf
    out = self.vector_field(t, y, args)
  File "<python-input-3>", line 2, in vf
    return x
           ^
NameError: name 'x' is not defined

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/Users/jonathanbrodrick/pasteurcodes/diffrax/diffrax/_integrate.py", line 195, in _assert_term_compatible
    jtu.tree_map(_check, term_structure, terms, contr_kwargs, y)
    ~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/jonathanbrodrick/.virtualenvs/ergodic/lib/python3.13/site-packages/jax/_src/tree_util.py", line 362, in tree_map
    return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
           ~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/jonathanbrodrick/.virtualenvs/ergodic/lib/python3.13/site-packages/jax/_src/tree_util.py", line 362, in <genexpr>
    return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
                             ~^^^^^
  File "/Users/jonathanbrodrick/pasteurcodes/diffrax/diffrax/_integrate.py", line 167, in _check
    raise ValueError(f"Error while tracing {term}.vf: " + str(e))
ValueError: Error while tracing ODETerm(vector_field=<function vf>).vf: name 'x' is not defined

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "<python-input-3>", line 3, in <module>
    diffrax.diffeqsolve(diffrax.ODETerm(vf), diffrax.Euler(), 0.0, 1.0, 0.1, jnp.zeros(1))
    ~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/jonathanbrodrick/.virtualenvs/ergodic/lib/python3.13/site-packages/equinox/_jit.py", line 209, in __call__
    return _call(self, False, args, kwargs)
  File "/Users/jonathanbrodrick/.virtualenvs/ergodic/lib/python3.13/site-packages/equinox/_jit.py", line 263, in _call
    marker, _, _ = out = jit_wrapper._cached(
                         ~~~~~~~~~~~~~~~~~~~^
        dynamic_donate, dynamic_nodonate, static
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    )
    ^
  File "/Users/jonathanbrodrick/pasteurcodes/diffrax/diffrax/_integrate.py", line 1117, in diffeqsolve
    _assert_term_compatible(
    ~~~~~~~~~~~~~~~~~~~~~~~^
        t0,
        ^^^
    ...<4 lines>...
        solver.term_compatible_contr_kwargs,
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    )
    ^
  File "/Users/jonathanbrodrick/pasteurcodes/diffrax/diffrax/_integrate.py", line 200, in _assert_term_compatible
    raise ValueError(
    ...<3 lines>...
    ) from e
ValueError: Terms are not compatible with solver! Got:
ODETerm(vector_field=<function vf>)
but expected:
diffrax.AbstractTerm
Note that terms are checked recursively: if you scroll up you may find a root-cause error that is more specific.
--------------------
For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

It would be preferable to return the stack trace ending with NameError: name 'x' is not defined which is the real issue.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions