Skip to content

Commit f862173

Browse files
committed
corrections
1 parent 8c47105 commit f862173

File tree

1 file changed

+15
-12
lines changed

1 file changed

+15
-12
lines changed

src/neural_de.jl

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
abstract type NeuralDELayer <: Function end
1+
abstract type NeuralDELayer <: Lux.AbstractExplicitLayer end
22
basic_tgrad(u,p,t) = zero(u)
33
Flux.trainable(m::NeuralDELayer) = (m.p,)
44

@@ -69,6 +69,9 @@ struct NeuralODE{M,P,RE,T,A,K} <: NeuralDELayer
6969
end
7070
end
7171

72+
Lux.initialparameters(rng::AbstractRNG, n::NeuralODE) = Lux.initialparameters(rng, n.model)
73+
Lux.initialstates(rng::AbstractRNG, n::NeuralODE) = Lux.initialstates(rng, n.model)
74+
7275
function (n::NeuralODE)(x,p=n.p)
7376
dudt_(u,p,t) = n.re(p)(u)
7477
ff = ODEFunction{false}(dudt_,tgrad=basic_tgrad)
@@ -86,10 +89,11 @@ function (n::NeuralODE{M})(x,p=n.p) where {M<:FastChain}
8689
end
8790

8891
function (n::NeuralODE{M})(x,p,st) where {M<:Lux.AbstractExplicitLayer}
89-
function dudt(u,p,t)
92+
function dudt(u,p,t;st=st)
9093
u_, st = n.model(u,p,st)
9194
return u_
9295
end
96+
9397
ff = ODEFunction{false}(dudt,tgrad=basic_tgrad)
9498
prob = ODEProblem{false}(ff,x,n.tspan,p)
9599
sense = InterpolatingAdjoint(autojacvec=ZygoteVJP())
@@ -121,8 +125,7 @@ Arguments:
121125
documentation for more details.
122126
123127
"""
124-
# struct NeuralDSDE{M,P,RE,M2,RE2,T,A,K} <: NeuralDELayer
125-
struct NeuralDSDE{M,P,RE,M2,RE2,T,A,K} <: Lux.AbstractExplicitLayer
128+
struct NeuralDSDE{M,P,RE,M2,RE2,T,A,K} <: NeuralDELayer
126129
p::P
127130
len::Int
128131
model1::M
@@ -180,13 +183,13 @@ function (n::NeuralDSDE{M})(x,p=n.p) where {M<:FastChain}
180183
solve(prob,n.args...;sensealg=TrackerAdjoint(),n.kwargs...)
181184
end
182185

183-
function initialparameters(rng::AbstractRNG, n::NeuralDSDE)
186+
function Lux.initialparameters(rng::AbstractRNG, n::NeuralDSDE)
184187
p1 = Lux.initialparameters(rng, n.model1)
185188
p2 = Lux.initialparameters(rng, n.model2)
186189
return Lux.ComponentArray((p1 = p1, p2 = p2))
187190
end
188191

189-
function initialstates(rng::AbstractRNG, n::NeuralDSDE)
192+
function Lux.initialstates(rng::AbstractRNG, n::NeuralDSDE)
190193
st1 = Lux.initialstates(rng, n.model1)
191194
st2 = Lux.initialstates(rng, n.model2)
192195
return (state1 = st1, state2 = st2)
@@ -291,14 +294,14 @@ function (n::NeuralSDE{P,M})(x,p=n.p) where {P,M<:FastChain}
291294
solve(prob,n.args...;sensealg=TrackerAdjoint(),n.kwargs...)
292295
end
293296

294-
function initialparameters(rng::AbstractRNG, n::NeuralSDE)
295-
p1 = initialparameters(rng, n.model1)
296-
p2 = initialparameters(rng, n.model2)
297+
function Lux.initialparameters(rng::AbstractRNG, n::NeuralSDE)
298+
p1 = Lux.initialparameters(rng, n.model1)
299+
p2 = Lux.initialparameters(rng, n.model2)
297300
return Lux.ComponentArray((p1 = p1, p2 = p2))
298301
end
299-
function initialstates(rng::AbstractRNG, n::NeuralSDE)
300-
st1 = initialstates(rng, n.model1)
301-
st2 = initialstates(rng, n.model2)
302+
function Lux.initialstates(rng::AbstractRNG, n::NeuralSDE)
303+
st1 = Lux.initialstates(rng, n.model1)
304+
st2 = Lux.initialstates(rng, n.model2)
302305
return (state1 = st1, state2 = st2)
303306
end
304307

0 commit comments

Comments
 (0)