Skip to content

Commit 10b47a3

Browse files
Neural SDE tutorial
Based off of the old one. Might be easiest to just make this Flux based?
1 parent b0b1494 commit 10b47a3

File tree

1 file changed

+206
-0
lines changed

1 file changed

+206
-0
lines changed

docs/src/examples/neural_sde.md

Lines changed: 206 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,206 @@
1+
# Neural Stochastic Differential Equations With Method of Moments
2+
3+
With neural stochastic differential equations, there is once again a helper form
4+
`neural_dmsde` which can be used for the multiplicative noise case (consult the
5+
layers API documentation, or [this full example using the layer
6+
function](https://github.com/MikeInnes/zygote-paper/blob/master/neural_sde/neural_sde.jl)).
7+
8+
However, since there are far too many possible combinations for the API to
9+
support, in many cases you will want to performantly define neural differential
10+
equations for non-ODE systems from scratch. For these systems, it is generally
11+
best to use `TrackerAdjoint` with non-mutating (out-of-place) forms. For
12+
example, the following defines a neural SDE with neural networks for both the
13+
drift and diffusion terms:
14+
15+
```julia
16+
dudt(u, p, t) = model(u)
17+
g(u, p, t) = model2(u)
18+
prob = SDEProblem(dudt, g, x, tspan, nothing)
19+
```
20+
21+
where `model` and `model2` are different neural networks. The same can apply to
22+
a neural delay differential equation. Its out-of-place formulation is
23+
`f(u,h,p,t)`. Thus for example, if we want to define a neural delay differential
24+
equation which uses the history value at `p.tau` in the past, we can define:
25+
26+
```julia
27+
dudt!(u, h, p, t) = model([u; h(t - p.tau)])
28+
prob = DDEProblem(dudt_, u0, h, tspan, nothing)
29+
```
30+
31+
32+
First let's build training data from the same example as the neural ODE:
33+
34+
```@example nsde
35+
using Plots, Statistics
36+
using Lux, Optimization, OptimizationFlux, DiffEqFlux, StochasticDiffEq, DiffEqBase.EnsembleAnalysis, Random
37+
38+
rng = Random.default_rng()
39+
u0 = Float32[2.; 0.]
40+
datasize = 30
41+
tspan = (0.0f0, 1.0f0)
42+
tsteps = range(tspan[1], tspan[2], length = datasize)
43+
```
44+
45+
```@example nsde
46+
function trueSDEfunc(du, u, p, t)
47+
true_A = [-0.1 2.0; -2.0 -0.1]
48+
du .= ((u.^3)'true_A)'
49+
end
50+
51+
mp = Float32[0.2, 0.2]
52+
function true_noise_func(du, u, p, t)
53+
du .= mp.*u
54+
end
55+
56+
prob_truesde = SDEProblem(trueSDEfunc, true_noise_func, u0, tspan)
57+
```
58+
59+
For our dataset we will use DifferentialEquations.jl's [parallel ensemble
60+
interface](http://docs.juliadiffeq.org/dev/features/ensemble.html) to generate
61+
data from the average of 10,000 runs of the SDE:
62+
63+
```@example nsde
64+
# Take a typical sample from the mean
65+
ensemble_prob = EnsembleProblem(prob_truesde)
66+
ensemble_sol = solve(ensemble_prob, SOSRI(), trajectories = 10000)
67+
ensemble_sum = EnsembleSummary(ensemble_sol)
68+
69+
sde_data, sde_data_vars = Array.(timeseries_point_meanvar(ensemble_sol, tsteps))
70+
```
71+
72+
Now we build a neural SDE. For simplicity we will use the `NeuralDSDE`
73+
neural SDE with diagonal noise layer function:
74+
75+
```@example nsde
76+
drift_dudt = Lux.Chain(ActivationFunction(x -> x.^3),
77+
Lux.Dense(2, 50, tanh),
78+
Lux.Dense(50, 2))
79+
p1, st1 = Lux.setup(rng, drift_dudt)
80+
81+
diffusion_dudt = Lux.Chain(Lux.Dense(2, 2))
82+
p2, st2 = Lux.setup(rng, diffusion_dudt)
83+
84+
p1 = Lux.ComponentArray(p1)
85+
p2 = Lux.ComponentArray(p2)
86+
#Component Arrays doesn't provide a name to the first ComponentVector, only subsequent ones get a name for dereferencing
87+
p = [p1, p2]
88+
89+
neuralsde = NeuralDSDE(drift_dudt, diffusion_dudt, tspan, SOSRI(),
90+
saveat = tsteps, reltol = 1e-1, abstol = 1e-1)
91+
```
92+
93+
Let's see what that looks like:
94+
95+
```@example nsde
96+
# Get the prediction using the correct initial condition
97+
prediction0, st1, st2 = neuralsde(u0,p,st1,st2)
98+
99+
drift_(u, p, t) = drift_dudt(u, p[1], st1)[1]
100+
diffusion_(u, p, t) = diffusion_dudt(u, p[2], st2)[1]
101+
102+
prob_neuralsde = SDEProblem(drift_, diffusion_, u0,(0.0f0, 1.2f0), p)
103+
104+
ensemble_nprob = EnsembleProblem(prob_neuralsde)
105+
ensemble_nsol = solve(ensemble_nprob, SOSRI(), trajectories = 100,
106+
saveat = tsteps)
107+
ensemble_nsum = EnsembleSummary(ensemble_nsol)
108+
109+
plt1 = plot(ensemble_nsum, title = "Neural SDE: Before Training")
110+
scatter!(plt1, tsteps, sde_data', lw = 3)
111+
112+
scatter(tsteps, sde_data[1,:], label = "data")
113+
scatter!(tsteps, prediction0[1,:], label = "prediction")
114+
```
115+
116+
Now just as with the neural ODE we define a loss function that calculates the
117+
mean and variance from `n` runs at each time point and uses the distance from
118+
the data values:
119+
120+
```@example nsde
121+
function predict_neuralsde(p, u = u0)
122+
return Array(neuralsde(u, p, st1, st2)[1])
123+
end
124+
125+
function loss_neuralsde(p; n = 100)
126+
u = repeat(reshape(u0, :, 1), 1, n)
127+
samples = predict_neuralsde(p, u)
128+
means = mean(samples, dims = 2)
129+
vars = var(samples, dims = 2, mean = means)[:, 1, :]
130+
means = means[:, 1, :]
131+
loss = sum(abs2, sde_data - means) + sum(abs2, sde_data_vars - vars)
132+
return loss, means, vars
133+
end
134+
```
135+
136+
```@example nsde
137+
list_plots = []
138+
iter = 0
139+
140+
# Callback function to observe training
141+
callback = function (p, loss, means, vars; doplot = false)
142+
global list_plots, iter
143+
144+
if iter == 0
145+
list_plots = []
146+
end
147+
iter += 1
148+
149+
# loss against current data
150+
display(loss)
151+
152+
# plot current prediction against data
153+
plt = Plots.scatter(tsteps, sde_data[1,:], yerror = sde_data_vars[1,:],
154+
ylim = (-4.0, 8.0), label = "data")
155+
Plots.scatter!(plt, tsteps, means[1,:], ribbon = vars[1,:], label = "prediction")
156+
push!(list_plots, plt)
157+
158+
if doplot
159+
display(plt)
160+
end
161+
return false
162+
end
163+
```
164+
165+
Now we train using this loss function. We can pre-train a little bit using a
166+
smaller `n` and then decrease it after it has had some time to adjust towards
167+
the right mean behavior:
168+
169+
```@example nsde
170+
opt = ADAM(0.025)
171+
172+
# First round of training with n = 10
173+
adtype = Optimization.AutoZygote()
174+
optf = Optimization.OptimizationFunction((x,p) -> loss_neuralsde(x, n=10), adtype)
175+
optprob = Optimization.OptimizationProblem(optf, p)
176+
result1 = Optimization.solve(optprob, opt,
177+
callback = callback, maxiters = 100)
178+
```
179+
180+
We resume the training with a larger `n`. (WARNING - this step is a couple of
181+
orders of magnitude longer than the previous one).
182+
183+
```@example nsde
184+
optf2 = Optimization.OptimizationFunction((x,p) -> loss_neuralsde(x, n=100), adtype)
185+
optprob2 = Optimization.OptimizationProblem(optf2, result1.u)
186+
result2 = Optimization.solve(optprob2, opt,
187+
callback = callback, maxiters = 100)
188+
```
189+
190+
And now we plot the solution to an ensemble of the trained neural SDE:
191+
192+
```@example nsde
193+
_, means, vars = loss_neuralsde(result2.u, n = 1000)
194+
195+
plt2 = Plots.scatter(tsteps, sde_data', yerror = sde_data_vars',
196+
label = "data", title = "Neural SDE: After Training",
197+
xlabel = "Time")
198+
plot!(plt2, tsteps, means', lw = 8, ribbon = vars', label = "prediction")
199+
200+
plt = plot(plt1, plt2, layout = (2, 1))
201+
savefig(plt, "NN_sde_combined.png"); nothing # sde
202+
```
203+
204+
![](https://user-images.githubusercontent.com/1814174/76975872-88dc9100-6909-11ea-80f7-242f661ebad1.png)
205+
206+
Try this with GPUs as well!

0 commit comments

Comments
 (0)