From 4fc346a39c890a9fd49a0c976234a728f8ef1bc9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Maximilian=20St=C3=B6lzle?= Date: Wed, 13 Sep 2023 13:13:35 +0000 Subject: [PATCH] Fix assert error for type of `keep_step` --- diffrax/integrate.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/diffrax/integrate.py b/diffrax/integrate.py index 814c90d9..1e814f9f 100644 --- a/diffrax/integrate.py +++ b/diffrax/integrate.py @@ -249,7 +249,7 @@ def body_fun(state): error_order, state.controller_state, ) - assert jnp.result_type(keep_step) is jnp.dtype(bool) + assert jnp.result_type(keep_step) in [bool, jnp.dtype(bool)] # # Do some book-keeping.