88TOLERANCE_EULER = 1e-3 # Euler with fixed steps requires a larger tolerance
99
1010# tolerances for SDE tests
11- TOL_MEAN = 3e -2
11+ TOL_MEAN = 5e -2
1212TOL_VAR = 5e-2
13- TOL_DET = 1e-3
1413
1514
1615@pytest .mark .parametrize ("method" , ["euler" , "rk45" , "tsit5" ])
@@ -123,7 +122,6 @@ def test_forward_additive_ou_weak_means_and_vars(method, use_adapt):
123122 x_0 = 1.2 # initial condition at time 0
124123 T = 1.0
125124
126- # batch of trajectories
127125 N = 10000
128126 seed = keras .random .SeedGenerator (42 )
129127
@@ -149,15 +147,14 @@ def diffusion_fn(t, x):
149147 steps = steps ,
150148 seed = seed ,
151149 method = method ,
152- max_steps = 1_000 ,
153150 )
154151
155152 x_T = np .array (out ["x" ])
156153 emp_mean = float (x_T .mean ())
157154 emp_var = float (x_T .var ())
158155
159- np .testing .assert_allclose (emp_mean , exp_mean , atol = TOL_MEAN , rtol = 0.0 )
160- np .testing .assert_allclose (emp_var , exp_var , atol = TOL_VAR , rtol = 0.0 )
156+ np .testing .assert_allclose (emp_mean , exp_mean , atol = TOL_MEAN )
157+ np .testing .assert_allclose (emp_var , exp_var , atol = TOL_VAR )
161158
162159
163160@pytest .mark .parametrize (
@@ -188,8 +185,7 @@ def test_backward_additive_ou_weak_means_and_vars(method, use_adapt):
188185 x_T = 1.2 # initial condition at time T
189186 T = 1.0
190187
191- # batch of trajectories
192- N = 10000 # large enough to control sampling error
188+ N = 10000
193189 seed = keras .random .SeedGenerator (42 )
194190
195191 def drift_fn (t , x ):
@@ -216,15 +212,14 @@ def diffusion_fn(t, x):
216212 steps = steps ,
217213 seed = seed ,
218214 method = method ,
219- max_steps = 1_000 ,
220215 )
221216
222217 x_0 = np .array (out ["x" ])
223218 emp_mean = float (x_0 .mean ())
224219 emp_var = float (x_0 .var ())
225220
226- np .testing .assert_allclose (emp_mean , exp_mean , atol = TOL_MEAN , rtol = 0.0 )
227- np .testing .assert_allclose (emp_var , exp_var , atol = TOL_VAR , rtol = 0.0 )
221+ np .testing .assert_allclose (emp_mean , exp_mean , atol = TOL_MEAN )
222+ np .testing .assert_allclose (emp_var , exp_var , atol = TOL_VAR )
228223
229224
230225@pytest .mark .parametrize (
@@ -270,7 +265,7 @@ def diffusion_fn(t, x):
270265 )["x" ]
271266
272267 exact = x0 * np .exp (a * T )
273- np .testing .assert_allclose (np .array (out ).mean (), exact , atol = TOL_DET , rtol = 0.1 )
268+ np .testing .assert_allclose (np .array (out ).mean (), exact , atol = 1e-3 , rtol = 0.1 )
274269
275270
276271@pytest .mark .parametrize ("steps" , [500 ])
0 commit comments