@@ -6,7 +6,6 @@ xs = Float32.(hcat([0.; 0.], [1.; 0.], [2.; 0.]))
66tspan = (0.0f0 ,1.0f0 )
77dudt = Flux. Chain (Flux. Dense (2 ,50 ,tanh),Flux. Dense (50 ,2 ))
88fastdudt = FastChain (FastDense (2 ,50 ,tanh),FastDense (50 ,2 ))
9- fastcdudt = FastChain (FastDense (2 ,50 ,tanh,precache= true ,numcols= size (xs)[2 ]),FastDense (50 ,2 ,precache= true ,numcols= size (xs)[2 ]))
109
1110NeuralODE (dudt,tspan,Tsit5 (),save_everystep= false ,save_start= false )(x)
1211NeuralODE (dudt,tspan,Tsit5 (),saveat= 0.1 )(x)
@@ -68,13 +67,6 @@ gradsnc2 = Zygote.gradient(()->sum(node(xs)),Flux.params(xs,node))
6867@test ! iszero (gradsnc2[xs])
6968@test ! iszero (gradsnc2[node. p])
7069
71- nodec = NeuralODE (fastcdudt,tspan,Tsit5 (),save_everystep= false ,save_start= false ,p= pd)
72- gradsc = Zygote. gradient (()-> sum (nodec (x)),Flux. params (x,nodec)) # with cache
73- @test ! iszero (gradsc[x])
74- @test ! iszero (gradsc[nodec. p])
75- @test gradsnc[x] ≈ gradsc[x] rtol= 1e-6
76- @test gradsnc[node. p] ≈ gradsc[nodec. p] rtol= 1e-6
77-
7870gradsc2 = Zygote. gradient (()-> sum (nodec (xs)),Flux. params (xs,nodec))
7971@test ! iszero (gradsc2[xs])
8072@test ! iszero (gradsc2[nodec. p])
@@ -87,11 +79,6 @@ gradsnc = Zygote.gradient(()->sum(node(x)),Flux.params(x,node))
8779@test ! iszero (gradsnc[x])
8880@test ! iszero (gradsnc[node. p])
8981
90- nodec = NeuralODE (fastcdudt, tspan, Tsit5 (), abstol= 1e-12 , reltol= 1e-12 , save_everystep= false , save_start= false ,p= pd)
91- gradsc = Zygote. gradient (()-> sum (nodec (x)),Flux. params (x,nodec))
92- @test gradsnc[x] ≈ gradsc[x] rtol= 1e-3
93- @test gradsnc[node. p] ≈ gradsc[nodec. p] rtol= 1e-3
94-
9582node = NeuralODE (fastdudt,tspan,Tsit5 (),save_everystep= false ,save_start= false ,sensealg= TrackerAdjoint ())
9683grads = Zygote. gradient (()-> sum (node (x)),Flux. params (x,node))
9784@test ! iszero (grads[x])
@@ -104,11 +91,6 @@ grads = Zygote.gradient(()->sum(node(xs)),Flux.params(xs,node))
10491goodgrad = grads[node. p]
10592p = node. p
10693
107- node = NeuralODE (fastcdudt,tspan,Tsit5 (),save_everystep= false ,save_start= false ,sensealg= TrackerAdjoint ())
108- grads = Zygote. gradient (()-> sum (node (x)),Flux. params (x,node))
109- @test ! iszero (grads[x])
110- @test ! iszero (grads[node. p])
111-
11294grads = Zygote. gradient (()-> sum (node (xs)),Flux. params (xs,node))
11395@test ! iszero (grads[xs])
11496@test ! iszero (grads[node. p])
@@ -127,23 +109,6 @@ grads = Zygote.gradient(()->sum(node(xs)),Flux.params(xs,node))
127109goodgrad2 = grads[node. p]
128110@test goodgrad ≈ goodgrad2 # Make sure adjoint overloads are correct
129111
130- node = NeuralODE (fastcdudt,tspan,Tsit5 (),save_everystep= false ,save_start= false , sensealg= BacksolveAdjoint (),p= pc)
131- grads = Zygote. gradient (()-> sum (node (x)),Flux. params (x,node))
132- @test ! iszero (grads[x])
133- @test ! iszero (grads[node. p])
134-
135- grads = Zygote. gradient (()-> sum (node (xs)),Flux. params (xs,node))
136- @test ! iszero (grads[xs])
137- @test ! iszero (grads[node. p])
138- goodgrad2 = grads[node. p]
139- @test goodgradc ≈ goodgrad2 rtol= 1e-6
140-
141- grads = Zygote. gradient (()-> sum (node (xs)),Flux. params (xs,node))
142- goodgrad2 = grads[node. p]
143- @test goodgrad ≈ goodgrad2 rtol = 1e-6
144-
145- @test_throws ErrorException grads = Zygote. gradient (()-> sum (node (xs)),Flux. params (xs,node))
146-
147112@info " Test some adjoints"
148113
149114# Adjoint
@@ -184,15 +149,6 @@ goodgrad2 = grads[node.p]
184149 @test_broken ! iszero (grads[xs])
185150 @test_broken ! iszero (grads[node. p])
186151
187- node = NeuralODE (fastcdudt,tspan,Tsit5 (),save_everystep= false ,save_start= false )
188- grads = Zygote. gradient (()-> sum (node (x)),Flux. params (x,node))
189- @test ! iszero (grads[x])
190- @test ! iszero (grads[node. p])
191-
192- @test_broken grads = Zygote. gradient (()-> sum (node (xs)),Flux. params (xs,node)) isa Tuple
193- @test_broken ! iszero (grads[xs])
194- @test_broken ! iszero (grads[node. p])
195-
196152 node = NeuralODE (fastdudt,tspan,Tsit5 (),saveat= 0.0 : 0.1 : 1.0 )
197153 grads = Zygote. gradient (()-> sum (node (x)),Flux. params (x,node))
198154 @test ! iszero (grads[x])
@@ -202,15 +158,6 @@ goodgrad2 = grads[node.p]
202158 @test_broken ! iszero (grads[xs])
203159 @test_broken ! iszero (grads[node. p])
204160
205- node = NeuralODE (fastcdudt,tspan,Tsit5 (),saveat= 0.0 : 0.1 : 1.0 )
206- grads = Zygote. gradient (()-> sum (node (x)),Flux. params (x,node))
207- @test ! iszero (grads[x])
208- @test ! iszero (grads[node. p])
209-
210- @test_broken grads = Zygote. gradient (()-> sum (node (xs)),Flux. params (xs,node)) isa Tuple
211- @test_broken ! iszero (grads[xs])
212- @test_broken ! iszero (grads[node. p])
213-
214161 node = NeuralODE (fastdudt,tspan,Tsit5 (),saveat= 0.1 )
215162 grads = Zygote. gradient (()-> sum (node (x)),Flux. params (x,node))
216163 @test ! iszero (grads[x])
@@ -219,15 +166,6 @@ goodgrad2 = grads[node.p]
219166 @test_broken grads = Zygote. gradient (()-> sum (node (xs)),Flux. params (xs,node)) isa Tuple
220167 @test_broken ! iszero (grads[xs])
221168 @test_broken ! iszero (grads[node. p])
222-
223- node = NeuralODE (fastcdudt,tspan,Tsit5 (),saveat= 0.1 )
224- grads = Zygote. gradient (()-> sum (node (x)),Flux. params (x,node))
225- @test ! iszero (grads[x])
226- @test ! iszero (grads[node. p])
227-
228- @test_broken grads = Zygote. gradient (()-> sum (node (xs)),Flux. params (xs,node)) isa Tuple
229- @test_broken ! iszero (grads[xs])
230- @test_broken ! iszero (grads[node. p])
231169end
232170
233171@info " Test Tracker"
270208 @test_broken ! iszero (grads[xs])
271209 @test_broken ! iszero (grads[node. p])
272210
273- node = NeuralODE (fastcdudt,tspan,Tsit5 (),save_everystep= false ,save_start= false ,sensealg= TrackerAdjoint ())
274- grads = Zygote. gradient (()-> sum (node (x)),Flux. params (x,node))
275- @test ! iszero (grads[x])
276- @test ! iszero (grads[node. p])
277-
278- @test_broken grads = Zygote. gradient (()-> sum (node (xs)),Flux. params (xs,node)) isa Tuple
279- @test_broken ! iszero (grads[xs])
280- @test_broken ! iszero (grads[node. p])
281-
282211 node = NeuralODE (fastdudt,tspan,Tsit5 (),saveat= 0.0 : 0.1 : 1.0 ,sensealg= TrackerAdjoint ())
283212 grads = Zygote. gradient (()-> sum (node (x)),Flux. params (x,node))
284213 @test ! iszero (grads[x])
288217 @test_broken ! iszero (grads[xs])
289218 @test_broken ! iszero (grads[node. p])
290219
291- node = NeuralODE (fastcdudt,tspan,Tsit5 (),saveat= 0.0 : 0.1 : 1.0 ,sensealg= TrackerAdjoint ())
292- grads = Zygote. gradient (()-> sum (node (x)),Flux. params (x,node))
293- @test ! iszero (grads[x])
294- @test ! iszero (grads[node. p])
295-
296- @test_broken grads = Zygote. gradient (()-> sum (node (xs)),Flux. params (xs,node)) isa Tuple
297- @test_broken ! iszero (grads[xs])
298- @test_broken ! iszero (grads[node. p])
299-
300220 node = NeuralODE (fastdudt,tspan,Tsit5 (),saveat= 0.1 ,sensealg= TrackerAdjoint ())
301221 grads = Zygote. gradient (()-> sum (node (x)),Flux. params (x,node))
302222 @test ! iszero (grads[x])
@@ -305,22 +225,12 @@ end
305225 @test_broken grads = Zygote. gradient (()-> sum (node (xs)),Flux. params (xs,node)) isa Tuple
306226 @test_broken ! iszero (grads[xs])
307227 @test_broken ! iszero (grads[node. p])
308-
309- node = NeuralODE (fastcdudt,tspan,Tsit5 (),saveat= 0.1 ,sensealg= TrackerAdjoint ())
310- grads = Zygote. gradient (()-> sum (node (x)),Flux. params (x,node))
311- @test ! iszero (grads[x])
312- @test ! iszero (grads[node. p])
313-
314- @test_broken grads = Zygote. gradient (()-> sum (node (xs)),Flux. params (xs,node)) isa Tuple
315- @test_broken ! iszero (grads[xs])
316- @test_broken ! iszero (grads[node. p])
317228end
318229
319230@info " Test non-ODEs"
320231
321232dudt2 = Flux. Chain (Flux. Dense (2 ,50 ,tanh),Flux. Dense (50 ,2 ))
322233fastdudt2 = FastChain (FastDense (2 ,50 ,tanh),FastDense (50 ,2 ))
323- fastcdudt2 = FastChain (FastDense (2 ,50 ,tanh,numcols= size (xs)[2 ],precache= true ),FastDense (50 ,2 ,numcols= size (xs)[2 ],precache= true ))
324234NeuralDSDE (dudt,dudt2,(0.0f0 ,.1f0 ),SOSRI (),saveat= 0.1 )(x)
325235sode = NeuralDSDE (dudt,dudt2,(0.0f0 ,.1f0 ),SOSRI (),saveat= 0.0 : 0.01 : 0.1 )
326236
@@ -348,15 +258,6 @@ gradsnc2 = Zygote.gradient(()->sum(sode(xs)),Flux.params(xs,sode))
348258@test ! iszero (gradsnc2[sode. p])
349259@test ! iszero (gradsnc2[sode. p][end ])
350260
351- sodec = NeuralDSDE (fastcdudt,fastcdudt2,(0.0f0 ,.1f0 ),SOSRI (),saveat= 0.0 : 0.01 : 0.1 ,p= pd)
352- Random. seed! (1234 )
353- gradsc = Zygote. gradient (()-> sum (sodec (x)),Flux. params (x,sodec))
354- @test ! iszero (gradsc[x])
355- @test ! iszero (gradsc[sodec. p])
356- @test ! iszero (gradsc[sodec. p][end ])
357- @test gradsnc[x] ≈ gradsc[x] rtol= 1e-6
358- @test gradsnc[sode. p] ≈ gradsc[sodec. p] rtol= 1e-6
359-
360261gradsc2 = Zygote. gradient (()-> sum (sodec (xs)),Flux. params (xs,sodec))
361262@test_broken gradsc2 isa Tuple
362263@test ! iszero (gradsc2[xs])
@@ -365,7 +266,6 @@ gradsc2 = Zygote.gradient(()->sum(sodec(xs)),Flux.params(xs,sodec))
365266
366267dudt22 = Flux. Chain (Flux. Dense (2 ,50 ,tanh),Flux. Dense (50 ,4 ),x-> reshape (x,2 ,2 ))
367268fastdudt22 = FastChain (FastDense (2 ,50 ,tanh),FastDense (50 ,4 ),(x,p)-> reshape (x,2 ,2 ))
368- fastcdudt22 = FastChain (FastDense (2 ,50 ,tanh,numcols= size (xs)[2 ],precache= true ),FastDense (50 ,4 ,numcols= size (xs)[2 ],precache= true ),(x,p)-> reshape (x,2 ,2 ))
369269NeuralSDE (dudt,dudt22,(0.0f0 ,.1f0 ),2 ,LambaEM (),saveat= 0.01 )(x)
370270
371271sode = NeuralSDE (dudt,dudt22,(0.0f0 ,0.1f0 ),2 ,LambaEM (),saveat= 0.0 : 0.01 : 0.1 )
@@ -393,15 +293,6 @@ gradsnc = Zygote.gradient(()->sum(sode(x)),Flux.params(x,sode))
393293@test ! iszero (gradsnc[sode. p])
394294@test ! iszero (gradsnc[sode. p][end ])
395295
396- sodec = NeuralSDE (fastcdudt,fastcdudt22,(0.0f0 ,0.1f0 ),2 ,LambaEM (),saveat= 0.0 : 0.01 : 0.1 ,p= pd)
397- Random. seed! (1234 )
398- gradsc = Zygote. gradient (()-> sum (sodec (x)),Flux. params (x,sodec))
399- @test ! iszero (gradsc[x])
400- @test ! iszero (gradsc[sodec. p])
401- @test ! iszero (gradsc[sodec. p][end ])
402- @test gradsnc[x] ≈ gradsc[x] rtol= 1e-6
403- @test gradsnc[sode. p] ≈ gradsc[sodec. p] rtol= 1e-6
404-
405296@test_broken gradsc = Zygote. gradient (()-> sum (sodec (xs)),Flux. params (xs,sodec))
406297@test_broken ! iszero (gradsc[xs])
407298@test ! iszero (gradsc[sodec. p])
0 commit comments