@@ -196,18 +196,20 @@ function Lux.initialstates(rng::AbstractRNG, n::NeuralDSDE)
196196end
197197
198198function (n:: NeuralDSDE{M} )(x,p,st) where {M<: Lux.AbstractExplicitLayer }
199- function dudt_ (u,p,t;st= st)
200- u_, st. state1 = n. model1 (u,p. p1,st. state1)
199+ st1 = st. state1
200+ st2 = st. state2
201+ function dudt_ (u,p,t;st= st1)
202+ u_, st = n. model1 (u,p. p1,st)
201203 return u_
202204 end
203- function g (u,p,t;st= st )
204- u_, st. state2 = n. model2 (u,p. p2,st. state2 )
205+ function g (u,p,t;st= st2 )
206+ u_, st = n. model2 (u,p. p2,st)
205207 return u_
206208 end
207209
208210 ff = SDEFunction {false} (dudt_,g,tgrad= basic_tgrad)
209211 prob = SDEProblem {false} (ff,g,x,n. tspan,p)
210- return solve (prob,n. args... ;sensealg= TrackerAdjoint (),n. kwargs... ), st
212+ return solve (prob,n. args... ;sensealg= InterpolatingAdjoint (),n. kwargs... ), (state1 = st1, state2 = st2)
211213end
212214
213215"""
@@ -306,18 +308,20 @@ function Lux.initialstates(rng::AbstractRNG, n::NeuralSDE)
306308end
307309
308310function (n:: NeuralSDE{P,M} )(x,p,st) where {P,M<: Lux.AbstractExplicitLayer }
309- function dudt_ (u,p,t;st= st)
310- u_, st. state1 = n. model1 (u,p. p1,st. state1)
311+ st1 = st. state1
312+ st2 = st. state2
313+ function dudt_ (u,p,t;st= st1)
314+ u_, st = n. model1 (u,p. p1,st)
311315 return u_
312316 end
313- function g (u,p,t;st= st )
314- u_, st. state2 = n. model2 (u,p. p2,st. state2 )
317+ function g (u,p,t;st= st2 )
318+ u_, st = n. model2 (u,p. p2,st)
315319 return u_
316320 end
317321
318322 ff = SDEFunction {false} (dudt_,g,tgrad= basic_tgrad)
319323 prob = SDEProblem {false} (ff,g,x,n. tspan,p,noise_rate_prototype= zeros (Float32,length (x),n. nbrown))
320- solve (prob,n. args... ;sensealg= ReverseDiffAdjoint (),n. kwargs... ), st
324+ solve (prob,n. args... ;sensealg= InterpolatingAdjoint (),n. kwargs... ), (state1 = st1, state2 = st2)
321325end
322326
323327"""
0 commit comments