Skip to content

Commit acb8ccd

Browse files
Fix warnings
1 parent ea1bdc9 commit acb8ccd

File tree

2 files changed

+6
-2
lines changed

2 files changed

+6
-2
lines changed

diffrax/brownian/path.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,9 @@ def __init__(
4343
key: "jax.random.PRNGKey",
4444
):
4545
self.shape = (
46-
jax.ShapeDtypeStruct(shape, None) if is_tuple_of_ints(shape) else shape
46+
jax.ShapeDtypeStruct(shape, jax.dtypes.canonicalize_dtype(None))
47+
if is_tuple_of_ints(shape)
48+
else shape
4749
)
4850
self.key = key
4951
if any(

diffrax/brownian/tree.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,9 @@ def __init__(
7575
self.t1 = t1
7676
self.tol = tol
7777
self.shape = (
78-
jax.ShapeDtypeStruct(shape, None) if is_tuple_of_ints(shape) else shape
78+
jax.ShapeDtypeStruct(shape, jax.dtypes.canonicalize_dtype(None))
79+
if is_tuple_of_ints(shape)
80+
else shape
7981
)
8082
if any(
8183
not jnp.issubdtype(x.dtype, jnp.inexact)

0 commit comments

Comments
 (0)