1- abstract type NeuralDELayer <: Function end
1+ abstract type NeuralDELayer <: Lux.AbstractExplicitLayer end
22basic_tgrad (u,p,t) = zero (u)
33Flux. trainable (m:: NeuralDELayer ) = (m. p,)
44
@@ -69,6 +69,9 @@ struct NeuralODE{M,P,RE,T,A,K} <: NeuralDELayer
6969 end
7070end
7171
72+ Lux. initialparameters (rng:: AbstractRNG , n:: NeuralODE ) = Lux. initialparameters (rng, n. model)
73+ Lux. initialstates (rng:: AbstractRNG , n:: NeuralODE ) = Lux. initialstates (rng, n. model)
74+
7275function (n:: NeuralODE )(x,p= n. p)
7376 dudt_ (u,p,t) = n. re (p)(u)
7477 ff = ODEFunction {false} (dudt_,tgrad= basic_tgrad)
@@ -86,10 +89,11 @@ function (n::NeuralODE{M})(x,p=n.p) where {M<:FastChain}
8689end
8790
8891function (n:: NeuralODE{M} )(x,p,st) where {M<: Lux.AbstractExplicitLayer }
89- function dudt (u,p,t)
92+ function dudt (u,p,t;st = st )
9093 u_, st = n. model (u,p,st)
9194 return u_
9295 end
96+
9397 ff = ODEFunction {false} (dudt,tgrad= basic_tgrad)
9498 prob = ODEProblem {false} (ff,x,n. tspan,p)
9599 sense = InterpolatingAdjoint (autojacvec= ZygoteVJP ())
@@ -121,8 +125,7 @@ Arguments:
121125 documentation for more details.
122126
123127"""
124- # struct NeuralDSDE{M,P,RE,M2,RE2,T,A,K} <: NeuralDELayer
125- struct NeuralDSDE{M,P,RE,M2,RE2,T,A,K} <: Lux.AbstractExplicitLayer
128+ struct NeuralDSDE{M,P,RE,M2,RE2,T,A,K} <: NeuralDELayer
126129 p:: P
127130 len:: Int
128131 model1:: M
@@ -180,13 +183,13 @@ function (n::NeuralDSDE{M})(x,p=n.p) where {M<:FastChain}
180183 solve (prob,n. args... ;sensealg= TrackerAdjoint (),n. kwargs... )
181184end
182185
183- function initialparameters (rng:: AbstractRNG , n:: NeuralDSDE )
186+ function Lux . initialparameters (rng:: AbstractRNG , n:: NeuralDSDE )
184187 p1 = Lux. initialparameters (rng, n. model1)
185188 p2 = Lux. initialparameters (rng, n. model2)
186189 return Lux. ComponentArray ((p1 = p1, p2 = p2))
187190end
188191
189- function initialstates (rng:: AbstractRNG , n:: NeuralDSDE )
192+ function Lux . initialstates (rng:: AbstractRNG , n:: NeuralDSDE )
190193 st1 = Lux. initialstates (rng, n. model1)
191194 st2 = Lux. initialstates (rng, n. model2)
192195 return (state1 = st1, state2 = st2)
@@ -291,14 +294,14 @@ function (n::NeuralSDE{P,M})(x,p=n.p) where {P,M<:FastChain}
291294 solve (prob,n. args... ;sensealg= TrackerAdjoint (),n. kwargs... )
292295end
293296
294- function initialparameters (rng:: AbstractRNG , n:: NeuralSDE )
295- p1 = initialparameters (rng, n. model1)
296- p2 = initialparameters (rng, n. model2)
297+ function Lux . initialparameters (rng:: AbstractRNG , n:: NeuralSDE )
298+ p1 = Lux . initialparameters (rng, n. model1)
299+ p2 = Lux . initialparameters (rng, n. model2)
297300 return Lux. ComponentArray ((p1 = p1, p2 = p2))
298301end
299- function initialstates (rng:: AbstractRNG , n:: NeuralSDE )
300- st1 = initialstates (rng, n. model1)
301- st2 = initialstates (rng, n. model2)
302+ function Lux . initialstates (rng:: AbstractRNG , n:: NeuralSDE )
303+ st1 = Lux . initialstates (rng, n. model1)
304+ st2 = Lux . initialstates (rng, n. model2)
302305 return (state1 = st1, state2 = st2)
303306end
304307
0 commit comments