Skip to content

Commit b0b1494

Browse files
tested locally, add GPU tutorials
1 parent c88150c commit b0b1494

File tree

6 files changed

+954
-0
lines changed

6 files changed

+954
-0
lines changed

docs/Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@ Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
88
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
99
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1010
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
11+
MLDataUtils = "cc2ba9b6-d476-5e6d-8eaf-a92d5412d41d"
12+
MLDatasets = "eb30cadb-4394-5ae3-aed4-317e484a6458"
1113
Optimization = "7f7a1694-90dd-40f0-9382-eb1efda571ba"
1214
OptimizationFlux = "253f991c-a7b2-45f8-8852-8b9a9df78a86"
1315
OptimizationOptimJL = "36348300-93cb-4f02-beb5-3c3902f8871e"

docs/make.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ using Documenter, DiffEqFlux
22

33
ENV["GKSwstype"] = "100"
44
using Plots
5+
ENV["DATADEPS_ALWAYS_ACCEPT"] = true
56

67
include("pages.jl")
78

docs/pages.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,9 @@ pages = [
22
"DiffEqFlux.jl: High Level Scientific Machine Learning (SciML) Pre-Built Architectures" => "index.md",
33
"Differential Equation Machine Learning Tutorials" => Any[
44
"examples/neural_ode.md",
5+
"examples/GPUs.md",
6+
"examples/mnist_neural_ode.md",
7+
"examples/mnist_conv_neural_ode.md",
58
"examples/augmented_neural_ode.md",
69
"examples/collocation.md",
710
"examples/normalizing_flows.md",

docs/src/examples/GPUs.md

Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
# Neural ODEs on GPUs
2+
3+
Note that the differential equation solvers will run on the GPU if the initial
4+
condition is a GPU array. Thus, for example, we can define a neural ODE by hand
5+
that runs on the GPU (if no GPU is available, the calculation defaults back to the CPU):
6+
7+
```julia
8+
using DifferentialEquations, Flux, DiffEqFlux, DiffEqSensitivity
9+
10+
using Random
11+
rng = Random.default_rng()
12+
13+
model_gpu = Chain(Dense(2, 50, tanh), Dense(50, 2)) |> gpu
14+
p, re = Flux.destructure(model_gpu)
15+
dudt!(u, p, t) = re(p)(u)
16+
17+
# Simulation interval and intermediary points
18+
tspan = (0f0, 10f0)
19+
tsteps = 0f0:1f-1:10f0
20+
21+
u0 = Float32[2.0; 0.0] |> gpu
22+
prob_gpu = ODEProblem(dudt!, u0, tspan, p)
23+
24+
# Runs on a GPU
25+
sol_gpu = solve(prob_gpu, Tsit5(), saveat = tsteps)
26+
```
27+
28+
Or we could directly use the neural ODE layer function, like:
29+
30+
```julia
31+
prob_neuralode_gpu = NeuralODE(gpu(model_gpu), tspan, Tsit5(), saveat = tsteps)
32+
```
33+
34+
If one is using `Lux.Chain`, then the computation takes place on the GPU with
35+
`f(x,p,st)` if `x`, `p` and `st` are on the GPU. This commonly looks like:
36+
37+
```julia
38+
import Lux
39+
40+
dudt2 = Lux.Chain(Lux.ActivationFunction(x -> x^3),
41+
Lux.Dense(2,50,tanh),
42+
Lux.Dense(50,2))
43+
44+
u0 = Float32[2.; 0.] |> gpu
45+
p, st = Lux.setup(rng, dudt2) .|> gpu
46+
47+
dudt2_(u, p, t) = dudt2(u,p,st)[1]
48+
49+
# Simulation interval and intermediary points
50+
tspan = (0f0, 10f0)
51+
tsteps = 0f0:1f-1:10f0
52+
53+
prob_gpu = ODEProblem(dudt2_, u0, tspan, p)
54+
55+
# Runs on a GPU
56+
sol_gpu = solve(prob_gpu, Tsit5(), saveat = tsteps)
57+
```
58+
59+
or via the NeuralODE struct:
60+
61+
```julia
62+
prob_neuralode_gpu = NeuralODE(dudt2, tspan, Tsit5(), saveat = tsteps)
63+
prob_neuralode_gpu(u0,p,st)
64+
```
65+
66+
## Neural ODE Example
67+
68+
Here is the full neural ODE example. Note that we use the `gpu` function so that the
69+
same code works on CPUs and GPUs, dependent on `using CUDA`.
70+
71+
```julia
72+
using Flux, DiffEqFlux, Optimization, OptimizationFlux, Zygote,
73+
OrdinaryDiffEq, Plots, CUDA, DiffEqSensitivity, Random, ComponentArrays
74+
CUDA.allowscalar(false) # Makes sure no slow operations are occuring
75+
76+
#rng for Lux.setup
77+
rng = Random.default_rng()
78+
# Generate Data
79+
u0 = Float32[2.0; 0.0]
80+
datasize = 30
81+
tspan = (0.0f0, 1.5f0)
82+
tsteps = range(tspan[1], tspan[2], length = datasize)
83+
function trueODEfunc(du, u, p, t)
84+
true_A = [-0.1 2.0; -2.0 -0.1]
85+
du .= ((u.^3)'true_A)'
86+
end
87+
prob_trueode = ODEProblem(trueODEfunc, u0, tspan)
88+
# Make the data into a GPU-based array if the user has a GPU
89+
ode_data = gpu(solve(prob_trueode, Tsit5(), saveat = tsteps))
90+
91+
92+
dudt2 = Chain(x -> x.^3, Dense(2, 50, tanh), Dense(50, 2)) |> gpu
93+
u0 = Float32[2.0; 0.0] |> gpu
94+
prob_neuralode = NeuralODE(dudt2, tspan, Tsit5(), saveat = tsteps)
95+
96+
function predict_neuralode(p)
97+
gpu(prob_neuralode(u0,p))
98+
end
99+
function loss_neuralode(p)
100+
pred = predict_neuralode(p)
101+
loss = sum(abs2, ode_data .- pred)
102+
return loss, pred
103+
end
104+
# Callback function to observe training
105+
list_plots = []
106+
iter = 0
107+
callback = function (p, l, pred; doplot = false)
108+
global list_plots, iter
109+
if iter == 0
110+
list_plots = []
111+
end
112+
iter += 1
113+
display(l)
114+
# plot current prediction against data
115+
plt = scatter(tsteps, Array(ode_data[1,:]), label = "data")
116+
scatter!(plt, tsteps, Array(pred[1,:]), label = "prediction")
117+
push!(list_plots, plt)
118+
if doplot
119+
display(plot(plt))
120+
end
121+
return false
122+
end
123+
124+
adtype = Optimization.AutoZygote()
125+
optf = Optimization.OptimizationFunction((x,p)->loss_neuralode(x), adtype)
126+
optprob = Optimization.OptimizationProblem(optf, prob_neuralode.p)
127+
result_neuralode = Optimization.solve(optprob,ADAM(0.05),callback = callback,maxiters = 300)
128+
```

0 commit comments

Comments
 (0)