1- abstract type NeuralDELayer <: Function end
1+ abstract type NeuralDELayer <: Lux.AbstractExplicitLayer end
22basic_tgrad (u,p,t) = zero (u)
33Flux. trainable (m:: NeuralDELayer ) = (m. p,)
44
@@ -69,6 +69,9 @@ struct NeuralODE{M,P,RE,T,A,K} <: NeuralDELayer
6969 end
7070end
7171
72+ Lux. initialparameters (rng:: AbstractRNG , n:: NeuralODE ) = Lux. initialparameters (rng, n. model)
73+ Lux. initialstates (rng:: AbstractRNG , n:: NeuralODE ) = Lux. initialstates (rng, n. model)
74+
7275function (n:: NeuralODE )(x,p= n. p)
7376 dudt_ (u,p,t) = n. re (p)(u)
7477 ff = ODEFunction {false} (dudt_,tgrad= basic_tgrad)
@@ -86,10 +89,11 @@ function (n::NeuralODE{M})(x,p=n.p) where {M<:FastChain}
8689end
8790
8891function (n:: NeuralODE{M} )(x,p,st) where {M<: Lux.AbstractExplicitLayer }
89- function dudt (u,p,t)
92+ function dudt (u,p,t;st = st )
9093 u_, st = n. model (u,p,st)
9194 return u_
9295 end
96+
9397 ff = ODEFunction {false} (dudt,tgrad= basic_tgrad)
9498 prob = ODEProblem {false} (ff,x,n. tspan,p)
9599 sense = InterpolatingAdjoint (autojacvec= ZygoteVJP ())
@@ -179,19 +183,33 @@ function (n::NeuralDSDE{M})(x,p=n.p) where {M<:FastChain}
179183 solve (prob,n. args... ;sensealg= TrackerAdjoint (),n. kwargs... )
180184end
181185
182- function (n:: NeuralDSDE{M} )(x,p,st1,st2) where {M<: Lux.AbstractExplicitLayer }
183- function dudt_ (u,p,t)
184- u_, st1 = n. model1 (u,p[1 ],st1)
186+ function Lux. initialparameters (rng:: AbstractRNG , n:: NeuralDSDE )
187+ p1 = Lux. initialparameters (rng, n. model1)
188+ p2 = Lux. initialparameters (rng, n. model2)
189+ return Lux. ComponentArray ((p1 = p1, p2 = p2))
190+ end
191+
192+ function Lux. initialstates (rng:: AbstractRNG , n:: NeuralDSDE )
193+ st1 = Lux. initialstates (rng, n. model1)
194+ st2 = Lux. initialstates (rng, n. model2)
195+ return (state1 = st1, state2 = st2)
196+ end
197+
198+ function (n:: NeuralDSDE{M} )(x,p,st) where {M<: Lux.AbstractExplicitLayer }
199+ st1 = st. state1
200+ st2 = st. state2
201+ function dudt_ (u,p,t;st= st1)
202+ u_, st = n. model1 (u,p. p1,st)
185203 return u_
186204 end
187- function g (u,p,t)
188- u_, st2 = n. model2 (u,p[ 2 ],st2 )
205+ function g (u,p,t;st = st2 )
206+ u_, st = n. model2 (u,p. p2,st )
189207 return u_
190208 end
191209
192210 ff = SDEFunction {false} (dudt_,g,tgrad= basic_tgrad)
193211 prob = SDEProblem {false} (ff,g,x,n. tspan,p)
194- return solve (prob,n. args... ;sensealg= TrackerAdjoint (),n. kwargs... ), st1, st2
212+ return solve (prob,n. args... ;sensealg= InterpolatingAdjoint (),n. kwargs... ), (state1 = st1, state2 = st2)
195213end
196214
197215"""
@@ -251,6 +269,15 @@ struct NeuralSDE{P,M,RE,M2,RE2,T,A,K} <: NeuralDELayer
251269 typeof (tspan),typeof (args),typeof (kwargs)}(
252270 p,length (p1),model1,re1,model2,re2,tspan,nbrown,args,kwargs)
253271 end
272+
273+ function NeuralSDE (model1:: Lux.AbstractExplicitLayer , model2:: Lux.AbstractExplicitLayer ,tspan,nbrown,args... ;
274+ p1 = nothing , p = nothing , kwargs... )
275+ re1 = nothing
276+ re2 = nothing
277+ new{typeof (p),typeof (model1),typeof (re1),typeof (model2),typeof (re2),
278+ typeof (tspan),typeof (args),typeof (kwargs)}(
279+ p,Int (1 ),model1,re1,model2,re2,tspan,nbrown,args,kwargs)
280+ end
254281end
255282
256283function (n:: NeuralSDE )(x,p= n. p)
@@ -269,6 +296,34 @@ function (n::NeuralSDE{P,M})(x,p=n.p) where {P,M<:FastChain}
269296 solve (prob,n. args... ;sensealg= TrackerAdjoint (),n. kwargs... )
270297end
271298
299+ function Lux. initialparameters (rng:: AbstractRNG , n:: NeuralSDE )
300+ p1 = Lux. initialparameters (rng, n. model1)
301+ p2 = Lux. initialparameters (rng, n. model2)
302+ return Lux. ComponentArray ((p1 = p1, p2 = p2))
303+ end
304+ function Lux. initialstates (rng:: AbstractRNG , n:: NeuralSDE )
305+ st1 = Lux. initialstates (rng, n. model1)
306+ st2 = Lux. initialstates (rng, n. model2)
307+ return (state1 = st1, state2 = st2)
308+ end
309+
310+ function (n:: NeuralSDE{P,M} )(x,p,st) where {P,M<: Lux.AbstractExplicitLayer }
311+ st1 = st. state1
312+ st2 = st. state2
313+ function dudt_ (u,p,t;st= st1)
314+ u_, st = n. model1 (u,p. p1,st)
315+ return u_
316+ end
317+ function g (u,p,t;st= st2)
318+ u_, st = n. model2 (u,p. p2,st)
319+ return u_
320+ end
321+
322+ ff = SDEFunction {false} (dudt_,g,tgrad= basic_tgrad)
323+ prob = SDEProblem {false} (ff,g,x,n. tspan,p,noise_rate_prototype= zeros (Float32,length (x),n. nbrown))
324+ solve (prob,n. args... ;sensealg= InterpolatingAdjoint (),n. kwargs... ), (state1 = st1, state2 = st2)
325+ end
326+
272327"""
273328Constructs a neural delay differential equation (neural DDE) with constant
274329delays.
0 commit comments