@@ -5,7 +5,7 @@ datasize = 30
55tspan = (0.0f0 ,1.5f0 )
66
77function trueODEfunc (du,u,p,t)
8- true_A = [- 0.1 2.0 ; - 2.0 - 0.1 ]
8+ true_A = Float32 [- 0.1 2.0 ; - 2.0 - 0.1 ]
99 du .= ((u.^ 3 )' true_A)'
1010end
1111t = range (tspan[1 ],tspan[2 ],length= datasize)
@@ -27,21 +27,6 @@ function fast_loss_n_ode(p)
2727 loss,pred
2828end
2929
30- staticdudt2 = FastChain ((x,p) -> x.^ 3 ,
31- StaticDense (2 ,50 ,tanh),
32- StaticDense (50 ,2 ))
33- static_n_ode = NeuralODE (staticdudt2,tspan,Tsit5 (),saveat= t)
34-
35- function static_predict_n_ode (p)
36- static_n_ode (u0,p)
37- end
38-
39- function static_loss_n_ode (p)
40- pred = static_predict_n_ode (p)
41- loss = sum (abs2,ode_data .- pred)
42- loss,pred
43- end
44-
4530dudt2 = Flux. Chain ((x) -> x.^ 3 ,
4631 Flux. Dense (2 ,50 ,tanh),
4732 Flux. Dense (50 ,2 ))
6045p = initial_params (fastdudt2)
6146_p,re = Flux. destructure (dudt2)
6247@test fastdudt2 (ones (2 ),_p) ≈ dudt2 (ones (2 ))
63- @test staticdudt2 (ones (2 ),_p) ≈ dudt2 (ones (2 ))
6448@test fast_loss_n_ode (p)[1 ] ≈ loss_n_ode (p)[1 ]
65- @test static_loss_n_ode (p)[1 ] ≈ loss_n_ode (p)[1 ]
66- @test Zygote. gradient ((p)-> fast_loss_n_ode (p)[1 ], p)[1 ] ≈ Zygote. gradient ((p)-> loss_n_ode (p)[1 ], p)[1 ] rtol= 4e-3
67- @test Zygote. gradient ((p)-> static_loss_n_ode (p)[1 ], p)[1 ] ≈ Zygote. gradient ((p)-> loss_n_ode (p)[1 ], p)[1 ] rtol= 4e-3
68-
69- # @btime Zygote.gradient((p)->static_loss_n_ode(p)[1], p)
70- # @btime Zygote.gradient((p)->fast_loss_n_ode(p)[1], p)
71- # @btime Zygote.gradient((p)->loss_n_ode(p)[1], p)
49+ @test Zygote. gradient ((p)-> fast_loss_n_ode (p)[1 ], p)[1 ] ≈ Zygote. gradient ((p)-> loss_n_ode (p)[1 ], p)[1 ] rtol= 4e-3
0 commit comments