Skip to content

Commit 48f1743

Browse files
more static removal
1 parent 9e6a6e3 commit 48f1743

File tree

1 file changed

+2
-24
lines changed

1 file changed

+2
-24
lines changed

test/fast_neural_ode.jl

Lines changed: 2 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ datasize = 30
55
tspan = (0.0f0,1.5f0)
66

77
function trueODEfunc(du,u,p,t)
8-
true_A = [-0.1 2.0; -2.0 -0.1]
8+
true_A = Float32[-0.1 2.0; -2.0 -0.1]
99
du .= ((u.^3)'true_A)'
1010
end
1111
t = range(tspan[1],tspan[2],length=datasize)
@@ -27,21 +27,6 @@ function fast_loss_n_ode(p)
2727
loss,pred
2828
end
2929

30-
staticdudt2 = FastChain((x,p) -> x.^3,
31-
StaticDense(2,50,tanh),
32-
StaticDense(50,2))
33-
static_n_ode = NeuralODE(staticdudt2,tspan,Tsit5(),saveat=t)
34-
35-
function static_predict_n_ode(p)
36-
static_n_ode(u0,p)
37-
end
38-
39-
function static_loss_n_ode(p)
40-
pred = static_predict_n_ode(p)
41-
loss = sum(abs2,ode_data .- pred)
42-
loss,pred
43-
end
44-
4530
dudt2 = Flux.Chain((x) -> x.^3,
4631
Flux.Dense(2,50,tanh),
4732
Flux.Dense(50,2))
@@ -60,12 +45,5 @@ end
6045
p = initial_params(fastdudt2)
6146
_p,re = Flux.destructure(dudt2)
6247
@test fastdudt2(ones(2),_p) dudt2(ones(2))
63-
@test staticdudt2(ones(2),_p) dudt2(ones(2))
6448
@test fast_loss_n_ode(p)[1] loss_n_ode(p)[1]
65-
@test static_loss_n_ode(p)[1] loss_n_ode(p)[1]
66-
@test Zygote.gradient((p)->fast_loss_n_ode(p)[1], p)[1] Zygote.gradient((p)->loss_n_ode(p)[1], p)[1] rtol=4e-3
67-
@test Zygote.gradient((p)->static_loss_n_ode(p)[1], p)[1] Zygote.gradient((p)->loss_n_ode(p)[1], p)[1] rtol=4e-3
68-
69-
# @btime Zygote.gradient((p)->static_loss_n_ode(p)[1], p)
70-
# @btime Zygote.gradient((p)->fast_loss_n_ode(p)[1], p)
71-
# @btime Zygote.gradient((p)->loss_n_ode(p)[1], p)
49+
@test Zygote.gradient((p)->fast_loss_n_ode(p)[1], p)[1] Zygote.gradient((p)->loss_n_ode(p)[1], p)[1] rtol=4e-3

0 commit comments

Comments
 (0)