Skip to content

Commit 5a1a3fa

Browse files
committed
filter kwargs
1 parent ad27606 commit 5a1a3fa

File tree

1 file changed

+34
-15
lines changed

1 file changed

+34
-15
lines changed

bayesflow/utils/integrate.py

Lines changed: 34 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -73,16 +73,21 @@ def rk45_step(
7373

7474
if k1 is None: # reuse k1 if available
7575
k1 = fn(time, **filter_kwargs(state, fn))
76-
k2 = fn(time + h * (1 / 5), **add_scaled(state, [k1], [1 / 5], h))
77-
k3 = fn(time + h * (3 / 10), **add_scaled(state, [k1, k2], [3 / 40, 9 / 40], h))
78-
k4 = fn(time + h * (4 / 5), **add_scaled(state, [k1, k2, k3], [44 / 45, -56 / 15, 32 / 9], h))
76+
k2 = fn(time + h * (1 / 5), **filter_kwargs(add_scaled(state, [k1], [1 / 5], h), fn))
77+
k3 = fn(time + h * (3 / 10), **filter_kwargs(add_scaled(state, [k1, k2], [3 / 40, 9 / 40], h), fn))
78+
k4 = fn(time + h * (4 / 5), **filter_kwargs(add_scaled(state, [k1, k2, k3], [44 / 45, -56 / 15, 32 / 9], h), fn))
7979
k5 = fn(
8080
time + h * (8 / 9),
81-
**add_scaled(state, [k1, k2, k3, k4], [19372 / 6561, -25360 / 2187, 64448 / 6561, -212 / 729], h),
81+
**filter_kwargs(
82+
add_scaled(state, [k1, k2, k3, k4], [19372 / 6561, -25360 / 2187, 64448 / 6561, -212 / 729], h), fn
83+
),
8284
)
8385
k6 = fn(
8486
time + h,
85-
**add_scaled(state, [k1, k2, k3, k4, k5], [9017 / 3168, -355 / 33, 46732 / 5247, 49 / 176, -5103 / 18656], h),
87+
**filter_kwargs(
88+
add_scaled(state, [k1, k2, k3, k4, k5], [9017 / 3168, -355 / 33, 46732 / 5247, 49 / 176, -5103 / 18656], h),
89+
fn,
90+
),
8691
)
8792

8893
# 5th order solution
@@ -140,24 +145,38 @@ def tsit5_step(
140145

141146
if k1 is None: # reuse k1 if available
142147
k1 = fn(time, **filter_kwargs(state, fn))
143-
k2 = fn(time + h * c2, **add_scaled(state, [k1], [0.161], h))
144-
k3 = fn(time + h * c3, **add_scaled(state, [k1, k2], [-0.0084806554923570, 0.3354806554923570], h))
148+
k2 = fn(time + h * c2, **filter_kwargs(add_scaled(state, [k1], [0.161], h), fn))
149+
k3 = fn(
150+
time + h * c3, **filter_kwargs(add_scaled(state, [k1, k2], [-0.0084806554923570, 0.3354806554923570], h), fn)
151+
)
145152
k4 = fn(
146-
time + h * c4, **add_scaled(state, [k1, k2, k3], [2.897153057105494, -6.359448489975075, 4.362295432869581], h)
153+
time + h * c4,
154+
**filter_kwargs(
155+
add_scaled(state, [k1, k2, k3], [2.897153057105494, -6.359448489975075, 4.362295432869581], h), fn
156+
),
147157
)
148158
k5 = fn(
149159
time + h * c5,
150-
**add_scaled(
151-
state, [k1, k2, k3, k4], [5.325864828439257, -11.74888356406283, 7.495539342889836, -0.09249506636175525], h
160+
**filter_kwargs(
161+
add_scaled(
162+
state,
163+
[k1, k2, k3, k4],
164+
[5.325864828439257, -11.74888356406283, 7.495539342889836, -0.09249506636175525],
165+
h,
166+
),
167+
fn,
152168
),
153169
)
154170
k6 = fn(
155171
time + h,
156-
**add_scaled(
157-
state,
158-
[k1, k2, k3, k4, k5],
159-
[5.86145544294270, -12.92096931784711, 8.159367898576159, -0.07158497328140100, -0.02826905039406838],
160-
h,
172+
**filter_kwargs(
173+
add_scaled(
174+
state,
175+
[k1, k2, k3, k4, k5],
176+
[5.86145544294270, -12.92096931784711, 8.159367898576159, -0.07158497328140100, -0.02826905039406838],
177+
h,
178+
),
179+
fn,
161180
),
162181
)
163182

0 commit comments

Comments
 (0)