Skip to content

Commit d68bd56

Browse files
Merge pull request #735 from SciML/neural_sde
Neural SDE tutorial
2 parents efd1f23 + 931249a commit d68bd56

File tree

4 files changed

+266
-8
lines changed

4 files changed

+266
-8
lines changed

docs/Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,9 @@ OptimizationPolyalgorithms = "500b13db-7e66-49ce-bda4-eed966be6282"
1717
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
1818
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
1919
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
20+
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
2021
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
22+
StochasticDiffEq = "789caeaf-c7a9-5a7d-9973-96adeb23e2a0"
2123

2224
[compat]
2325
Documenter = "0.27"

docs/pages.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ pages = [
66
"examples/mnist_neural_ode.md",
77
"examples/mnist_conv_neural_ode.md",
88
"examples/augmented_neural_ode.md",
9+
"examples/neural_sde.md",
910
"examples/collocation.md",
1011
"examples/normalizing_flows.md",
1112
"examples/hamiltonian_nn.md",

docs/src/examples/neural_sde.md

Lines changed: 200 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,200 @@
1+
# Neural Stochastic Differential Equations With Method of Moments
2+
3+
With neural stochastic differential equations, there is once again a helper form
4+
`neural_dmsde` which can be used for the multiplicative noise case (consult the
5+
layers API documentation, or [this full example using the layer
6+
function](https://github.com/MikeInnes/zygote-paper/blob/master/neural_sde/neural_sde.jl)).
7+
8+
However, since there are far too many possible combinations for the API to
9+
support, in many cases you will want to performantly define neural differential
10+
equations for non-ODE systems from scratch. For these systems, it is generally
11+
best to use `TrackerAdjoint` with non-mutating (out-of-place) forms. For
12+
example, the following defines a neural SDE with neural networks for both the
13+
drift and diffusion terms:
14+
15+
```julia
16+
dudt(u, p, t) = model(u)
17+
g(u, p, t) = model2(u)
18+
prob = SDEProblem(dudt, g, x, tspan, nothing)
19+
```
20+
21+
where `model` and `model2` are different neural networks. The same can apply to
22+
a neural delay differential equation. Its out-of-place formulation is
23+
`f(u,h,p,t)`. Thus for example, if we want to define a neural delay differential
24+
equation which uses the history value at `p.tau` in the past, we can define:
25+
26+
```julia
27+
dudt!(u, h, p, t) = model([u; h(t - p.tau)])
28+
prob = DDEProblem(dudt_, u0, h, tspan, nothing)
29+
```
30+
31+
32+
First let's build training data from the same example as the neural ODE:
33+
34+
```@example nsde
35+
using Plots, Statistics
36+
using Flux, Optimization, OptimizationFlux, DiffEqFlux, StochasticDiffEq, SciMLBase.EnsembleAnalysis
37+
38+
u0 = Float32[2.; 0.]
39+
datasize = 30
40+
tspan = (0.0f0, 1.0f0)
41+
tsteps = range(tspan[1], tspan[2], length = datasize)
42+
```
43+
44+
```@example nsde
45+
function trueSDEfunc(du, u, p, t)
46+
true_A = [-0.1 2.0; -2.0 -0.1]
47+
du .= ((u.^3)'true_A)'
48+
end
49+
50+
mp = Float32[0.2, 0.2]
51+
function true_noise_func(du, u, p, t)
52+
du .= mp.*u
53+
end
54+
55+
prob_truesde = SDEProblem(trueSDEfunc, true_noise_func, u0, tspan)
56+
```
57+
58+
For our dataset we will use DifferentialEquations.jl's [parallel ensemble
59+
interface](http://docs.juliadiffeq.org/dev/features/ensemble.html) to generate
60+
data from the average of 10,000 runs of the SDE:
61+
62+
```@example nsde
63+
# Take a typical sample from the mean
64+
ensemble_prob = EnsembleProblem(prob_truesde)
65+
ensemble_sol = solve(ensemble_prob, SOSRI(), trajectories = 10000)
66+
ensemble_sum = EnsembleSummary(ensemble_sol)
67+
68+
sde_data, sde_data_vars = Array.(timeseries_point_meanvar(ensemble_sol, tsteps))
69+
```
70+
71+
Now we build a neural SDE. For simplicity we will use the `NeuralDSDE`
72+
neural SDE with diagonal noise layer function:
73+
74+
```@example nsde
75+
drift_dudt = Flux.Chain(x -> x.^3,
76+
Flux.Dense(2, 50, tanh),
77+
Flux.Dense(50, 2))
78+
p1, re1 = Flux.destructure(drift_dudt)
79+
80+
diffusion_dudt = Flux.Chain(Flux.Dense(2, 2))
81+
p2, re2 = Flux.destructure(diffusion_dudt)
82+
83+
neuralsde = NeuralDSDE(drift_dudt, diffusion_dudt, tspan, SOSRI(),
84+
saveat = tsteps, reltol = 1e-1, abstol = 1e-1)
85+
```
86+
87+
Let's see what that looks like:
88+
89+
```@example nsde
90+
# Get the prediction using the correct initial condition
91+
prediction0 = neuralsde(u0)
92+
93+
drift_(u, p, t) = re1(p[1:neuralsde.len])(u)
94+
diffusion_(u, p, t) = re2(p[neuralsde.len+1:end])(u)
95+
96+
prob_neuralsde = SDEProblem(drift_, diffusion_, u0,(0.0f0, 1.2f0), neuralsde.p)
97+
98+
ensemble_nprob = EnsembleProblem(prob_neuralsde)
99+
ensemble_nsol = solve(ensemble_nprob, SOSRI(), trajectories = 100,
100+
saveat = tsteps)
101+
ensemble_nsum = EnsembleSummary(ensemble_nsol)
102+
103+
plt1 = plot(ensemble_nsum, title = "Neural SDE: Before Training")
104+
scatter!(plt1, tsteps, sde_data', lw = 3)
105+
106+
scatter(tsteps, sde_data[1,:], label = "data")
107+
scatter!(tsteps, prediction0[1,:], label = "prediction")
108+
```
109+
110+
Now just as with the neural ODE we define a loss function that calculates the
111+
mean and variance from `n` runs at each time point and uses the distance from
112+
the data values:
113+
114+
```@example nsde
115+
function predict_neuralsde(p, u = u0)
116+
return Array(neuralsde(u, p))
117+
end
118+
119+
function loss_neuralsde(p; n = 100)
120+
u = repeat(reshape(u0, :, 1), 1, n)
121+
samples = predict_neuralsde(p, u)
122+
means = mean(samples, dims = 2)
123+
vars = var(samples, dims = 2, mean = means)[:, 1, :]
124+
means = means[:, 1, :]
125+
loss = sum(abs2, sde_data - means) + sum(abs2, sde_data_vars - vars)
126+
return loss, means, vars
127+
end
128+
```
129+
130+
```@example nsde
131+
list_plots = []
132+
iter = 0
133+
134+
# Callback function to observe training
135+
callback = function (p, loss, means, vars; doplot = false)
136+
global list_plots, iter
137+
138+
if iter == 0
139+
list_plots = []
140+
end
141+
iter += 1
142+
143+
# loss against current data
144+
display(loss)
145+
146+
# plot current prediction against data
147+
plt = Plots.scatter(tsteps, sde_data[1,:], yerror = sde_data_vars[1,:],
148+
ylim = (-4.0, 8.0), label = "data")
149+
Plots.scatter!(plt, tsteps, means[1,:], ribbon = vars[1,:], label = "prediction")
150+
push!(list_plots, plt)
151+
152+
if doplot
153+
display(plt)
154+
end
155+
return false
156+
end
157+
```
158+
159+
Now we train using this loss function. We can pre-train a little bit using a
160+
smaller `n` and then decrease it after it has had some time to adjust towards
161+
the right mean behavior:
162+
163+
```@example nsde
164+
opt = ADAM(0.025)
165+
166+
# First round of training with n = 10
167+
adtype = Optimization.AutoZygote()
168+
optf = Optimization.OptimizationFunction((x,p) -> loss_neuralsde(x, n=10), adtype)
169+
optprob = Optimization.OptimizationProblem(optf, neuralsde.p)
170+
result1 = Optimization.solve(optprob, opt,
171+
callback = callback, maxiters = 100)
172+
```
173+
174+
We resume the training with a larger `n`. (WARNING - this step is a couple of
175+
orders of magnitude longer than the previous one).
176+
177+
```@example nsde
178+
optf2 = Optimization.OptimizationFunction((x,p) -> loss_neuralsde(x, n=100), adtype)
179+
optprob2 = Optimization.OptimizationProblem(optf2, result1.u)
180+
result2 = Optimization.solve(optprob2, opt,
181+
callback = callback, maxiters = 100)
182+
```
183+
184+
And now we plot the solution to an ensemble of the trained neural SDE:
185+
186+
```@example nsde
187+
_, means, vars = loss_neuralsde(result2.u, n = 1000)
188+
189+
plt2 = Plots.scatter(tsteps, sde_data', yerror = sde_data_vars',
190+
label = "data", title = "Neural SDE: After Training",
191+
xlabel = "Time")
192+
plot!(plt2, tsteps, means', lw = 8, ribbon = vars', label = "prediction")
193+
194+
plt = plot(plt1, plt2, layout = (2, 1))
195+
savefig(plt, "NN_sde_combined.png"); nothing # sde
196+
```
197+
198+
![](https://user-images.githubusercontent.com/1814174/76975872-88dc9100-6909-11ea-80f7-242f661ebad1.png)
199+
200+
Try this with GPUs as well!

src/neural_de.jl

Lines changed: 63 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
abstract type NeuralDELayer <: Function end
1+
abstract type NeuralDELayer <: Lux.AbstractExplicitLayer end
22
basic_tgrad(u,p,t) = zero(u)
33
Flux.trainable(m::NeuralDELayer) = (m.p,)
44

@@ -69,6 +69,9 @@ struct NeuralODE{M,P,RE,T,A,K} <: NeuralDELayer
6969
end
7070
end
7171

72+
Lux.initialparameters(rng::AbstractRNG, n::NeuralODE) = Lux.initialparameters(rng, n.model)
73+
Lux.initialstates(rng::AbstractRNG, n::NeuralODE) = Lux.initialstates(rng, n.model)
74+
7275
function (n::NeuralODE)(x,p=n.p)
7376
dudt_(u,p,t) = n.re(p)(u)
7477
ff = ODEFunction{false}(dudt_,tgrad=basic_tgrad)
@@ -86,10 +89,11 @@ function (n::NeuralODE{M})(x,p=n.p) where {M<:FastChain}
8689
end
8790

8891
function (n::NeuralODE{M})(x,p,st) where {M<:Lux.AbstractExplicitLayer}
89-
function dudt(u,p,t)
92+
function dudt(u,p,t;st=st)
9093
u_, st = n.model(u,p,st)
9194
return u_
9295
end
96+
9397
ff = ODEFunction{false}(dudt,tgrad=basic_tgrad)
9498
prob = ODEProblem{false}(ff,x,n.tspan,p)
9599
sense = InterpolatingAdjoint(autojacvec=ZygoteVJP())
@@ -179,19 +183,33 @@ function (n::NeuralDSDE{M})(x,p=n.p) where {M<:FastChain}
179183
solve(prob,n.args...;sensealg=TrackerAdjoint(),n.kwargs...)
180184
end
181185

182-
function (n::NeuralDSDE{M})(x,p,st1,st2) where {M<:Lux.AbstractExplicitLayer}
183-
function dudt_(u,p,t)
184-
u_, st1 = n.model1(u,p[1],st1)
186+
function Lux.initialparameters(rng::AbstractRNG, n::NeuralDSDE)
187+
p1 = Lux.initialparameters(rng, n.model1)
188+
p2 = Lux.initialparameters(rng, n.model2)
189+
return Lux.ComponentArray((p1 = p1, p2 = p2))
190+
end
191+
192+
function Lux.initialstates(rng::AbstractRNG, n::NeuralDSDE)
193+
st1 = Lux.initialstates(rng, n.model1)
194+
st2 = Lux.initialstates(rng, n.model2)
195+
return (state1 = st1, state2 = st2)
196+
end
197+
198+
function (n::NeuralDSDE{M})(x,p,st) where {M<:Lux.AbstractExplicitLayer}
199+
st1 = st.state1
200+
st2 = st.state2
201+
function dudt_(u,p,t;st=st1)
202+
u_, st = n.model1(u,p.p1,st)
185203
return u_
186204
end
187-
function g(u,p,t)
188-
u_, st2 = n.model2(u,p[2],st2)
205+
function g(u,p,t;st=st2)
206+
u_, st = n.model2(u,p.p2,st)
189207
return u_
190208
end
191209

192210
ff = SDEFunction{false}(dudt_,g,tgrad=basic_tgrad)
193211
prob = SDEProblem{false}(ff,g,x,n.tspan,p)
194-
return solve(prob,n.args...;sensealg=TrackerAdjoint(),n.kwargs...), st1, st2
212+
return solve(prob,n.args...;sensealg=InterpolatingAdjoint(),n.kwargs...), (state1 = st1, state2 = st2)
195213
end
196214

197215
"""
@@ -251,6 +269,15 @@ struct NeuralSDE{P,M,RE,M2,RE2,T,A,K} <: NeuralDELayer
251269
typeof(tspan),typeof(args),typeof(kwargs)}(
252270
p,length(p1),model1,re1,model2,re2,tspan,nbrown,args,kwargs)
253271
end
272+
273+
function NeuralSDE(model1::Lux.AbstractExplicitLayer, model2::Lux.AbstractExplicitLayer,tspan,nbrown,args...;
274+
p1 = nothing, p = nothing, kwargs...)
275+
re1 = nothing
276+
re2 = nothing
277+
new{typeof(p),typeof(model1),typeof(re1),typeof(model2),typeof(re2),
278+
typeof(tspan),typeof(args),typeof(kwargs)}(
279+
p,Int(1),model1,re1,model2,re2,tspan,nbrown,args,kwargs)
280+
end
254281
end
255282

256283
function (n::NeuralSDE)(x,p=n.p)
@@ -269,6 +296,34 @@ function (n::NeuralSDE{P,M})(x,p=n.p) where {P,M<:FastChain}
269296
solve(prob,n.args...;sensealg=TrackerAdjoint(),n.kwargs...)
270297
end
271298

299+
function Lux.initialparameters(rng::AbstractRNG, n::NeuralSDE)
300+
p1 = Lux.initialparameters(rng, n.model1)
301+
p2 = Lux.initialparameters(rng, n.model2)
302+
return Lux.ComponentArray((p1 = p1, p2 = p2))
303+
end
304+
function Lux.initialstates(rng::AbstractRNG, n::NeuralSDE)
305+
st1 = Lux.initialstates(rng, n.model1)
306+
st2 = Lux.initialstates(rng, n.model2)
307+
return (state1 = st1, state2 = st2)
308+
end
309+
310+
function (n::NeuralSDE{P,M})(x,p,st) where {P,M<:Lux.AbstractExplicitLayer}
311+
st1 = st.state1
312+
st2 = st.state2
313+
function dudt_(u,p,t;st=st1)
314+
u_, st = n.model1(u,p.p1,st)
315+
return u_
316+
end
317+
function g(u,p,t;st=st2)
318+
u_, st = n.model2(u,p.p2,st)
319+
return u_
320+
end
321+
322+
ff = SDEFunction{false}(dudt_,g,tgrad=basic_tgrad)
323+
prob = SDEProblem{false}(ff,g,x,n.tspan,p,noise_rate_prototype=zeros(Float32,length(x),n.nbrown))
324+
solve(prob,n.args...;sensealg=InterpolatingAdjoint(),n.kwargs...), (state1 = st1, state2 = st2)
325+
end
326+
272327
"""
273328
Constructs a neural delay differential equation (neural DDE) with constant
274329
delays.

0 commit comments

Comments
 (0)