Skip to content

Commit 0920550

Browse files
Doc fix.
1 parent 316fca4 commit 0920550

File tree

1 file changed

+3
-4
lines changed

1 file changed

+3
-4
lines changed

diffrax/brownian/tree.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -202,10 +202,9 @@ def _body_fun(_state):
202202
- `t1`: The start of the interval the Brownian motion is defined over.
203203
- `tol`: The discretisation that `[t0, t1]` is discretised to.
204204
- `shape`: Should be a PyTree of `jax.ShapeDtypeStruct`s, representing the shape,
205-
dtype, and PyTree structure of the output. For simplicity, `shape` can also just
206-
be a tuple of integers, describing the shape of a single JAX array. In that case
207-
the dtype is chosen to be `float64` if `JAX_ENABLE_X64=True` and `float32`
208-
otherwise.
205+
dtype, and PyTree structure of the output. For simplicity, `shape` can also just
206+
be a tuple of integers, describing the shape of a single JAX array. In that case
207+
the dtype is chosen to be the default floating-point dtype.
209208
- `key`: A random key.
210209
211210
!!! info

0 commit comments

Comments
 (0)