@@ -73,16 +73,21 @@ def rk45_step(
7373
7474 if k1 is None : # reuse k1 if available
7575 k1 = fn (time , ** filter_kwargs (state , fn ))
76- k2 = fn (time + h * (1 / 5 ), ** add_scaled (state , [k1 ], [1 / 5 ], h ))
77- k3 = fn (time + h * (3 / 10 ), ** add_scaled (state , [k1 , k2 ], [3 / 40 , 9 / 40 ], h ))
78- k4 = fn (time + h * (4 / 5 ), ** add_scaled (state , [k1 , k2 , k3 ], [44 / 45 , - 56 / 15 , 32 / 9 ], h ))
76+ k2 = fn (time + h * (1 / 5 ), ** filter_kwargs ( add_scaled (state , [k1 ], [1 / 5 ], h ), fn ))
77+ k3 = fn (time + h * (3 / 10 ), ** filter_kwargs ( add_scaled (state , [k1 , k2 ], [3 / 40 , 9 / 40 ], h ), fn ))
78+ k4 = fn (time + h * (4 / 5 ), ** filter_kwargs ( add_scaled (state , [k1 , k2 , k3 ], [44 / 45 , - 56 / 15 , 32 / 9 ], h ), fn ))
7979 k5 = fn (
8080 time + h * (8 / 9 ),
81- ** add_scaled (state , [k1 , k2 , k3 , k4 ], [19372 / 6561 , - 25360 / 2187 , 64448 / 6561 , - 212 / 729 ], h ),
81+ ** filter_kwargs (
82+ add_scaled (state , [k1 , k2 , k3 , k4 ], [19372 / 6561 , - 25360 / 2187 , 64448 / 6561 , - 212 / 729 ], h ), fn
83+ ),
8284 )
8385 k6 = fn (
8486 time + h ,
85- ** add_scaled (state , [k1 , k2 , k3 , k4 , k5 ], [9017 / 3168 , - 355 / 33 , 46732 / 5247 , 49 / 176 , - 5103 / 18656 ], h ),
87+ ** filter_kwargs (
88+ add_scaled (state , [k1 , k2 , k3 , k4 , k5 ], [9017 / 3168 , - 355 / 33 , 46732 / 5247 , 49 / 176 , - 5103 / 18656 ], h ),
89+ fn ,
90+ ),
8691 )
8792
8893 # 5th order solution
@@ -140,24 +145,38 @@ def tsit5_step(
140145
141146 if k1 is None : # reuse k1 if available
142147 k1 = fn (time , ** filter_kwargs (state , fn ))
143- k2 = fn (time + h * c2 , ** add_scaled (state , [k1 ], [0.161 ], h ))
144- k3 = fn (time + h * c3 , ** add_scaled (state , [k1 , k2 ], [- 0.0084806554923570 , 0.3354806554923570 ], h ))
148+ k2 = fn (time + h * c2 , ** filter_kwargs (add_scaled (state , [k1 ], [0.161 ], h ), fn ))
149+ k3 = fn (
150+ time + h * c3 , ** filter_kwargs (add_scaled (state , [k1 , k2 ], [- 0.0084806554923570 , 0.3354806554923570 ], h ), fn )
151+ )
145152 k4 = fn (
146- time + h * c4 , ** add_scaled (state , [k1 , k2 , k3 ], [2.897153057105494 , - 6.359448489975075 , 4.362295432869581 ], h )
153+ time + h * c4 ,
154+ ** filter_kwargs (
155+ add_scaled (state , [k1 , k2 , k3 ], [2.897153057105494 , - 6.359448489975075 , 4.362295432869581 ], h ), fn
156+ ),
147157 )
148158 k5 = fn (
149159 time + h * c5 ,
150- ** add_scaled (
151- state , [k1 , k2 , k3 , k4 ], [5.325864828439257 , - 11.74888356406283 , 7.495539342889836 , - 0.09249506636175525 ], h
160+ ** filter_kwargs (
161+ add_scaled (
162+ state ,
163+ [k1 , k2 , k3 , k4 ],
164+ [5.325864828439257 , - 11.74888356406283 , 7.495539342889836 , - 0.09249506636175525 ],
165+ h ,
166+ ),
167+ fn ,
152168 ),
153169 )
154170 k6 = fn (
155171 time + h ,
156- ** add_scaled (
157- state ,
158- [k1 , k2 , k3 , k4 , k5 ],
159- [5.86145544294270 , - 12.92096931784711 , 8.159367898576159 , - 0.07158497328140100 , - 0.02826905039406838 ],
160- h ,
172+ ** filter_kwargs (
173+ add_scaled (
174+ state ,
175+ [k1 , k2 , k3 , k4 , k5 ],
176+ [5.86145544294270 , - 12.92096931784711 , 8.159367898576159 , - 0.07158497328140100 , - 0.02826905039406838 ],
177+ h ,
178+ ),
179+ fn ,
161180 ),
162181 )
163182
0 commit comments