Skip to content

Commit 5397e66

Browse files
update a few more
1 parent 12ef020 commit 5397e66

File tree

2 files changed

+30
-30
lines changed

2 files changed

+30
-30
lines changed

test/cnf_test.jl

Lines changed: 25 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -31,25 +31,25 @@ end
3131
regularize = false
3232
monte_carlo = false
3333

34-
@test !isnothing(DiffEqFlux.sciml_train-> loss(θ; regularize, monte_carlo), ffjord_mdl.p, ADAM(0.1), adtype; callback = callback, maxiters=10))
34+
@test !isnothing(DiffEqFlux.sciml_train-> loss(θ; regularize, monte_carlo), ffjord_mdl.p, ADAM(1f-1), adtype; callback = callback, maxiters=10))
3535
end
3636
@testset "regularize=false & monte_carlo=true" begin
3737
regularize = false
3838
monte_carlo = true
3939

40-
@test !isnothing(DiffEqFlux.sciml_train-> loss(θ; regularize, monte_carlo), ffjord_mdl.p, ADAM(0.1), adtype; callback = callback, maxiters=10))
40+
@test !isnothing(DiffEqFlux.sciml_train-> loss(θ; regularize, monte_carlo), ffjord_mdl.p, ADAM(1f-1), adtype; callback = callback, maxiters=10))
4141
end
4242
@testset "regularize=true & monte_carlo=false" begin
4343
regularize = true
4444
monte_carlo = false
4545

46-
@test_broken !isnothing(DiffEqFlux.sciml_train-> loss(θ; regularize, monte_carlo), ffjord_mdl.p, ADAM(0.1), adtype; callback = callback, maxiters=10))
46+
@test_broken !isnothing(DiffEqFlux.sciml_train-> loss(θ; regularize, monte_carlo), ffjord_mdl.p, ADAM(1f-1), adtype; callback = callback, maxiters=10))
4747
end
4848
@testset "regularize=true & monte_carlo=true" begin
4949
regularize = true
5050
monte_carlo = true
5151

52-
@test !isnothing(DiffEqFlux.sciml_train-> loss(θ; regularize, monte_carlo), ffjord_mdl.p, ADAM(0.1), adtype; callback = callback, maxiters=10))
52+
@test !isnothing(DiffEqFlux.sciml_train-> loss(θ; regularize, monte_carlo), ffjord_mdl.p, ADAM(1f-1), adtype; callback = callback, maxiters=10))
5353
end
5454
end
5555
@testset "AutoReverseDiff as adtype" begin
@@ -59,25 +59,25 @@ end
5959
regularize = false
6060
monte_carlo = false
6161

62-
@test_broken !isnothing(DiffEqFlux.sciml_train-> loss(θ; regularize, monte_carlo), ffjord_mdl.p, ADAM(0.1), adtype; callback = callback, maxiters=10))
62+
@test_broken !isnothing(DiffEqFlux.sciml_train-> loss(θ; regularize, monte_carlo), ffjord_mdl.p, ADAM(1f-1), adtype; callback = callback, maxiters=10))
6363
end
6464
@testset "regularize=false & monte_carlo=true" begin
6565
regularize = false
6666
monte_carlo = true
6767

68-
@test_broken !isnothing(DiffEqFlux.sciml_train-> loss(θ; regularize, monte_carlo), ffjord_mdl.p, ADAM(0.1), adtype; callback = callback, maxiters=10))
68+
@test_broken !isnothing(DiffEqFlux.sciml_train-> loss(θ; regularize, monte_carlo), ffjord_mdl.p, ADAM(1f-1), adtype; callback = callback, maxiters=10))
6969
end
7070
@testset "regularize=true & monte_carlo=false" begin
7171
regularize = true
7272
monte_carlo = false
7373

74-
@test_broken !isnothing(DiffEqFlux.sciml_train-> loss(θ; regularize, monte_carlo), ffjord_mdl.p, ADAM(0.1), adtype; callback = callback, maxiters=10))
74+
@test_broken !isnothing(DiffEqFlux.sciml_train-> loss(θ; regularize, monte_carlo), ffjord_mdl.p, ADAM(1f-1), adtype; callback = callback, maxiters=10))
7575
end
7676
@testset "regularize=true & monte_carlo=true" begin
7777
regularize = true
7878
monte_carlo = true
7979

80-
@test_broken !isnothing(DiffEqFlux.sciml_train-> loss(θ; regularize, monte_carlo), ffjord_mdl.p, ADAM(0.1), adtype; callback = callback, maxiters=10))
80+
@test_broken !isnothing(DiffEqFlux.sciml_train-> loss(θ; regularize, monte_carlo), ffjord_mdl.p, ADAM(1f-1), adtype; callback = callback, maxiters=10))
8181
end
8282
end
8383
@testset "AutoTracker as adtype" begin
@@ -87,25 +87,25 @@ end
8787
regularize = false
8888
monte_carlo = false
8989

90-
@test_broken !isnothing(DiffEqFlux.sciml_train-> loss(θ; regularize, monte_carlo), ffjord_mdl.p, ADAM(0.1), adtype; callback = callback, maxiters=10))
90+
@test_broken !isnothing(DiffEqFlux.sciml_train-> loss(θ; regularize, monte_carlo), ffjord_mdl.p, ADAM(1f-1), adtype; callback = callback, maxiters=10))
9191
end
9292
@testset "regularize=false & monte_carlo=true" begin
9393
regularize = false
9494
monte_carlo = true
9595

96-
@test_broken !isnothing(DiffEqFlux.sciml_train-> loss(θ; regularize, monte_carlo), ffjord_mdl.p, ADAM(0.1), adtype; callback = callback, maxiters=10))
96+
@test_broken !isnothing(DiffEqFlux.sciml_train-> loss(θ; regularize, monte_carlo), ffjord_mdl.p, ADAM(1f-1), adtype; callback = callback, maxiters=10))
9797
end
9898
@testset "regularize=true & monte_carlo=false" begin
9999
regularize = true
100100
monte_carlo = false
101101

102-
@test_broken !isnothing(DiffEqFlux.sciml_train-> loss(θ; regularize, monte_carlo), ffjord_mdl.p, ADAM(0.1), adtype; callback = callback, maxiters=10))
102+
@test_broken !isnothing(DiffEqFlux.sciml_train-> loss(θ; regularize, monte_carlo), ffjord_mdl.p, ADAM(1f-1), adtype; callback = callback, maxiters=10))
103103
end
104104
@testset "regularize=true & monte_carlo=true" begin
105105
regularize = true
106106
monte_carlo = true
107107

108-
@test_broken !isnothing(DiffEqFlux.sciml_train-> loss(θ; regularize, monte_carlo), ffjord_mdl.p, ADAM(0.1), adtype; callback = callback, maxiters=10))
108+
@test_broken !isnothing(DiffEqFlux.sciml_train-> loss(θ; regularize, monte_carlo), ffjord_mdl.p, ADAM(1f-1), adtype; callback = callback, maxiters=10))
109109
end
110110
end
111111
@testset "AutoZygote as adtype" begin
@@ -115,25 +115,25 @@ end
115115
regularize = false
116116
monte_carlo = false
117117

118-
@test !isnothing(DiffEqFlux.sciml_train-> loss(θ; regularize, monte_carlo), ffjord_mdl.p, ADAM(0.1), adtype; callback = callback, maxiters=10))
118+
@test !isnothing(DiffEqFlux.sciml_train-> loss(θ; regularize, monte_carlo), ffjord_mdl.p, ADAM(1f-1), adtype; callback = callback, maxiters=10))
119119
end
120120
@testset "regularize=false & monte_carlo=true" begin
121121
regularize = false
122122
monte_carlo = true
123123

124-
@test !isnothing(DiffEqFlux.sciml_train-> loss(θ; regularize, monte_carlo), ffjord_mdl.p, ADAM(0.1), adtype; callback = callback, maxiters=10))
124+
@test !isnothing(DiffEqFlux.sciml_train-> loss(θ; regularize, monte_carlo), ffjord_mdl.p, ADAM(1f-1), adtype; callback = callback, maxiters=10))
125125
end
126126
@testset "regularize=true & monte_carlo=false" begin
127127
regularize = true
128128
monte_carlo = false
129129

130-
@test_broken !isnothing(DiffEqFlux.sciml_train-> loss(θ; regularize, monte_carlo), ffjord_mdl.p, ADAM(0.1), adtype; callback = callback, maxiters=10))
130+
@test_broken !isnothing(DiffEqFlux.sciml_train-> loss(θ; regularize, monte_carlo), ffjord_mdl.p, ADAM(1f-1), adtype; callback = callback, maxiters=10))
131131
end
132132
@testset "regularize=true & monte_carlo=true" begin
133133
regularize = true
134134
monte_carlo = true
135135

136-
@test !isnothing(DiffEqFlux.sciml_train-> loss(θ; regularize, monte_carlo), ffjord_mdl.p, ADAM(0.1), adtype; callback = callback, maxiters=10))
136+
@test !isnothing(DiffEqFlux.sciml_train-> loss(θ; regularize, monte_carlo), ffjord_mdl.p, ADAM(1f-1), adtype; callback = callback, maxiters=10))
137137
end
138138
end
139139
@testset "AutoFiniteDiff as adtype" begin
@@ -143,25 +143,25 @@ end
143143
regularize = false
144144
monte_carlo = false
145145

146-
@test !isnothing(DiffEqFlux.sciml_train-> loss(θ; regularize, monte_carlo), ffjord_mdl.p, ADAM(0.1), adtype; callback = callback, maxiters=10))
146+
@test !isnothing(DiffEqFlux.sciml_train-> loss(θ; regularize, monte_carlo), ffjord_mdl.p, ADAM(1f-1), adtype; callback = callback, maxiters=10))
147147
end
148148
@testset "regularize=false & monte_carlo=true" begin
149149
regularize = false
150150
monte_carlo = true
151151

152-
@test !isnothing(DiffEqFlux.sciml_train-> loss(θ; regularize, monte_carlo), ffjord_mdl.p, ADAM(0.1), adtype; callback = callback, maxiters=10))
152+
@test !isnothing(DiffEqFlux.sciml_train-> loss(θ; regularize, monte_carlo), ffjord_mdl.p, ADAM(1f-1), adtype; callback = callback, maxiters=10))
153153
end
154154
@testset "regularize=true & monte_carlo=false" begin
155155
regularize = true
156156
monte_carlo = false
157157

158-
@test_broken !isnothing(DiffEqFlux.sciml_train-> loss(θ; regularize, monte_carlo), ffjord_mdl.p, ADAM(0.1), adtype; callback = callback, maxiters=10))
158+
@test_broken !isnothing(DiffEqFlux.sciml_train-> loss(θ; regularize, monte_carlo), ffjord_mdl.p, ADAM(1f-1), adtype; callback = callback, maxiters=10))
159159
end
160160
@testset "regularize=true & monte_carlo=true" begin
161161
regularize = true
162162
monte_carlo = true
163163

164-
@test !isnothing(DiffEqFlux.sciml_train-> loss(θ; regularize, monte_carlo), ffjord_mdl.p, ADAM(0.1), adtype; callback = callback, maxiters=10))
164+
@test !isnothing(DiffEqFlux.sciml_train-> loss(θ; regularize, monte_carlo), ffjord_mdl.p, ADAM(1f-1), adtype; callback = callback, maxiters=10))
165165
end
166166
end
167167
end
@@ -185,7 +185,7 @@ end
185185
regularize = false
186186
monte_carlo = false
187187

188-
res = DiffEqFlux.sciml_train-> loss(θ; regularize, monte_carlo), ffjord_mdl.p, ADAM(0.1), adtype; callback= callback, maxiters=10)
188+
res = DiffEqFlux.sciml_train-> loss(θ; regularize, monte_carlo), ffjord_mdl.p, ADAM(1f-1), adtype; callback= callback, maxiters=10)
189189
ffjord_d = FFJORDDistribution(FFJORD(nn, tspan, Tsit5(); p=res.u); regularize, monte_carlo)
190190

191191
@test !isnothing(pdf(ffjord_d, train_data))
@@ -211,7 +211,7 @@ end
211211
end
212212

213213
adtype = Optimization.AutoZygote()
214-
res = DiffEqFlux.sciml_train(loss, ffjord_mdl.p, ADAM(0.1), adtype; callback= callback, maxiters=100)
214+
res = DiffEqFlux.sciml_train(loss, ffjord_mdl.p, ADAM(1f-1), adtype; callback= callback, maxiters=100)
215215

216216
actual_pdf = pdf.(data_dist, test_data)
217217
learned_pdf = exp.(ffjord_mdl(test_data, res.u; regularize, monte_carlo)[1])
@@ -239,7 +239,7 @@ end
239239
end
240240

241241
adtype = Optimization.AutoZygote()
242-
res = DiffEqFlux.sciml_train(loss, 0.01f0 * ffjord_mdl.p, ADAM(0.1), adtype; callback = callback, maxiters=100)
242+
res = DiffEqFlux.sciml_train(loss, 0.01f0 * ffjord_mdl.p, ADAM(1f-1), adtype; callback = callback, maxiters=100)
243243

244244
actual_pdf = pdf.(data_dist, test_data)
245245
learned_pdf = exp.(ffjord_mdl(test_data, res.u; regularize, monte_carlo)[1])
@@ -268,7 +268,7 @@ end
268268
end
269269

270270
adtype = Optimization.AutoZygote()
271-
res = DiffEqFlux.sciml_train(loss, 0.01f0 * ffjord_mdl.p, ADAM(0.1), adtype; callback = callback, maxiters=300)
271+
res = DiffEqFlux.sciml_train(loss, 0.01f0 * ffjord_mdl.p, ADAM(1f-1), adtype; callback = callback, maxiters=300)
272272

273273
actual_pdf = pdf(data_dist, test_data)
274274
learned_pdf = exp.(ffjord_mdl(test_data, res.u; regularize, monte_carlo)[1])
@@ -297,7 +297,7 @@ end
297297
end
298298

299299
adtype = Optimization.AutoZygote()
300-
res = DiffEqFlux.sciml_train(loss, 0.01f0 * ffjord_mdl.p, ADAM(0.1), adtype; callback = callback, maxiters=300)
300+
res = DiffEqFlux.sciml_train(loss, 0.01f0 * ffjord_mdl.p, ADAM(1f-1), adtype; callback = callback, maxiters=300)
301301

302302
actual_pdf = pdf(data_dist, test_data)
303303
learned_pdf = exp.(ffjord_mdl(test_data, res.u; regularize, monte_carlo)[1])

test/newton_neural_ode.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,11 @@ using DiffEqFlux, Optimization, OptimizationOptimJL, OrdinaryDiffEq, Random, Tes
44
Random.seed!(100)
55

66
n = 1 # number of ODEs
7-
tspan = (0.0, 1.0)
7+
tspan = (0f0, 1f0)
88

99
d = 5 # number of data pairs
10-
x = rand(n, 5)
11-
y = rand(n, 5)
10+
x = rand(Float32, n, 5)
11+
y = rand(Float32, n, 5)
1212

1313
cb = function (p,l)
1414
@show l
@@ -19,7 +19,7 @@ NN = Flux.Chain(Flux.Dense(n, 5n, tanh),
1919
Flux.Dense(5n, n))
2020

2121
@info "ROCK4"
22-
nODE = NeuralODE(NN, tspan, ROCK4(), reltol=1e-4, saveat=[tspan[end]])
22+
nODE = NeuralODE(NN, tspan, ROCK4(), reltol=1f-4, saveat=[tspan[end]])
2323

2424
loss_function(θ) = Flux.Losses.mse(y, nODE(x, θ)[end])
2525
l1 = loss_function(nODE.p)
@@ -33,7 +33,7 @@ NN = FastChain(FastDense(n, 5n, tanh),
3333
FastDense(5n, n))
3434

3535
@info "ROCK2"
36-
nODE = NeuralODE(NN, tspan, ROCK2(), reltol=1e-4, saveat=[tspan[end]])
36+
nODE = NeuralODE(NN, tspan, ROCK2(), reltol=1f-4, saveat=[tspan[end]])
3737

3838
loss_function(θ) = Flux.Losses.mse(y, nODE(x, θ)[end])
3939
l1 = loss_function(nODE.p)

0 commit comments

Comments
 (0)