Skip to content

Commit d367dc6

Browse files
Merge pull request #730 from SciML/staticdudt
just remove staticdudt tests
2 parents b8bb050 + 48f1743 commit d367dc6

File tree

4 files changed

+34
-207
lines changed

4 files changed

+34
-207
lines changed

test/cnf_test.jl

Lines changed: 26 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -31,25 +31,25 @@ end
3131
regularize = false
3232
monte_carlo = false
3333

34-
@test !isnothing(DiffEqFlux.sciml_train-> loss(θ; regularize, monte_carlo), ffjord_mdl.p, ADAM(0.1), adtype; callback = callback, maxiters=10))
34+
@test !isnothing(DiffEqFlux.sciml_train-> loss(θ; regularize, monte_carlo), ffjord_mdl.p, ADAM(1f-1), adtype; callback = callback, maxiters=10))
3535
end
3636
@testset "regularize=false & monte_carlo=true" begin
3737
regularize = false
3838
monte_carlo = true
3939

40-
@test !isnothing(DiffEqFlux.sciml_train-> loss(θ; regularize, monte_carlo), ffjord_mdl.p, ADAM(0.1), adtype; callback = callback, maxiters=10))
40+
@test !isnothing(DiffEqFlux.sciml_train-> loss(θ; regularize, monte_carlo), ffjord_mdl.p, ADAM(1f-1), adtype; callback = callback, maxiters=10))
4141
end
4242
@testset "regularize=true & monte_carlo=false" begin
4343
regularize = true
4444
monte_carlo = false
4545

46-
@test_broken !isnothing(DiffEqFlux.sciml_train-> loss(θ; regularize, monte_carlo), ffjord_mdl.p, ADAM(0.1), adtype; callback = callback, maxiters=10))
46+
@test_broken !isnothing(DiffEqFlux.sciml_train-> loss(θ; regularize, monte_carlo), ffjord_mdl.p, ADAM(1f-1), adtype; callback = callback, maxiters=10))
4747
end
4848
@testset "regularize=true & monte_carlo=true" begin
4949
regularize = true
5050
monte_carlo = true
5151

52-
@test !isnothing(DiffEqFlux.sciml_train-> loss(θ; regularize, monte_carlo), ffjord_mdl.p, ADAM(0.1), adtype; callback = callback, maxiters=10))
52+
@test !isnothing(DiffEqFlux.sciml_train-> loss(θ; regularize, monte_carlo), ffjord_mdl.p, ADAM(1f-1), adtype; callback = callback, maxiters=10))
5353
end
5454
end
5555
@testset "AutoReverseDiff as adtype" begin
@@ -59,25 +59,25 @@ end
5959
regularize = false
6060
monte_carlo = false
6161

62-
@test_broken !isnothing(DiffEqFlux.sciml_train-> loss(θ; regularize, monte_carlo), ffjord_mdl.p, ADAM(0.1), adtype; callback = callback, maxiters=10))
62+
@test_broken !isnothing(DiffEqFlux.sciml_train-> loss(θ; regularize, monte_carlo), ffjord_mdl.p, ADAM(1f-1), adtype; callback = callback, maxiters=10))
6363
end
6464
@testset "regularize=false & monte_carlo=true" begin
6565
regularize = false
6666
monte_carlo = true
6767

68-
@test_broken !isnothing(DiffEqFlux.sciml_train-> loss(θ; regularize, monte_carlo), ffjord_mdl.p, ADAM(0.1), adtype; callback = callback, maxiters=10))
68+
@test_broken !isnothing(DiffEqFlux.sciml_train-> loss(θ; regularize, monte_carlo), ffjord_mdl.p, ADAM(1f-1), adtype; callback = callback, maxiters=10))
6969
end
7070
@testset "regularize=true & monte_carlo=false" begin
7171
regularize = true
7272
monte_carlo = false
7373

74-
@test_broken !isnothing(DiffEqFlux.sciml_train-> loss(θ; regularize, monte_carlo), ffjord_mdl.p, ADAM(0.1), adtype; callback = callback, maxiters=10))
74+
@test_broken !isnothing(DiffEqFlux.sciml_train-> loss(θ; regularize, monte_carlo), ffjord_mdl.p, ADAM(1f-1), adtype; callback = callback, maxiters=10))
7575
end
7676
@testset "regularize=true & monte_carlo=true" begin
7777
regularize = true
7878
monte_carlo = true
7979

80-
@test_broken !isnothing(DiffEqFlux.sciml_train-> loss(θ; regularize, monte_carlo), ffjord_mdl.p, ADAM(0.1), adtype; callback = callback, maxiters=10))
80+
@test_broken !isnothing(DiffEqFlux.sciml_train-> loss(θ; regularize, monte_carlo), ffjord_mdl.p, ADAM(1f-1), adtype; callback = callback, maxiters=10))
8181
end
8282
end
8383
@testset "AutoTracker as adtype" begin
@@ -87,25 +87,25 @@ end
8787
regularize = false
8888
monte_carlo = false
8989

90-
@test_broken !isnothing(DiffEqFlux.sciml_train-> loss(θ; regularize, monte_carlo), ffjord_mdl.p, ADAM(0.1), adtype; callback = callback, maxiters=10))
90+
@test_broken !isnothing(DiffEqFlux.sciml_train-> loss(θ; regularize, monte_carlo), ffjord_mdl.p, ADAM(1f-1), adtype; callback = callback, maxiters=10))
9191
end
9292
@testset "regularize=false & monte_carlo=true" begin
9393
regularize = false
9494
monte_carlo = true
9595

96-
@test_broken !isnothing(DiffEqFlux.sciml_train-> loss(θ; regularize, monte_carlo), ffjord_mdl.p, ADAM(0.1), adtype; callback = callback, maxiters=10))
96+
@test_broken !isnothing(DiffEqFlux.sciml_train-> loss(θ; regularize, monte_carlo), ffjord_mdl.p, ADAM(1f-1), adtype; callback = callback, maxiters=10))
9797
end
9898
@testset "regularize=true & monte_carlo=false" begin
9999
regularize = true
100100
monte_carlo = false
101101

102-
@test_broken !isnothing(DiffEqFlux.sciml_train-> loss(θ; regularize, monte_carlo), ffjord_mdl.p, ADAM(0.1), adtype; callback = callback, maxiters=10))
102+
@test_broken !isnothing(DiffEqFlux.sciml_train-> loss(θ; regularize, monte_carlo), ffjord_mdl.p, ADAM(1f-1), adtype; callback = callback, maxiters=10))
103103
end
104104
@testset "regularize=true & monte_carlo=true" begin
105105
regularize = true
106106
monte_carlo = true
107107

108-
@test_broken !isnothing(DiffEqFlux.sciml_train-> loss(θ; regularize, monte_carlo), ffjord_mdl.p, ADAM(0.1), adtype; callback = callback, maxiters=10))
108+
@test_broken !isnothing(DiffEqFlux.sciml_train-> loss(θ; regularize, monte_carlo), ffjord_mdl.p, ADAM(1f-1), adtype; callback = callback, maxiters=10))
109109
end
110110
end
111111
@testset "AutoZygote as adtype" begin
@@ -115,25 +115,25 @@ end
115115
regularize = false
116116
monte_carlo = false
117117

118-
@test !isnothing(DiffEqFlux.sciml_train-> loss(θ; regularize, monte_carlo), ffjord_mdl.p, ADAM(0.1), adtype; callback = callback, maxiters=10))
118+
@test !isnothing(DiffEqFlux.sciml_train-> loss(θ; regularize, monte_carlo), ffjord_mdl.p, ADAM(1f-1), adtype; callback = callback, maxiters=10))
119119
end
120120
@testset "regularize=false & monte_carlo=true" begin
121121
regularize = false
122122
monte_carlo = true
123123

124-
@test !isnothing(DiffEqFlux.sciml_train-> loss(θ; regularize, monte_carlo), ffjord_mdl.p, ADAM(0.1), adtype; callback = callback, maxiters=10))
124+
@test !isnothing(DiffEqFlux.sciml_train-> loss(θ; regularize, monte_carlo), ffjord_mdl.p, ADAM(1f-1), adtype; callback = callback, maxiters=10))
125125
end
126126
@testset "regularize=true & monte_carlo=false" begin
127127
regularize = true
128128
monte_carlo = false
129129

130-
@test_broken !isnothing(DiffEqFlux.sciml_train-> loss(θ; regularize, monte_carlo), ffjord_mdl.p, ADAM(0.1), adtype; callback = callback, maxiters=10))
130+
@test_broken !isnothing(DiffEqFlux.sciml_train-> loss(θ; regularize, monte_carlo), ffjord_mdl.p, ADAM(1f-1), adtype; callback = callback, maxiters=10))
131131
end
132132
@testset "regularize=true & monte_carlo=true" begin
133133
regularize = true
134134
monte_carlo = true
135135

136-
@test !isnothing(DiffEqFlux.sciml_train-> loss(θ; regularize, monte_carlo), ffjord_mdl.p, ADAM(0.1), adtype; callback = callback, maxiters=10))
136+
@test !isnothing(DiffEqFlux.sciml_train-> loss(θ; regularize, monte_carlo), ffjord_mdl.p, ADAM(1f-1), adtype; callback = callback, maxiters=10))
137137
end
138138
end
139139
@testset "AutoFiniteDiff as adtype" begin
@@ -143,25 +143,25 @@ end
143143
regularize = false
144144
monte_carlo = false
145145

146-
@test !isnothing(DiffEqFlux.sciml_train-> loss(θ; regularize, monte_carlo), ffjord_mdl.p, ADAM(0.1), adtype; callback = callback, maxiters=10))
146+
@test !isnothing(DiffEqFlux.sciml_train-> loss(θ; regularize, monte_carlo), ffjord_mdl.p, ADAM(1f-1), adtype; callback = callback, maxiters=10))
147147
end
148148
@testset "regularize=false & monte_carlo=true" begin
149149
regularize = false
150150
monte_carlo = true
151151

152-
@test !isnothing(DiffEqFlux.sciml_train-> loss(θ; regularize, monte_carlo), ffjord_mdl.p, ADAM(0.1), adtype; callback = callback, maxiters=10))
152+
@test !isnothing(DiffEqFlux.sciml_train-> loss(θ; regularize, monte_carlo), ffjord_mdl.p, ADAM(1f-1), adtype; callback = callback, maxiters=10))
153153
end
154154
@testset "regularize=true & monte_carlo=false" begin
155155
regularize = true
156156
monte_carlo = false
157157

158-
@test_broken !isnothing(DiffEqFlux.sciml_train-> loss(θ; regularize, monte_carlo), ffjord_mdl.p, ADAM(0.1), adtype; callback = callback, maxiters=10))
158+
@test_broken !isnothing(DiffEqFlux.sciml_train-> loss(θ; regularize, monte_carlo), ffjord_mdl.p, ADAM(1f-1), adtype; callback = callback, maxiters=10))
159159
end
160160
@testset "regularize=true & monte_carlo=true" begin
161161
regularize = true
162162
monte_carlo = true
163163

164-
@test !isnothing(DiffEqFlux.sciml_train-> loss(θ; regularize, monte_carlo), ffjord_mdl.p, ADAM(0.1), adtype; callback = callback, maxiters=10))
164+
@test !isnothing(DiffEqFlux.sciml_train-> loss(θ; regularize, monte_carlo), ffjord_mdl.p, ADAM(1f-1), adtype; callback = callback, maxiters=10))
165165
end
166166
end
167167
end
@@ -185,7 +185,7 @@ end
185185
regularize = false
186186
monte_carlo = false
187187

188-
res = DiffEqFlux.sciml_train-> loss(θ; regularize, monte_carlo), ffjord_mdl.p, ADAM(0.1), adtype; callback= callback, maxiters=10)
188+
res = DiffEqFlux.sciml_train-> loss(θ; regularize, monte_carlo), ffjord_mdl.p, ADAM(1f-1), adtype; callback= callback, maxiters=10)
189189
ffjord_d = FFJORDDistribution(FFJORD(nn, tspan, Tsit5(); p=res.u); regularize, monte_carlo)
190190

191191
@test !isnothing(pdf(ffjord_d, train_data))
@@ -211,7 +211,7 @@ end
211211
end
212212

213213
adtype = Optimization.AutoZygote()
214-
res = DiffEqFlux.sciml_train(loss, ffjord_mdl.p, ADAM(0.1), adtype; callback= callback, maxiters=100)
214+
res = DiffEqFlux.sciml_train(loss, ffjord_mdl.p, ADAM(1f-1), adtype; callback= callback, maxiters=100)
215215

216216
actual_pdf = pdf.(data_dist, test_data)
217217
learned_pdf = exp.(ffjord_mdl(test_data, res.u; regularize, monte_carlo)[1])
@@ -239,7 +239,7 @@ end
239239
end
240240

241241
adtype = Optimization.AutoZygote()
242-
res = DiffEqFlux.sciml_train(loss, 0.01f0 * ffjord_mdl.p, ADAM(0.1), adtype; callback = callback, maxiters=100)
242+
res = DiffEqFlux.sciml_train(loss, 0.01f0 * ffjord_mdl.p, ADAM(1f-1), adtype; callback = callback, maxiters=100)
243243

244244
actual_pdf = pdf.(data_dist, test_data)
245245
learned_pdf = exp.(ffjord_mdl(test_data, res.u; regularize, monte_carlo)[1])
@@ -268,7 +268,7 @@ end
268268
end
269269

270270
adtype = Optimization.AutoZygote()
271-
res = DiffEqFlux.sciml_train(loss, 0.01f0 * ffjord_mdl.p, ADAM(0.1), adtype; callback = callback, maxiters=300)
271+
res = DiffEqFlux.sciml_train(loss, 0.01f0 * ffjord_mdl.p, ADAM(1f-1), adtype; callback = callback, maxiters=300)
272272

273273
actual_pdf = pdf(data_dist, test_data)
274274
learned_pdf = exp.(ffjord_mdl(test_data, res.u; regularize, monte_carlo)[1])
@@ -293,11 +293,11 @@ end
293293

294294
function loss(θ)
295295
logpx, λ₁, λ₂ = ffjord_mdl(train_data, θ; regularize, monte_carlo)
296-
mean(-logpx .+ 0.1 * λ₁ .+ 0.1 * λ₂)
296+
mean(-logpx .+ 1f-1 * λ₁ .+ 1f-1 * λ₂)
297297
end
298298

299299
adtype = Optimization.AutoZygote()
300-
res = DiffEqFlux.sciml_train(loss, 0.01f0 * ffjord_mdl.p, ADAM(0.1), adtype; callback = callback, maxiters=300)
300+
res = DiffEqFlux.sciml_train(loss, 0.01f0 * ffjord_mdl.p, ADAM(1f-1), adtype; callback = callback, maxiters=300)
301301

302302
actual_pdf = pdf(data_dist, test_data)
303303
learned_pdf = exp.(ffjord_mdl(test_data, res.u; regularize, monte_carlo)[1])

test/fast_neural_ode.jl

Lines changed: 2 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ datasize = 30
55
tspan = (0.0f0,1.5f0)
66

77
function trueODEfunc(du,u,p,t)
8-
true_A = [-0.1 2.0; -2.0 -0.1]
8+
true_A = Float32[-0.1 2.0; -2.0 -0.1]
99
du .= ((u.^3)'true_A)'
1010
end
1111
t = range(tspan[1],tspan[2],length=datasize)
@@ -27,21 +27,6 @@ function fast_loss_n_ode(p)
2727
loss,pred
2828
end
2929

30-
staticdudt2 = FastChain((x,p) -> x.^3,
31-
StaticDense(2,50,tanh),
32-
StaticDense(50,2))
33-
static_n_ode = NeuralODE(staticdudt2,tspan,Tsit5(),saveat=t)
34-
35-
function static_predict_n_ode(p)
36-
static_n_ode(u0,p)
37-
end
38-
39-
function static_loss_n_ode(p)
40-
pred = static_predict_n_ode(p)
41-
loss = sum(abs2,ode_data .- pred)
42-
loss,pred
43-
end
44-
4530
dudt2 = Flux.Chain((x) -> x.^3,
4631
Flux.Dense(2,50,tanh),
4732
Flux.Dense(50,2))
@@ -60,12 +45,5 @@ end
6045
p = initial_params(fastdudt2)
6146
_p,re = Flux.destructure(dudt2)
6247
@test fastdudt2(ones(2),_p) dudt2(ones(2))
63-
@test staticdudt2(ones(2),_p) dudt2(ones(2))
6448
@test fast_loss_n_ode(p)[1] loss_n_ode(p)[1]
65-
@test static_loss_n_ode(p)[1] loss_n_ode(p)[1]
66-
@test Zygote.gradient((p)->fast_loss_n_ode(p)[1], p)[1] Zygote.gradient((p)->loss_n_ode(p)[1], p)[1] rtol=4e-3
67-
@test Zygote.gradient((p)->static_loss_n_ode(p)[1], p)[1] Zygote.gradient((p)->loss_n_ode(p)[1], p)[1] rtol=4e-3
68-
69-
# @btime Zygote.gradient((p)->static_loss_n_ode(p)[1], p)
70-
# @btime Zygote.gradient((p)->fast_loss_n_ode(p)[1], p)
71-
# @btime Zygote.gradient((p)->loss_n_ode(p)[1], p)
49+
@test Zygote.gradient((p)->fast_loss_n_ode(p)[1], p)[1] Zygote.gradient((p)->loss_n_ode(p)[1], p)[1] rtol=4e-3

0 commit comments

Comments
 (0)