Skip to content

Commit c5b904f

Browse files
Fix tests and remaining doc
1 parent 0b0a9c4 commit c5b904f

File tree

3 files changed

+38
-32
lines changed

3 files changed

+38
-32
lines changed

docs/src/examples/normalizing_flows.md

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -7,21 +7,21 @@ Now, we study a single layer neural network that can estimate the density `p_x`
77
Before getting to the explanation, here's some code to start with. We will
88
follow a full explanation of the definition and training process:
99

10-
```julia
10+
```@example cnf
1111
using Flux, DiffEqFlux, DifferentialEquations, Optimization, OptimizationFlux,
1212
OptimizationOptimJL, Distributions
1313
1414
nn = Flux.Chain(
1515
Flux.Dense(1, 3, tanh),
1616
Flux.Dense(3, 1, tanh),
1717
) |> f32
18-
tspan = (0.0f0, 10.0f0)
18+
tspan = (0.0f0, 1.0f0)
1919
2020
ffjord_mdl = FFJORD(nn, tspan, Tsit5())
2121
2222
# Training
2323
data_dist = Normal(6.0f0, 0.7f0)
24-
train_data = rand(data_dist, 1, 100)
24+
train_data = Float32.(rand(data_dist, 1, 100))
2525
2626
function loss(θ)
2727
logpx, λ₁, λ₂ = ffjord_mdl(train_data, θ)
@@ -57,7 +57,7 @@ new_data = rand(ffjord_dist, 100)
5757

5858
We can use DiffEqFlux.jl to define, train and output the densities computed by CNF layers. In the same way as a neural ODE, the layer takes a neural network that defines its derivative function (see [1] for a reference). A possible way to define a CNF layer, would be:
5959

60-
```julia
60+
```@example cnf2
6161
using Flux, DiffEqFlux, DifferentialEquations, Optimization, OptimizationFlux,
6262
OptimizationOptimJL, Distributions
6363
@@ -74,16 +74,17 @@ where we also pass as an input the desired timespan for which the differential e
7474

7575
### Training
7676

77-
First, let's get an array from a normal distribution as the training data
77+
First, let's get an array from a normal distribution as the training data. Note that we want the data in Float32
78+
values to match how we have setup the neural network weights and the state space of the ODE.
7879

79-
```julia
80+
```@example cnf2
8081
data_dist = Normal(6.0f0, 0.7f0)
81-
train_data = rand(data_dist, 1, 100)
82+
train_data = Float32.(rand(data_dist, 1, 100))
8283
```
8384

8485
Now we define a loss function that we wish to minimize
8586

86-
```julia
87+
```@example cnf2
8788
function loss(θ)
8889
logpx, λ₁, λ₂ = ffjord_mdl(train_data, θ)
8990
-mean(logpx)
@@ -96,7 +97,7 @@ We then train the neural network to learn the distribution of `x`.
9697

9798
Here we showcase starting the optimization with `ADAM` to more quickly find a minimum, and then honing in on the minimum by using `LBFGS`.
9899

99-
```julia
100+
```@example cnf2
100101
adtype = Optimization.AutoZygote()
101102
optf = Optimization.OptimizationFunction((x, p) -> loss(x), adtype)
102103
optprob = Optimization.OptimizationProblem(optf, ffjord_mdl.p)
@@ -108,7 +109,7 @@ res1 = Optimization.solve(optprob,
108109

109110
We then complete the training using a different optimizer starting from where `ADAM` stopped.
110111

111-
```julia
112+
```@example cnf2
112113
optprob2 = Optimization.OptimizationProblem(optf, res1.u)
113114
res2 = Optimization.solve(optprob2,
114115
Optim.LBFGS(),
@@ -119,7 +120,7 @@ res2 = Optimization.solve(optprob2,
119120

120121
For evaluating the result, we can use `totalvariation` function from `Distances.jl`. First, we compute densities using actual distribution and FFJORD model. then we use a distance function.
121122

122-
```julia
123+
```@example cnf2
123124
using Distances
124125
125126
actual_pdf = pdf.(data_dist, train_data)
@@ -131,7 +132,7 @@ train_dis = totalvariation(learned_pdf, actual_pdf) / size(train_data, 2)
131132

132133
What's more, we can generate new data by using FFJORD as a distribution in `rand`.
133134

134-
```julia
135+
```@example cnf2
135136
ffjord_dist = FFJORDDistribution(FFJORD(nn, tspan, Tsit5(); p=res2.u))
136137
new_data = rand(ffjord_dist, 100)
137138
```

docs/src/examples/tensor_layer.md

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ To obtain the training data, we solve the equation of motion using one of the
1313
solvers in `DifferentialEquations`:
1414

1515
```@example tensor
16-
using DiffEqFlux, Optimization, OptimizationOptimJL, DifferentialEquations, LinearAlgebra
16+
using DiffEqFlux, Optimization, OptimizationFlux, DifferentialEquations, LinearAlgebra
1717
k, α, β, γ = 1, 0.1, 0.2, 0.3
1818
tspan = (0.0,10.0)
1919
@@ -24,7 +24,7 @@ end
2424
2525
u0 = [1.0,0.0]
2626
ts = collect(0.0:0.1:tspan[2])
27-
prob_train = ODEProblem{true}(dxdt_train,u0,tspan,p=nothing)
27+
prob_train = ODEProblem{true}(dxdt_train,u0,tspan,p)
2828
data_train = Array(solve(prob_train,Tsit5(),saveat=ts))
2929
```
3030

@@ -49,7 +49,7 @@ end
4949
5050
α = zeros(102)
5151
52-
prob_pred = ODEProblem{true}(dxdt_pred,u0,tspan,p=nothing)
52+
prob_pred = ODEProblem{true}(dxdt_pred,u0,tspan)
5353
```
5454

5555
Note that we introduced a "cap" in the neural network term to avoid instabilities
@@ -59,9 +59,9 @@ in order to obtain a faster convergence for this particular example.
5959
Finally, we introduce the corresponding loss function:
6060

6161
```@example tensor
62-
6362
function predict_adjoint(θ)
64-
x = Array(solve(prob_pred,Tsit5(),p=θ,saveat=ts))
63+
x = Array(solve(prob_pred,Tsit5(),p=θ,saveat=ts,
64+
sensealg=InterpolatingAdjoint(autojacvec=ReverseDiffVJP(true))))
6565
end
6666
6767
function loss_adjoint(θ)
@@ -70,8 +70,13 @@ function loss_adjoint(θ)
7070
return loss
7171
end
7272
73-
function cb(θ,l)
74-
@show θ, l
73+
iter = 0
74+
function callback(θ,l)
75+
global iter
76+
iter += 1
77+
if iter%10 == 0
78+
println(l)
79+
end
7580
return false
7681
end
7782
```
@@ -82,18 +87,18 @@ and we train the network using two rounds of `ADAM`:
8287
adtype = Optimization.AutoZygote()
8388
optf = Optimization.OptimizationFunction((x,p) -> loss_adjoint(x), adtype)
8489
optprob = Optimization.OptimizationProblem(optf, α)
85-
res1 = Optimization.solve(optprob, ADAM(0.05), cb = cb, maxiters = 150)
90+
res1 = Optimization.solve(optprob, ADAM(0.05), callback = callback, maxiters = 150)
8691
8792
optprob2 = Optimization.OptimizationProblem(optf, res1.u)
88-
res2 = Optimization.solve(optprob2, ADAM(0.001), cb = cb,maxiters = 150)
93+
res2 = Optimization.solve(optprob2, ADAM(0.001), callback = callback,maxiters = 150)
8994
opt = res2.u
9095
```
9196

9297
We plot the results and we obtain a fairly accurate learned model:
9398

9499
```@example tensor
95100
using Plots
96-
data_pred = predict_adjoint(opt)
101+
data_pred = predict_adjoint(res1.u)
97102
plot(ts, data_train[1,:], label = "X (ODE)")
98103
plot!(ts, data_train[2,:], label = "V (ODE)")
99104
plot!(ts, data_pred[1,:], label = "X (NN)")

test/cnf_test.jl

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ end
1717
ffjord_mdl = FFJORD(nn, tspan, Tsit5())
1818

1919
data_dist = Beta(2.0f0, 2.0f0)
20-
train_data = rand(data_dist, 1, 100)
20+
train_data = Float32.(rand(data_dist, 1, 100))
2121

2222
function loss(θ; regularize, monte_carlo)
2323
logpx, λ₁, λ₂ = ffjord_mdl(train_data, θ; regularize, monte_carlo)
@@ -167,7 +167,7 @@ end
167167
ffjord_mdl = FFJORD(nn, tspan, Tsit5())
168168

169169
data_dist = Beta(2.0f0, 2.0f0)
170-
train_data = rand(data_dist, 1, 100)
170+
train_data = Float32.(rand(data_dist, 1, 100))
171171

172172
function loss(θ; regularize, monte_carlo)
173173
logpx, λ₁, λ₂ = ffjord_mdl(train_data, θ; regularize, monte_carlo)
@@ -196,8 +196,8 @@ end
196196
monte_carlo = false
197197

198198
data_dist = Beta(7.0f0, 7.0f0)
199-
train_data = rand(data_dist, 1, 100)
200-
test_data = rand(data_dist, 1, 100)
199+
train_data = Float32.(rand(data_dist, 1, 100))
200+
test_data = Float32.(rand(data_dist, 1, 100))
201201

202202
function loss(θ)
203203
logpx, λ₁, λ₂ = ffjord_mdl(train_data, θ; regularize, monte_carlo)
@@ -224,8 +224,8 @@ end
224224
monte_carlo = false
225225

226226
data_dist = Normal(6.0f0, 0.7f0)
227-
train_data = rand(data_dist, 1, 100)
228-
test_data = rand(data_dist, 1, 100)
227+
train_data = Float32.(rand(data_dist, 1, 100))
228+
test_data = Float32.(rand(data_dist, 1, 100))
229229

230230
function loss(θ)
231231
logpx, λ₁, λ₂ = ffjord_mdl(train_data, θ; regularize, monte_carlo)
@@ -253,8 +253,8 @@ end
253253
μ = ones(Float32, 2)
254254
Σ = Diagonal([7.0f0, 7.0f0])
255255
data_dist = MvNormal(μ, Σ)
256-
train_data = rand(data_dist, 100)
257-
test_data = rand(data_dist, 100)
256+
train_data = Float32.(rand(data_dist, 100))
257+
test_data = Float32.(rand(data_dist, 100))
258258

259259
function loss(θ)
260260
logpx, λ₁, λ₂ = ffjord_mdl(train_data, θ; regularize, monte_carlo)
@@ -282,8 +282,8 @@ end
282282
μ = ones(Float32, 2)
283283
Σ = Diagonal([7.0f0, 7.0f0])
284284
data_dist = MvNormal(μ, Σ)
285-
train_data = rand(data_dist, 100)
286-
test_data = rand(data_dist, 100)
285+
train_data = Float32.(rand(data_dist, 100))
286+
test_data = Float32.(rand(data_dist, 100))
287287

288288
function loss(θ)
289289
logpx, λ₁, λ₂ = ffjord_mdl(train_data, θ; regularize, monte_carlo)

0 commit comments

Comments
 (0)