Skip to content

Commit dd021bb

Browse files
committed
fix schedule test
1 parent 5c5abd3 commit dd021bb

File tree

1 file changed

+8
-11
lines changed

1 file changed

+8
-11
lines changed

tests/test_utils/test_integrate.py

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -13,23 +13,20 @@
1313
TOL_DET = 1e-3
1414

1515

16-
def test_scheduled_integration():
17-
import keras
18-
from bayesflow.utils import integrate
19-
16+
@pytest.mark.parametrize("method", ["euler", "rk45", "tsit5"])
17+
def test_scheduled_integration(method):
2018
def fn(t, x):
2119
return {"x": t**2}
2220

23-
steps = keras.ops.convert_to_tensor([0.0, 0.5, 1.0])
24-
approximate_result = 0.0 + 0.5**2 * 0.5
25-
result = integrate(fn, {"x": 0.0}, steps=steps)["x"]
26-
assert result == approximate_result
21+
def analytical_result(t):
22+
return (t**3) / 3.0
2723

24+
steps = keras.ops.arange(0.0, 1.0 + 1e-6, 0.01)
25+
result = integrate(fn, {"x": 0.0}, steps=steps, method=method)["x"]
26+
np.testing.assert_allclose(result, analytical_result(steps[-1]), atol=1e-1, rtol=1e-1)
2827

29-
def test_scipy_integration():
30-
import keras
31-
from bayesflow.utils import integrate
3228

29+
def test_scipy_integration():
3330
def fn(t, x):
3431
return {"x": keras.ops.exp(t)}
3532

0 commit comments

Comments
 (0)