Skip to content

Commit 8c47105

Browse files
committed
initial changes for neuraldsde and neuralsde
1 parent 0b45582 commit 8c47105

File tree

2 files changed

+61
-17
lines changed

2 files changed

+61
-17
lines changed

docs/src/examples/neural_sde.md

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -76,28 +76,24 @@ neural SDE with diagonal noise layer function:
7676
drift_dudt = Lux.Chain(ActivationFunction(x -> x.^3),
7777
Lux.Dense(2, 50, tanh),
7878
Lux.Dense(50, 2))
79-
p1, st1 = Lux.setup(rng, drift_dudt)
8079
8180
diffusion_dudt = Lux.Chain(Lux.Dense(2, 2))
82-
p2, st2 = Lux.setup(rng, diffusion_dudt)
8381
84-
p1 = Lux.ComponentArray(p1)
85-
p2 = Lux.ComponentArray(p2)
86-
#Component Arrays doesn't provide a name to the first ComponentVector, only subsequent ones get a name for dereferencing
87-
p = [p1, p2]
8882
8983
neuralsde = NeuralDSDE(drift_dudt, diffusion_dudt, tspan, SOSRI(),
9084
saveat = tsteps, reltol = 1e-1, abstol = 1e-1)
85+
86+
p, st = Lux.setup(rng, neuralsde)
9187
```
9288

9389
Let's see what that looks like:
9490

9591
```@example nsde
9692
# Get the prediction using the correct initial condition
97-
prediction0, st1, st2 = neuralsde(u0,p,st1,st2)
93+
prediction0, st = neuralsde(u0,p,st)
9894
99-
drift_(u, p, t) = drift_dudt(u, p[1], st1)[1]
100-
diffusion_(u, p, t) = diffusion_dudt(u, p[2], st2)[1]
95+
drift_(u, p, t) = drift_dudt(u, p.p1, st.state1)[1]
96+
diffusion_(u, p, t) = diffusion_dudt(u, p.p2, st.state2)[1]
10197
10298
prob_neuralsde = SDEProblem(drift_, diffusion_, u0,(0.0f0, 1.2f0), p)
10399
@@ -119,7 +115,7 @@ the data values:
119115

120116
```@example nsde
121117
function predict_neuralsde(p, u = u0)
122-
return Array(neuralsde(u, p, st1, st2)[1])
118+
return Array(neuralsde(u, p, st)[1])
123119
end
124120
125121
function loss_neuralsde(p; n = 100)

src/neural_de.jl

Lines changed: 55 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,8 @@ Arguments:
121121
documentation for more details.
122122
123123
"""
124-
struct NeuralDSDE{M,P,RE,M2,RE2,T,A,K} <: NeuralDELayer
124+
# struct NeuralDSDE{M,P,RE,M2,RE2,T,A,K} <: NeuralDELayer
125+
struct NeuralDSDE{M,P,RE,M2,RE2,T,A,K} <: Lux.AbstractExplicitLayer
125126
p::P
126127
len::Int
127128
model1::M
@@ -179,19 +180,31 @@ function (n::NeuralDSDE{M})(x,p=n.p) where {M<:FastChain}
179180
solve(prob,n.args...;sensealg=TrackerAdjoint(),n.kwargs...)
180181
end
181182

182-
function (n::NeuralDSDE{M})(x,p,st1,st2) where {M<:Lux.AbstractExplicitLayer}
183-
function dudt_(u,p,t)
184-
u_, st1 = n.model1(u,p[1],st1)
183+
function initialparameters(rng::AbstractRNG, n::NeuralDSDE)
184+
p1 = Lux.initialparameters(rng, n.model1)
185+
p2 = Lux.initialparameters(rng, n.model2)
186+
return Lux.ComponentArray((p1 = p1, p2 = p2))
187+
end
188+
189+
function initialstates(rng::AbstractRNG, n::NeuralDSDE)
190+
st1 = Lux.initialstates(rng, n.model1)
191+
st2 = Lux.initialstates(rng, n.model2)
192+
return (state1 = st1, state2 = st2)
193+
end
194+
195+
function (n::NeuralDSDE{M})(x,p,st) where {M<:Lux.AbstractExplicitLayer}
196+
function dudt_(u,p,t;st=st)
197+
u_, st.state1 = n.model1(u,p.p1,st.state1)
185198
return u_
186199
end
187-
function g(u,p,t)
188-
u_, st2 = n.model2(u,p[2],st2)
200+
function g(u,p,t;st=st)
201+
u_, st.state2 = n.model2(u,p.p2,st.state2)
189202
return u_
190203
end
191204

192205
ff = SDEFunction{false}(dudt_,g,tgrad=basic_tgrad)
193206
prob = SDEProblem{false}(ff,g,x,n.tspan,p)
194-
return solve(prob,n.args...;sensealg=TrackerAdjoint(),n.kwargs...), st1, st2
207+
return solve(prob,n.args...;sensealg=TrackerAdjoint(),n.kwargs...), st
195208
end
196209

197210
"""
@@ -251,6 +264,15 @@ struct NeuralSDE{P,M,RE,M2,RE2,T,A,K} <: NeuralDELayer
251264
typeof(tspan),typeof(args),typeof(kwargs)}(
252265
p,length(p1),model1,re1,model2,re2,tspan,nbrown,args,kwargs)
253266
end
267+
268+
function NeuralSDE(model1::Lux.AbstractExplicitLayer, model2::Lux.AbstractExplicitLayer,tspan,nbrown,args...;
269+
p1 = nothing, p = nothing, kwargs...)
270+
re1 = nothing
271+
re2 = nothing
272+
new{typeof(p),typeof(model1),typeof(re1),typeof(model2),typeof(re2),
273+
typeof(tspan),typeof(args),typeof(kwargs)}(
274+
p,Int(1),model1,re1,model2,re2,tspan,nbrown,args,kwargs)
275+
end
254276
end
255277

256278
function (n::NeuralSDE)(x,p=n.p)
@@ -269,6 +291,32 @@ function (n::NeuralSDE{P,M})(x,p=n.p) where {P,M<:FastChain}
269291
solve(prob,n.args...;sensealg=TrackerAdjoint(),n.kwargs...)
270292
end
271293

294+
function initialparameters(rng::AbstractRNG, n::NeuralSDE)
295+
p1 = initialparameters(rng, n.model1)
296+
p2 = initialparameters(rng, n.model2)
297+
return Lux.ComponentArray((p1 = p1, p2 = p2))
298+
end
299+
function initialstates(rng::AbstractRNG, n::NeuralSDE)
300+
st1 = initialstates(rng, n.model1)
301+
st2 = initialstates(rng, n.model2)
302+
return (state1 = st1, state2 = st2)
303+
end
304+
305+
function (n::NeuralSDE{P,M})(x,p,st) where {P,M<:Lux.AbstractExplicitLayer}
306+
function dudt_(u,p,t;st=st)
307+
u_, st.state1 = n.model1(u,p.p1,st.state1)
308+
return u_
309+
end
310+
function g(u,p,t;st=st)
311+
u_, st.state2 = n.model2(u,p.p2,st.state2)
312+
return u_
313+
end
314+
315+
ff = SDEFunction{false}(dudt_,g,tgrad=basic_tgrad)
316+
prob = SDEProblem{false}(ff,g,x,n.tspan,p,noise_rate_prototype=zeros(Float32,length(x),n.nbrown))
317+
solve(prob,n.args...;sensealg=ReverseDiffAdjoint(),n.kwargs...), st
318+
end
319+
272320
"""
273321
Constructs a neural delay differential equation (neural DDE) with constant
274322
delays.

0 commit comments

Comments
 (0)