-
-
Notifications
You must be signed in to change notification settings - Fork 163
Description
I can into this rather mysterious issue which caused my code to hang indefinitely, not responding to cntrl + C either. This issue arises in the following (admittedly rather specific) setting:
- I use
jnp.int32s instead of python'sints when passing a shape toShapeDtypeStruct - VirtualBrownianTree.evaluate() is inside a
jit'ted (oreqx.filter_jit) function, which is defined and called more than once
The first time the function is called, it compiles and runs perfectly fine. The second time, however, the JIT'ted function begins compiling, but never finishes, see the MWE below. Note that this only occurs when using jnp.int32s to pass a shape, when using regular int this code works fine. The same issue occurs with UnsafeBrownianPath.
MWE:
import diffrax as dfx
import jax
from jax import numpy as jnp
from jax import random as jr
def main():
def generate_path(shape):
shape_dtype = jax.ShapeDtypeStruct(
shape=shape,
dtype=jnp.float32,
)
@jax.jit
def eval_path(path, t0, t1):
print("Compiling path evaluation...")
return path.evaluate(t0=t0, t1=t1)
path = dfx.VirtualBrownianTree(
t0=0,
t1=1,
tol=1e-5,
shape=shape_dtype,
key=jr.PRNGKey(0),
levy_area=dfx.SpaceTimeLevyArea,
)
eval_path(path, 0.0, 1.0)
print("Path evaluated successfully.")
generate_path(shape=(jnp.int32(2),)) # Works fine
generate_path(shape=(jnp.int32(2),)) # Hangs indefinitely
if __name__ == "__main__":
main()
Output:
Compiling path evaluation...
Path evaluated successfully.
Compiling path evaluation...
Perhaps using jnp.int32 to create a ShapeDtypeStruct is simply wrong here, but this is not obvious to me at least an error/warning would be warranted, I suppose. Since scalars are zero-dimensional arrays in JAX, the ShapeDtypeStructs you get with different dtypes for the shape argument are slightly different:
shape_int32 = jax.ShapeDtypeStruct(shape=(jnp.int32(2),), dtype=jnp.float32)
# shape_int32.shape = (Array(2, dtype=int32),)
# shape_int32.shape[0] = 2
# type(shape_int32[0]) = <class 'jaxlib._jax.ArrayImpl'>
shape_int = jax.ShapeDtypeStruct(shape=(2,), dtype=jnp.float32)
# shape_int.shape = (2,)
# shape_int.shape[0] = 2
# type(shape_int.shape[0])= <class 'int'>
The issue can be easily avoided by just using int instead of jnp.int32, but I am very curious why this would make such a difference.
I tested the MWE both in JAX versions 0.7.0 and 0.6.2 with identical results.