Skip to content

Commit 0b0a9c4

Browse files
Merge branch 'master' into doctests
2 parents 412ee83 + d367dc6 commit 0b0a9c4

File tree

5 files changed

+37
-185
lines changed

5 files changed

+37
-185
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "DiffEqFlux"
22
uuid = "aae7a2af-3d4f-5e19-a356-7da93b79d9d0"
33
authors = ["Chris Rackauckas <accounts@chrisrackauckas.com>"]
4-
version = "1.49.0"
4+
version = "1.49.1"
55

66
[deps]
77
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"

src/neural_de.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,7 @@ struct NeuralDSDE{M,P,RE,M2,RE2,T,A,K} <: NeuralDELayer
159159
re2 = nothing
160160
new{typeof(model1),typeof(p),typeof(re1),typeof(model2),typeof(re2),
161161
typeof(tspan),typeof(args),typeof(kwargs)}(p,
162-
length(p1),model1,re1,model2,re2,tspan,args,kwargs)
162+
Int(1),model1,re1,model2,re2,tspan,args,kwargs)
163163
end
164164
end
165165

@@ -179,13 +179,13 @@ function (n::NeuralDSDE{M})(x,p=n.p) where {M<:FastChain}
179179
solve(prob,n.args...;sensealg=TrackerAdjoint(),n.kwargs...)
180180
end
181181

182-
function (n::NeuralDSDE{M})(x,p1,p2,st1,st2) where {M<:Lux.AbstractExplicitLayer}
182+
function (n::NeuralDSDE{M})(x,p,st1,st2) where {M<:Lux.AbstractExplicitLayer}
183183
function dudt_(u,p,t)
184-
u_, st1 = n.model1(u,p1,st1)
184+
u_, st1 = n.model1(u,p[1],st1)
185185
return u_
186186
end
187187
function g(u,p,t)
188-
u_, st2 = n.model2(u,p2,st2)
188+
u_, st2 = n.model2(u,p[2],st2)
189189
return u_
190190
end
191191

test/cnf_test.jl

Lines changed: 26 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -31,25 +31,25 @@ end
3131
regularize = false
3232
monte_carlo = false
3333

34-
@test !isnothing(DiffEqFlux.sciml_train-> loss(θ; regularize, monte_carlo), ffjord_mdl.p, ADAM(0.1), adtype; callback = callback, maxiters=10))
34+
@test !isnothing(DiffEqFlux.sciml_train-> loss(θ; regularize, monte_carlo), ffjord_mdl.p, ADAM(1f-1), adtype; callback = callback, maxiters=10))
3535
end
3636
@testset "regularize=false & monte_carlo=true" begin
3737
regularize = false
3838
monte_carlo = true
3939

40-
@test !isnothing(DiffEqFlux.sciml_train-> loss(θ; regularize, monte_carlo), ffjord_mdl.p, ADAM(0.1), adtype; callback = callback, maxiters=10))
40+
@test !isnothing(DiffEqFlux.sciml_train-> loss(θ; regularize, monte_carlo), ffjord_mdl.p, ADAM(1f-1), adtype; callback = callback, maxiters=10))
4141
end
4242
@testset "regularize=true & monte_carlo=false" begin
4343
regularize = true
4444
monte_carlo = false
4545

46-
@test_broken !isnothing(DiffEqFlux.sciml_train-> loss(θ; regularize, monte_carlo), ffjord_mdl.p, ADAM(0.1), adtype; callback = callback, maxiters=10))
46+
@test_broken !isnothing(DiffEqFlux.sciml_train-> loss(θ; regularize, monte_carlo), ffjord_mdl.p, ADAM(1f-1), adtype; callback = callback, maxiters=10))
4747
end
4848
@testset "regularize=true & monte_carlo=true" begin
4949
regularize = true
5050
monte_carlo = true
5151

52-
@test !isnothing(DiffEqFlux.sciml_train-> loss(θ; regularize, monte_carlo), ffjord_mdl.p, ADAM(0.1), adtype; callback = callback, maxiters=10))
52+
@test !isnothing(DiffEqFlux.sciml_train-> loss(θ; regularize, monte_carlo), ffjord_mdl.p, ADAM(1f-1), adtype; callback = callback, maxiters=10))
5353
end
5454
end
5555
@testset "AutoReverseDiff as adtype" begin
@@ -58,26 +58,23 @@ end
5858
@testset "regularize=false & monte_carlo=false" begin
5959
regularize = false
6060
monte_carlo = false
61-
62-
@test !isnothing(DiffEqFlux.sciml_train-> loss(θ; regularize, monte_carlo), ffjord_mdl.p, ADAM(0.1), adtype; callback = callback, maxiters=10))
61+
@test !isnothing(DiffEqFlux.sciml_train-> loss(θ; regularize, monte_carlo), ffjord_mdl.p, ADAM(1f-1), adtype; callback = callback, maxiters=10))
6362
end
6463
@testset "regularize=false & monte_carlo=true" begin
6564
regularize = false
6665
monte_carlo = true
67-
68-
@test !isnothing(DiffEqFlux.sciml_train-> loss(θ; regularize, monte_carlo), ffjord_mdl.p, ADAM(0.1), adtype; callback = callback, maxiters=10))
66+
@test !isnothing(DiffEqFlux.sciml_train-> loss(θ; regularize, monte_carlo), ffjord_mdl.p, ADAM(1f-1), adtype; callback = callback, maxiters=10))
6967
end
7068
@testset "regularize=true & monte_carlo=false" begin
7169
regularize = true
7270
monte_carlo = false
7371

74-
@test_broken !isnothing(DiffEqFlux.sciml_train-> loss(θ; regularize, monte_carlo), ffjord_mdl.p, ADAM(0.1), adtype; callback = callback, maxiters=10))
72+
@test_broken !isnothing(DiffEqFlux.sciml_train-> loss(θ; regularize, monte_carlo), ffjord_mdl.p, ADAM(1f-1), adtype; callback = callback, maxiters=10))
7573
end
7674
@testset "regularize=true & monte_carlo=true" begin
7775
regularize = true
7876
monte_carlo = true
79-
80-
@test !isnothing(DiffEqFlux.sciml_train-> loss(θ; regularize, monte_carlo), ffjord_mdl.p, ADAM(0.1), adtype; callback = callback, maxiters=10))
77+
@test !isnothing(DiffEqFlux.sciml_train-> loss(θ; regularize, monte_carlo), ffjord_mdl.p, ADAM(1f-1), adtype; callback = callback, maxiters=10))
8178
end
8279
end
8380
@testset "AutoTracker as adtype" begin
@@ -86,26 +83,23 @@ end
8683
@testset "regularize=false & monte_carlo=false" begin
8784
regularize = false
8885
monte_carlo = false
89-
90-
@test !isnothing(DiffEqFlux.sciml_train-> loss(θ; regularize, monte_carlo), ffjord_mdl.p, ADAM(0.1), adtype; callback = callback, maxiters=10))
86+
@test !isnothing(DiffEqFlux.sciml_train-> loss(θ; regularize, monte_carlo), ffjord_mdl.p, ADAM(1f-1), adtype; callback = callback, maxiters=10))
9187
end
9288
@testset "regularize=false & monte_carlo=true" begin
9389
regularize = false
9490
monte_carlo = true
95-
96-
@test !isnothing(DiffEqFlux.sciml_train-> loss(θ; regularize, monte_carlo), ffjord_mdl.p, ADAM(0.1), adtype; callback = callback, maxiters=10))
91+
@test !isnothing(DiffEqFlux.sciml_train-> loss(θ; regularize, monte_carlo), ffjord_mdl.p, ADAM(1f-1), adtype; callback = callback, maxiters=10))
9792
end
9893
@testset "regularize=true & monte_carlo=false" begin
9994
regularize = true
10095
monte_carlo = false
10196

102-
@test_broken !isnothing(DiffEqFlux.sciml_train-> loss(θ; regularize, monte_carlo), ffjord_mdl.p, ADAM(0.1), adtype; callback = callback, maxiters=10))
97+
@test_broken !isnothing(DiffEqFlux.sciml_train-> loss(θ; regularize, monte_carlo), ffjord_mdl.p, ADAM(1f-1), adtype; callback = callback, maxiters=10))
10398
end
10499
@testset "regularize=true & monte_carlo=true" begin
105100
regularize = true
106101
monte_carlo = true
107-
108-
@test !isnothing(DiffEqFlux.sciml_train-> loss(θ; regularize, monte_carlo), ffjord_mdl.p, ADAM(0.1), adtype; callback = callback, maxiters=10))
102+
@test !isnothing(DiffEqFlux.sciml_train-> loss(θ; regularize, monte_carlo), ffjord_mdl.p, ADAM(1f-1), adtype; callback = callback, maxiters=10))
109103
end
110104
end
111105
@testset "AutoZygote as adtype" begin
@@ -115,25 +109,25 @@ end
115109
regularize = false
116110
monte_carlo = false
117111

118-
@test !isnothing(DiffEqFlux.sciml_train-> loss(θ; regularize, monte_carlo), ffjord_mdl.p, ADAM(0.1), adtype; callback = callback, maxiters=10))
112+
@test !isnothing(DiffEqFlux.sciml_train-> loss(θ; regularize, monte_carlo), ffjord_mdl.p, ADAM(1f-1), adtype; callback = callback, maxiters=10))
119113
end
120114
@testset "regularize=false & monte_carlo=true" begin
121115
regularize = false
122116
monte_carlo = true
123117

124-
@test !isnothing(DiffEqFlux.sciml_train-> loss(θ; regularize, monte_carlo), ffjord_mdl.p, ADAM(0.1), adtype; callback = callback, maxiters=10))
118+
@test !isnothing(DiffEqFlux.sciml_train-> loss(θ; regularize, monte_carlo), ffjord_mdl.p, ADAM(1f-1), adtype; callback = callback, maxiters=10))
125119
end
126120
@testset "regularize=true & monte_carlo=false" begin
127121
regularize = true
128122
monte_carlo = false
129123

130-
@test_broken !isnothing(DiffEqFlux.sciml_train-> loss(θ; regularize, monte_carlo), ffjord_mdl.p, ADAM(0.1), adtype; callback = callback, maxiters=10))
124+
@test_broken !isnothing(DiffEqFlux.sciml_train-> loss(θ; regularize, monte_carlo), ffjord_mdl.p, ADAM(1f-1), adtype; callback = callback, maxiters=10))
131125
end
132126
@testset "regularize=true & monte_carlo=true" begin
133127
regularize = true
134128
monte_carlo = true
135129

136-
@test !isnothing(DiffEqFlux.sciml_train-> loss(θ; regularize, monte_carlo), ffjord_mdl.p, ADAM(0.1), adtype; callback = callback, maxiters=10))
130+
@test !isnothing(DiffEqFlux.sciml_train-> loss(θ; regularize, monte_carlo), ffjord_mdl.p, ADAM(1f-1), adtype; callback = callback, maxiters=10))
137131
end
138132
end
139133
@testset "AutoFiniteDiff as adtype" begin
@@ -143,25 +137,25 @@ end
143137
regularize = false
144138
monte_carlo = false
145139

146-
@test !isnothing(DiffEqFlux.sciml_train-> loss(θ; regularize, monte_carlo), ffjord_mdl.p, ADAM(0.1), adtype; callback = callback, maxiters=10))
140+
@test !isnothing(DiffEqFlux.sciml_train-> loss(θ; regularize, monte_carlo), ffjord_mdl.p, ADAM(1f-1), adtype; callback = callback, maxiters=10))
147141
end
148142
@testset "regularize=false & monte_carlo=true" begin
149143
regularize = false
150144
monte_carlo = true
151145

152-
@test !isnothing(DiffEqFlux.sciml_train-> loss(θ; regularize, monte_carlo), ffjord_mdl.p, ADAM(0.1), adtype; callback = callback, maxiters=10))
146+
@test !isnothing(DiffEqFlux.sciml_train-> loss(θ; regularize, monte_carlo), ffjord_mdl.p, ADAM(1f-1), adtype; callback = callback, maxiters=10))
153147
end
154148
@testset "regularize=true & monte_carlo=false" begin
155149
regularize = true
156150
monte_carlo = false
157151

158-
@test_broken !isnothing(DiffEqFlux.sciml_train-> loss(θ; regularize, monte_carlo), ffjord_mdl.p, ADAM(0.1), adtype; callback = callback, maxiters=10))
152+
@test_broken !isnothing(DiffEqFlux.sciml_train-> loss(θ; regularize, monte_carlo), ffjord_mdl.p, ADAM(1f-1), adtype; callback = callback, maxiters=10))
159153
end
160154
@testset "regularize=true & monte_carlo=true" begin
161155
regularize = true
162156
monte_carlo = true
163157

164-
@test !isnothing(DiffEqFlux.sciml_train-> loss(θ; regularize, monte_carlo), ffjord_mdl.p, ADAM(0.1), adtype; callback = callback, maxiters=10))
158+
@test !isnothing(DiffEqFlux.sciml_train-> loss(θ; regularize, monte_carlo), ffjord_mdl.p, ADAM(1f-1), adtype; callback = callback, maxiters=10))
165159
end
166160
end
167161
end
@@ -185,7 +179,7 @@ end
185179
regularize = false
186180
monte_carlo = false
187181

188-
res = DiffEqFlux.sciml_train-> loss(θ; regularize, monte_carlo), ffjord_mdl.p, ADAM(0.1), adtype; callback= callback, maxiters=10)
182+
res = DiffEqFlux.sciml_train-> loss(θ; regularize, monte_carlo), ffjord_mdl.p, ADAM(1f-1), adtype; callback= callback, maxiters=10)
189183
ffjord_d = FFJORDDistribution(FFJORD(nn, tspan, Tsit5(); p=res.u); regularize, monte_carlo)
190184

191185
@test !isnothing(pdf(ffjord_d, train_data))
@@ -211,7 +205,7 @@ end
211205
end
212206

213207
adtype = Optimization.AutoZygote()
214-
res = DiffEqFlux.sciml_train(loss, ffjord_mdl.p, ADAM(0.1), adtype; callback= callback, maxiters=100)
208+
res = DiffEqFlux.sciml_train(loss, ffjord_mdl.p, ADAM(1f-1), adtype; callback= callback, maxiters=100)
215209

216210
actual_pdf = pdf.(data_dist, test_data)
217211
learned_pdf = exp.(ffjord_mdl(test_data, res.u; regularize, monte_carlo)[1])
@@ -239,7 +233,7 @@ end
239233
end
240234

241235
adtype = Optimization.AutoZygote()
242-
res = DiffEqFlux.sciml_train(loss, 0.01f0 * ffjord_mdl.p, ADAM(0.1), adtype; callback = callback, maxiters=100)
236+
res = DiffEqFlux.sciml_train(loss, 0.01f0 * ffjord_mdl.p, ADAM(1f-1), adtype; callback = callback, maxiters=100)
243237

244238
actual_pdf = pdf.(data_dist, test_data)
245239
learned_pdf = exp.(ffjord_mdl(test_data, res.u; regularize, monte_carlo)[1])
@@ -268,7 +262,7 @@ end
268262
end
269263

270264
adtype = Optimization.AutoZygote()
271-
res = DiffEqFlux.sciml_train(loss, 0.01f0 * ffjord_mdl.p, ADAM(0.1), adtype; callback = callback, maxiters=300)
265+
res = DiffEqFlux.sciml_train(loss, 0.01f0 * ffjord_mdl.p, ADAM(1f-1), adtype; callback = callback, maxiters=300)
272266

273267
actual_pdf = pdf(data_dist, test_data)
274268
learned_pdf = exp.(ffjord_mdl(test_data, res.u; regularize, monte_carlo)[1])
@@ -293,11 +287,11 @@ end
293287

294288
function loss(θ)
295289
logpx, λ₁, λ₂ = ffjord_mdl(train_data, θ; regularize, monte_carlo)
296-
mean(-logpx .+ 0.1 * λ₁ .+ 0.1 * λ₂)
290+
mean(-logpx .+ 1f-1 * λ₁ .+ 1f-1 * λ₂)
297291
end
298292

299293
adtype = Optimization.AutoZygote()
300-
res = DiffEqFlux.sciml_train(loss, 0.01f0 * ffjord_mdl.p, ADAM(0.1), adtype; callback = callback, maxiters=300)
294+
res = DiffEqFlux.sciml_train(loss, 0.01f0 * ffjord_mdl.p, ADAM(1f-1), adtype; callback = callback, maxiters=300)
301295

302296
actual_pdf = pdf(data_dist, test_data)
303297
learned_pdf = exp.(ffjord_mdl(test_data, res.u; regularize, monte_carlo)[1])

0 commit comments

Comments
 (0)