@@ -121,7 +121,8 @@ Arguments:
121121 documentation for more details.
122122
123123"""
124- struct NeuralDSDE{M,P,RE,M2,RE2,T,A,K} <: NeuralDELayer
124+ # struct NeuralDSDE{M,P,RE,M2,RE2,T,A,K} <: NeuralDELayer
125+ struct NeuralDSDE{M,P,RE,M2,RE2,T,A,K} <: Lux.AbstractExplicitLayer
125126 p:: P
126127 len:: Int
127128 model1:: M
@@ -179,19 +180,31 @@ function (n::NeuralDSDE{M})(x,p=n.p) where {M<:FastChain}
179180 solve (prob,n. args... ;sensealg= TrackerAdjoint (),n. kwargs... )
180181end
181182
182- function (n:: NeuralDSDE{M} )(x,p,st1,st2) where {M<: Lux.AbstractExplicitLayer }
183- function dudt_ (u,p,t)
184- u_, st1 = n. model1 (u,p[1 ],st1)
183+ function initialparameters (rng:: AbstractRNG , n:: NeuralDSDE )
184+ p1 = Lux. initialparameters (rng, n. model1)
185+ p2 = Lux. initialparameters (rng, n. model2)
186+ return Lux. ComponentArray ((p1 = p1, p2 = p2))
187+ end
188+
189+ function initialstates (rng:: AbstractRNG , n:: NeuralDSDE )
190+ st1 = Lux. initialstates (rng, n. model1)
191+ st2 = Lux. initialstates (rng, n. model2)
192+ return (state1 = st1, state2 = st2)
193+ end
194+
195+ function (n:: NeuralDSDE{M} )(x,p,st) where {M<: Lux.AbstractExplicitLayer }
196+ function dudt_ (u,p,t;st= st)
197+ u_, st. state1 = n. model1 (u,p. p1,st. state1)
185198 return u_
186199 end
187- function g (u,p,t)
188- u_, st2 = n. model2 (u,p[ 2 ],st2 )
200+ function g (u,p,t;st = st )
201+ u_, st . state2 = n. model2 (u,p. p2,st . state2 )
189202 return u_
190203 end
191204
192205 ff = SDEFunction {false} (dudt_,g,tgrad= basic_tgrad)
193206 prob = SDEProblem {false} (ff,g,x,n. tspan,p)
194- return solve (prob,n. args... ;sensealg= TrackerAdjoint (),n. kwargs... ), st1, st2
207+ return solve (prob,n. args... ;sensealg= TrackerAdjoint (),n. kwargs... ), st
195208end
196209
197210"""
@@ -251,6 +264,15 @@ struct NeuralSDE{P,M,RE,M2,RE2,T,A,K} <: NeuralDELayer
251264 typeof (tspan),typeof (args),typeof (kwargs)}(
252265 p,length (p1),model1,re1,model2,re2,tspan,nbrown,args,kwargs)
253266 end
267+
268+ function NeuralSDE (model1:: Lux.AbstractExplicitLayer , model2:: Lux.AbstractExplicitLayer ,tspan,nbrown,args... ;
269+ p1 = nothing , p = nothing , kwargs... )
270+ re1 = nothing
271+ re2 = nothing
272+ new{typeof (p),typeof (model1),typeof (re1),typeof (model2),typeof (re2),
273+ typeof (tspan),typeof (args),typeof (kwargs)}(
274+ p,Int (1 ),model1,re1,model2,re2,tspan,nbrown,args,kwargs)
275+ end
254276end
255277
256278function (n:: NeuralSDE )(x,p= n. p)
@@ -269,6 +291,32 @@ function (n::NeuralSDE{P,M})(x,p=n.p) where {P,M<:FastChain}
269291 solve (prob,n. args... ;sensealg= TrackerAdjoint (),n. kwargs... )
270292end
271293
294+ function initialparameters (rng:: AbstractRNG , n:: NeuralSDE )
295+ p1 = initialparameters (rng, n. model1)
296+ p2 = initialparameters (rng, n. model2)
297+ return Lux. ComponentArray ((p1 = p1, p2 = p2))
298+ end
299+ function initialstates (rng:: AbstractRNG , n:: NeuralSDE )
300+ st1 = initialstates (rng, n. model1)
301+ st2 = initialstates (rng, n. model2)
302+ return (state1 = st1, state2 = st2)
303+ end
304+
305+ function (n:: NeuralSDE{P,M} )(x,p,st) where {P,M<: Lux.AbstractExplicitLayer }
306+ function dudt_ (u,p,t;st= st)
307+ u_, st. state1 = n. model1 (u,p. p1,st. state1)
308+ return u_
309+ end
310+ function g (u,p,t;st= st)
311+ u_, st. state2 = n. model2 (u,p. p2,st. state2)
312+ return u_
313+ end
314+
315+ ff = SDEFunction {false} (dudt_,g,tgrad= basic_tgrad)
316+ prob = SDEProblem {false} (ff,g,x,n. tspan,p,noise_rate_prototype= zeros (Float32,length (x),n. nbrown))
317+ solve (prob,n. args... ;sensealg= ReverseDiffAdjoint (),n. kwargs... ), st
318+ end
319+
272320"""
273321Constructs a neural delay differential equation (neural DDE) with constant
274322delays.
0 commit comments