Skip to content

Commit 327611f

Browse files
committed
fix
1 parent f862173 commit 327611f

File tree

1 file changed

+14
-10
lines changed

1 file changed

+14
-10
lines changed

src/neural_de.jl

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -196,18 +196,20 @@ function Lux.initialstates(rng::AbstractRNG, n::NeuralDSDE)
196196
end
197197

198198
function (n::NeuralDSDE{M})(x,p,st) where {M<:Lux.AbstractExplicitLayer}
199-
function dudt_(u,p,t;st=st)
200-
u_, st.state1 = n.model1(u,p.p1,st.state1)
199+
st1 = st.state1
200+
st2 = st.state2
201+
function dudt_(u,p,t;st=st1)
202+
u_, st = n.model1(u,p.p1,st)
201203
return u_
202204
end
203-
function g(u,p,t;st=st)
204-
u_, st.state2 = n.model2(u,p.p2,st.state2)
205+
function g(u,p,t;st=st2)
206+
u_, st = n.model2(u,p.p2,st)
205207
return u_
206208
end
207209

208210
ff = SDEFunction{false}(dudt_,g,tgrad=basic_tgrad)
209211
prob = SDEProblem{false}(ff,g,x,n.tspan,p)
210-
return solve(prob,n.args...;sensealg=TrackerAdjoint(),n.kwargs...), st
212+
return solve(prob,n.args...;sensealg=InterpolatingAdjoint(),n.kwargs...), (state1 = st1, state2 = st2)
211213
end
212214

213215
"""
@@ -306,18 +308,20 @@ function Lux.initialstates(rng::AbstractRNG, n::NeuralSDE)
306308
end
307309

308310
function (n::NeuralSDE{P,M})(x,p,st) where {P,M<:Lux.AbstractExplicitLayer}
309-
function dudt_(u,p,t;st=st)
310-
u_, st.state1 = n.model1(u,p.p1,st.state1)
311+
st1 = st.state1
312+
st2 = st.state2
313+
function dudt_(u,p,t;st=st1)
314+
u_, st = n.model1(u,p.p1,st)
311315
return u_
312316
end
313-
function g(u,p,t;st=st)
314-
u_, st.state2 = n.model2(u,p.p2,st.state2)
317+
function g(u,p,t;st=st2)
318+
u_, st = n.model2(u,p.p2,st)
315319
return u_
316320
end
317321

318322
ff = SDEFunction{false}(dudt_,g,tgrad=basic_tgrad)
319323
prob = SDEProblem{false}(ff,g,x,n.tspan,p,noise_rate_prototype=zeros(Float32,length(x),n.nbrown))
320-
solve(prob,n.args...;sensealg=ReverseDiffAdjoint(),n.kwargs...), st
324+
solve(prob,n.args...;sensealg=InterpolatingAdjoint(),n.kwargs...), (state1 = st1, state2 = st2)
321325
end
322326

323327
"""

0 commit comments

Comments
 (0)