Skip to content

Commit 931249a

Browse files
Merge pull request #738 from Abhishek-1Bhatt/neuralsde2
neural_sde example in Flux
2 parents a9d252a + 6d5cc39 commit 931249a

File tree

2 files changed

+76
-27
lines changed

2 files changed

+76
-27
lines changed

docs/src/examples/neural_sde.md

Lines changed: 13 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -33,9 +33,8 @@ First let's build training data from the same example as the neural ODE:
3333

3434
```@example nsde
3535
using Plots, Statistics
36-
using Lux, Optimization, OptimizationFlux, DiffEqFlux, StochasticDiffEq, SciMLBase.EnsembleAnalysis, Random
36+
using Flux, Optimization, OptimizationFlux, DiffEqFlux, StochasticDiffEq, SciMLBase.EnsembleAnalysis
3737
38-
rng = Random.default_rng()
3938
u0 = Float32[2.; 0.]
4039
datasize = 30
4140
tspan = (0.0f0, 1.0f0)
@@ -73,18 +72,13 @@ Now we build a neural SDE. For simplicity we will use the `NeuralDSDE`
7372
neural SDE with diagonal noise layer function:
7473

7574
```@example nsde
76-
drift_dudt = Lux.Chain(ActivationFunction(x -> x.^3),
77-
Lux.Dense(2, 50, tanh),
78-
Lux.Dense(50, 2))
79-
p1, st1 = Lux.setup(rng, drift_dudt)
75+
drift_dudt = Flux.Chain(x -> x.^3,
76+
Flux.Dense(2, 50, tanh),
77+
Flux.Dense(50, 2))
78+
p1, re1 = Flux.destructure(drift_dudt)
8079
81-
diffusion_dudt = Lux.Chain(Lux.Dense(2, 2))
82-
p2, st2 = Lux.setup(rng, diffusion_dudt)
83-
84-
p1 = Lux.ComponentArray(p1)
85-
p2 = Lux.ComponentArray(p2)
86-
#Component Arrays doesn't provide a name to the first ComponentVector, only subsequent ones get a name for dereferencing
87-
p = [p1, p2]
80+
diffusion_dudt = Flux.Chain(Flux.Dense(2, 2))
81+
p2, re2 = Flux.destructure(diffusion_dudt)
8882
8983
neuralsde = NeuralDSDE(drift_dudt, diffusion_dudt, tspan, SOSRI(),
9084
saveat = tsteps, reltol = 1e-1, abstol = 1e-1)
@@ -94,12 +88,12 @@ Let's see what that looks like:
9488

9589
```@example nsde
9690
# Get the prediction using the correct initial condition
97-
prediction0, st1, st2 = neuralsde(u0,p,st1,st2)
91+
prediction0 = neuralsde(u0)
9892
99-
drift_(u, p, t) = drift_dudt(u, p[1], st1)[1]
100-
diffusion_(u, p, t) = diffusion_dudt(u, p[2], st2)[1]
93+
drift_(u, p, t) = re1(p[1:neuralsde.len])(u)
94+
diffusion_(u, p, t) = re2(p[neuralsde.len+1:end])(u)
10195
102-
prob_neuralsde = SDEProblem(drift_, diffusion_, u0,(0.0f0, 1.2f0), p)
96+
prob_neuralsde = SDEProblem(drift_, diffusion_, u0,(0.0f0, 1.2f0), neuralsde.p)
10397
10498
ensemble_nprob = EnsembleProblem(prob_neuralsde)
10599
ensemble_nsol = solve(ensemble_nprob, SOSRI(), trajectories = 100,
@@ -119,7 +113,7 @@ the data values:
119113

120114
```@example nsde
121115
function predict_neuralsde(p, u = u0)
122-
return Array(neuralsde(u, p, st1, st2)[1])
116+
return Array(neuralsde(u, p))
123117
end
124118
125119
function loss_neuralsde(p; n = 100)
@@ -172,7 +166,7 @@ opt = ADAM(0.025)
172166
# First round of training with n = 10
173167
adtype = Optimization.AutoZygote()
174168
optf = Optimization.OptimizationFunction((x,p) -> loss_neuralsde(x, n=10), adtype)
175-
optprob = Optimization.OptimizationProblem(optf, p)
169+
optprob = Optimization.OptimizationProblem(optf, neuralsde.p)
176170
result1 = Optimization.solve(optprob, opt,
177171
callback = callback, maxiters = 100)
178172
```

src/neural_de.jl

Lines changed: 63 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
abstract type NeuralDELayer <: Function end
1+
abstract type NeuralDELayer <: Lux.AbstractExplicitLayer end
22
basic_tgrad(u,p,t) = zero(u)
33
Flux.trainable(m::NeuralDELayer) = (m.p,)
44

@@ -69,6 +69,9 @@ struct NeuralODE{M,P,RE,T,A,K} <: NeuralDELayer
6969
end
7070
end
7171

72+
Lux.initialparameters(rng::AbstractRNG, n::NeuralODE) = Lux.initialparameters(rng, n.model)
73+
Lux.initialstates(rng::AbstractRNG, n::NeuralODE) = Lux.initialstates(rng, n.model)
74+
7275
function (n::NeuralODE)(x,p=n.p)
7376
dudt_(u,p,t) = n.re(p)(u)
7477
ff = ODEFunction{false}(dudt_,tgrad=basic_tgrad)
@@ -86,10 +89,11 @@ function (n::NeuralODE{M})(x,p=n.p) where {M<:FastChain}
8689
end
8790

8891
function (n::NeuralODE{M})(x,p,st) where {M<:Lux.AbstractExplicitLayer}
89-
function dudt(u,p,t)
92+
function dudt(u,p,t;st=st)
9093
u_, st = n.model(u,p,st)
9194
return u_
9295
end
96+
9397
ff = ODEFunction{false}(dudt,tgrad=basic_tgrad)
9498
prob = ODEProblem{false}(ff,x,n.tspan,p)
9599
sense = InterpolatingAdjoint(autojacvec=ZygoteVJP())
@@ -179,19 +183,33 @@ function (n::NeuralDSDE{M})(x,p=n.p) where {M<:FastChain}
179183
solve(prob,n.args...;sensealg=TrackerAdjoint(),n.kwargs...)
180184
end
181185

182-
function (n::NeuralDSDE{M})(x,p,st1,st2) where {M<:Lux.AbstractExplicitLayer}
183-
function dudt_(u,p,t)
184-
u_, st1 = n.model1(u,p[1],st1)
186+
function Lux.initialparameters(rng::AbstractRNG, n::NeuralDSDE)
187+
p1 = Lux.initialparameters(rng, n.model1)
188+
p2 = Lux.initialparameters(rng, n.model2)
189+
return Lux.ComponentArray((p1 = p1, p2 = p2))
190+
end
191+
192+
function Lux.initialstates(rng::AbstractRNG, n::NeuralDSDE)
193+
st1 = Lux.initialstates(rng, n.model1)
194+
st2 = Lux.initialstates(rng, n.model2)
195+
return (state1 = st1, state2 = st2)
196+
end
197+
198+
function (n::NeuralDSDE{M})(x,p,st) where {M<:Lux.AbstractExplicitLayer}
199+
st1 = st.state1
200+
st2 = st.state2
201+
function dudt_(u,p,t;st=st1)
202+
u_, st = n.model1(u,p.p1,st)
185203
return u_
186204
end
187-
function g(u,p,t)
188-
u_, st2 = n.model2(u,p[2],st2)
205+
function g(u,p,t;st=st2)
206+
u_, st = n.model2(u,p.p2,st)
189207
return u_
190208
end
191209

192210
ff = SDEFunction{false}(dudt_,g,tgrad=basic_tgrad)
193211
prob = SDEProblem{false}(ff,g,x,n.tspan,p)
194-
return solve(prob,n.args...;sensealg=TrackerAdjoint(),n.kwargs...), st1, st2
212+
return solve(prob,n.args...;sensealg=InterpolatingAdjoint(),n.kwargs...), (state1 = st1, state2 = st2)
195213
end
196214

197215
"""
@@ -251,6 +269,15 @@ struct NeuralSDE{P,M,RE,M2,RE2,T,A,K} <: NeuralDELayer
251269
typeof(tspan),typeof(args),typeof(kwargs)}(
252270
p,length(p1),model1,re1,model2,re2,tspan,nbrown,args,kwargs)
253271
end
272+
273+
function NeuralSDE(model1::Lux.AbstractExplicitLayer, model2::Lux.AbstractExplicitLayer,tspan,nbrown,args...;
274+
p1 = nothing, p = nothing, kwargs...)
275+
re1 = nothing
276+
re2 = nothing
277+
new{typeof(p),typeof(model1),typeof(re1),typeof(model2),typeof(re2),
278+
typeof(tspan),typeof(args),typeof(kwargs)}(
279+
p,Int(1),model1,re1,model2,re2,tspan,nbrown,args,kwargs)
280+
end
254281
end
255282

256283
function (n::NeuralSDE)(x,p=n.p)
@@ -269,6 +296,34 @@ function (n::NeuralSDE{P,M})(x,p=n.p) where {P,M<:FastChain}
269296
solve(prob,n.args...;sensealg=TrackerAdjoint(),n.kwargs...)
270297
end
271298

299+
function Lux.initialparameters(rng::AbstractRNG, n::NeuralSDE)
300+
p1 = Lux.initialparameters(rng, n.model1)
301+
p2 = Lux.initialparameters(rng, n.model2)
302+
return Lux.ComponentArray((p1 = p1, p2 = p2))
303+
end
304+
function Lux.initialstates(rng::AbstractRNG, n::NeuralSDE)
305+
st1 = Lux.initialstates(rng, n.model1)
306+
st2 = Lux.initialstates(rng, n.model2)
307+
return (state1 = st1, state2 = st2)
308+
end
309+
310+
function (n::NeuralSDE{P,M})(x,p,st) where {P,M<:Lux.AbstractExplicitLayer}
311+
st1 = st.state1
312+
st2 = st.state2
313+
function dudt_(u,p,t;st=st1)
314+
u_, st = n.model1(u,p.p1,st)
315+
return u_
316+
end
317+
function g(u,p,t;st=st2)
318+
u_, st = n.model2(u,p.p2,st)
319+
return u_
320+
end
321+
322+
ff = SDEFunction{false}(dudt_,g,tgrad=basic_tgrad)
323+
prob = SDEProblem{false}(ff,g,x,n.tspan,p,noise_rate_prototype=zeros(Float32,length(x),n.nbrown))
324+
solve(prob,n.args...;sensealg=InterpolatingAdjoint(),n.kwargs...), (state1 = st1, state2 = st2)
325+
end
326+
272327
"""
273328
Constructs a neural delay differential equation (neural DDE) with constant
274329
delays.

0 commit comments

Comments
 (0)