Skip to content

JIT compiling VirtualBrownianTree.evaluate() hangs indefinitely when using jnp.int32 to pass shape to ShapeDtypeStruct #707

@alexander-de-ranitz

Description

@alexander-de-ranitz

I can into this rather mysterious issue which caused my code to hang indefinitely, not responding to cntrl + C either. This issue arises in the following (admittedly rather specific) setting:

  • I use jnp.int32s instead of python's ints when passing a shape to ShapeDtypeStruct
  • VirtualBrownianTree.evaluate() is inside a jit'ted (or eqx.filter_jit) function, which is defined and called more than once

The first time the function is called, it compiles and runs perfectly fine. The second time, however, the JIT'ted function begins compiling, but never finishes, see the MWE below. Note that this only occurs when using jnp.int32s to pass a shape, when using regular int this code works fine. The same issue occurs with UnsafeBrownianPath.

MWE:

import diffrax as dfx
import jax
from jax import numpy as jnp
from jax import random as jr

def main(): 
    def generate_path(shape): 
        shape_dtype = jax.ShapeDtypeStruct(
            shape=shape,
            dtype=jnp.float32,
        )

        @jax.jit
        def eval_path(path, t0, t1):
            print("Compiling path evaluation...")
            return path.evaluate(t0=t0, t1=t1)
        
        path = dfx.VirtualBrownianTree(
            t0=0,
            t1=1,
            tol=1e-5,
            shape=shape_dtype,
            key=jr.PRNGKey(0),
            levy_area=dfx.SpaceTimeLevyArea,
        )
        eval_path(path, 0.0, 1.0)
        print("Path evaluated successfully.")

    generate_path(shape=(jnp.int32(2),)) # Works fine
    generate_path(shape=(jnp.int32(2),)) # Hangs indefinitely

if __name__ == "__main__":
    main()

Output:

Compiling path evaluation...
Path evaluated successfully.
Compiling path evaluation...

Perhaps using jnp.int32 to create a ShapeDtypeStruct is simply wrong here, but this is not obvious to me at least an error/warning would be warranted, I suppose. Since scalars are zero-dimensional arrays in JAX, the ShapeDtypeStructs you get with different dtypes for the shape argument are slightly different:

shape_int32 = jax.ShapeDtypeStruct(shape=(jnp.int32(2),), dtype=jnp.float32) 
# shape_int32.shape = (Array(2, dtype=int32),)  
# shape_int32.shape[0] = 2 
# type(shape_int32[0]) = <class 'jaxlib._jax.ArrayImpl'>

shape_int = jax.ShapeDtypeStruct(shape=(2,), dtype=jnp.float32) 
# shape_int.shape = (2,) 
# shape_int.shape[0] = 2  
# type(shape_int.shape[0])= <class 'int'>

The issue can be easily avoided by just using int instead of jnp.int32, but I am very curious why this would make such a difference.

I tested the MWE both in JAX versions 0.7.0 and 0.6.2 with identical results.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions