Skip to content

Commit ae44b4c

Browse files
removing caching tests
1 parent 5397e66 commit ae44b4c

File tree

1 file changed

+0
-109
lines changed

1 file changed

+0
-109
lines changed

test/neural_de.jl

Lines changed: 0 additions & 109 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@ xs = Float32.(hcat([0.; 0.], [1.; 0.], [2.; 0.]))
66
tspan = (0.0f0,1.0f0)
77
dudt = Flux.Chain(Flux.Dense(2,50,tanh),Flux.Dense(50,2))
88
fastdudt = FastChain(FastDense(2,50,tanh),FastDense(50,2))
9-
fastcdudt = FastChain(FastDense(2,50,tanh,precache=true,numcols=size(xs)[2]),FastDense(50,2,precache=true,numcols=size(xs)[2]))
109

1110
NeuralODE(dudt,tspan,Tsit5(),save_everystep=false,save_start=false)(x)
1211
NeuralODE(dudt,tspan,Tsit5(),saveat=0.1)(x)
@@ -68,13 +67,6 @@ gradsnc2 = Zygote.gradient(()->sum(node(xs)),Flux.params(xs,node))
6867
@test ! iszero(gradsnc2[xs])
6968
@test ! iszero(gradsnc2[node.p])
7069

71-
nodec = NeuralODE(fastcdudt,tspan,Tsit5(),save_everystep=false,save_start=false,p=pd)
72-
gradsc = Zygote.gradient(()->sum(nodec(x)),Flux.params(x,nodec)) #with cache
73-
@test ! iszero(gradsc[x])
74-
@test ! iszero(gradsc[nodec.p])
75-
@test gradsnc[x] gradsc[x] rtol=1e-6
76-
@test gradsnc[node.p] gradsc[nodec.p] rtol=1e-6
77-
7870
gradsc2 = Zygote.gradient(()->sum(nodec(xs)),Flux.params(xs,nodec))
7971
@test ! iszero(gradsc2[xs])
8072
@test ! iszero(gradsc2[nodec.p])
@@ -87,11 +79,6 @@ gradsnc = Zygote.gradient(()->sum(node(x)),Flux.params(x,node))
8779
@test ! iszero(gradsnc[x])
8880
@test ! iszero(gradsnc[node.p])
8981

90-
nodec = NeuralODE(fastcdudt, tspan, Tsit5(), abstol=1e-12, reltol=1e-12, save_everystep=false, save_start=false,p=pd)
91-
gradsc = Zygote.gradient(()->sum(nodec(x)),Flux.params(x,nodec))
92-
@test gradsnc[x] gradsc[x] rtol=1e-3
93-
@test gradsnc[node.p] gradsc[nodec.p] rtol=1e-3
94-
9582
node = NeuralODE(fastdudt,tspan,Tsit5(),save_everystep=false,save_start=false,sensealg=TrackerAdjoint())
9683
grads = Zygote.gradient(()->sum(node(x)),Flux.params(x,node))
9784
@test ! iszero(grads[x])
@@ -104,11 +91,6 @@ grads = Zygote.gradient(()->sum(node(xs)),Flux.params(xs,node))
10491
goodgrad = grads[node.p]
10592
p = node.p
10693

107-
node = NeuralODE(fastcdudt,tspan,Tsit5(),save_everystep=false,save_start=false,sensealg=TrackerAdjoint())
108-
grads = Zygote.gradient(()->sum(node(x)),Flux.params(x,node))
109-
@test ! iszero(grads[x])
110-
@test ! iszero(grads[node.p])
111-
11294
grads = Zygote.gradient(()->sum(node(xs)),Flux.params(xs,node))
11395
@test ! iszero(grads[xs])
11496
@test ! iszero(grads[node.p])
@@ -127,23 +109,6 @@ grads = Zygote.gradient(()->sum(node(xs)),Flux.params(xs,node))
127109
goodgrad2 = grads[node.p]
128110
@test goodgrad goodgrad2 # Make sure adjoint overloads are correct
129111

130-
node = NeuralODE(fastcdudt,tspan,Tsit5(),save_everystep=false,save_start=false, sensealg=BacksolveAdjoint(),p=pc)
131-
grads = Zygote.gradient(()->sum(node(x)),Flux.params(x,node))
132-
@test ! iszero(grads[x])
133-
@test ! iszero(grads[node.p])
134-
135-
grads = Zygote.gradient(()->sum(node(xs)),Flux.params(xs,node))
136-
@test !iszero(grads[xs])
137-
@test ! iszero(grads[node.p])
138-
goodgrad2 = grads[node.p]
139-
@test goodgradc goodgrad2 rtol=1e-6
140-
141-
grads = Zygote.gradient(()->sum(node(xs)),Flux.params(xs,node))
142-
goodgrad2 = grads[node.p]
143-
@test goodgrad goodgrad2 rtol = 1e-6
144-
145-
@test_throws ErrorException grads = Zygote.gradient(()->sum(node(xs)),Flux.params(xs,node))
146-
147112
@info "Test some adjoints"
148113

149114
# Adjoint
@@ -184,15 +149,6 @@ goodgrad2 = grads[node.p]
184149
@test_broken ! iszero(grads[xs])
185150
@test_broken ! iszero(grads[node.p])
186151

187-
node = NeuralODE(fastcdudt,tspan,Tsit5(),save_everystep=false,save_start=false)
188-
grads = Zygote.gradient(()->sum(node(x)),Flux.params(x,node))
189-
@test ! iszero(grads[x])
190-
@test ! iszero(grads[node.p])
191-
192-
@test_broken grads = Zygote.gradient(()->sum(node(xs)),Flux.params(xs,node)) isa Tuple
193-
@test_broken ! iszero(grads[xs])
194-
@test_broken ! iszero(grads[node.p])
195-
196152
node = NeuralODE(fastdudt,tspan,Tsit5(),saveat=0.0:0.1:1.0)
197153
grads = Zygote.gradient(()->sum(node(x)),Flux.params(x,node))
198154
@test ! iszero(grads[x])
@@ -202,15 +158,6 @@ goodgrad2 = grads[node.p]
202158
@test_broken ! iszero(grads[xs])
203159
@test_broken ! iszero(grads[node.p])
204160

205-
node = NeuralODE(fastcdudt,tspan,Tsit5(),saveat=0.0:0.1:1.0)
206-
grads = Zygote.gradient(()->sum(node(x)),Flux.params(x,node))
207-
@test ! iszero(grads[x])
208-
@test ! iszero(grads[node.p])
209-
210-
@test_broken grads = Zygote.gradient(()->sum(node(xs)),Flux.params(xs,node)) isa Tuple
211-
@test_broken ! iszero(grads[xs])
212-
@test_broken ! iszero(grads[node.p])
213-
214161
node = NeuralODE(fastdudt,tspan,Tsit5(),saveat=0.1)
215162
grads = Zygote.gradient(()->sum(node(x)),Flux.params(x,node))
216163
@test ! iszero(grads[x])
@@ -219,15 +166,6 @@ goodgrad2 = grads[node.p]
219166
@test_broken grads = Zygote.gradient(()->sum(node(xs)),Flux.params(xs,node)) isa Tuple
220167
@test_broken ! iszero(grads[xs])
221168
@test_broken ! iszero(grads[node.p])
222-
223-
node = NeuralODE(fastcdudt,tspan,Tsit5(),saveat=0.1)
224-
grads = Zygote.gradient(()->sum(node(x)),Flux.params(x,node))
225-
@test ! iszero(grads[x])
226-
@test ! iszero(grads[node.p])
227-
228-
@test_broken grads = Zygote.gradient(()->sum(node(xs)),Flux.params(xs,node)) isa Tuple
229-
@test_broken ! iszero(grads[xs])
230-
@test_broken ! iszero(grads[node.p])
231169
end
232170

233171
@info "Test Tracker"
@@ -270,15 +208,6 @@ end
270208
@test_broken ! iszero(grads[xs])
271209
@test_broken ! iszero(grads[node.p])
272210

273-
node = NeuralODE(fastcdudt,tspan,Tsit5(),save_everystep=false,save_start=false,sensealg=TrackerAdjoint())
274-
grads = Zygote.gradient(()->sum(node(x)),Flux.params(x,node))
275-
@test ! iszero(grads[x])
276-
@test ! iszero(grads[node.p])
277-
278-
@test_broken grads = Zygote.gradient(()->sum(node(xs)),Flux.params(xs,node)) isa Tuple
279-
@test_broken ! iszero(grads[xs])
280-
@test_broken ! iszero(grads[node.p])
281-
282211
node = NeuralODE(fastdudt,tspan,Tsit5(),saveat=0.0:0.1:1.0,sensealg=TrackerAdjoint())
283212
grads = Zygote.gradient(()->sum(node(x)),Flux.params(x,node))
284213
@test ! iszero(grads[x])
@@ -288,15 +217,6 @@ end
288217
@test_broken ! iszero(grads[xs])
289218
@test_broken ! iszero(grads[node.p])
290219

291-
node = NeuralODE(fastcdudt,tspan,Tsit5(),saveat=0.0:0.1:1.0,sensealg=TrackerAdjoint())
292-
grads = Zygote.gradient(()->sum(node(x)),Flux.params(x,node))
293-
@test ! iszero(grads[x])
294-
@test ! iszero(grads[node.p])
295-
296-
@test_broken grads = Zygote.gradient(()->sum(node(xs)),Flux.params(xs,node)) isa Tuple
297-
@test_broken ! iszero(grads[xs])
298-
@test_broken ! iszero(grads[node.p])
299-
300220
node = NeuralODE(fastdudt,tspan,Tsit5(),saveat=0.1,sensealg=TrackerAdjoint())
301221
grads = Zygote.gradient(()->sum(node(x)),Flux.params(x,node))
302222
@test ! iszero(grads[x])
@@ -305,22 +225,12 @@ end
305225
@test_broken grads = Zygote.gradient(()->sum(node(xs)),Flux.params(xs,node)) isa Tuple
306226
@test_broken ! iszero(grads[xs])
307227
@test_broken ! iszero(grads[node.p])
308-
309-
node = NeuralODE(fastcdudt,tspan,Tsit5(),saveat=0.1,sensealg=TrackerAdjoint())
310-
grads = Zygote.gradient(()->sum(node(x)),Flux.params(x,node))
311-
@test ! iszero(grads[x])
312-
@test ! iszero(grads[node.p])
313-
314-
@test_broken grads = Zygote.gradient(()->sum(node(xs)),Flux.params(xs,node)) isa Tuple
315-
@test_broken ! iszero(grads[xs])
316-
@test_broken ! iszero(grads[node.p])
317228
end
318229

319230
@info "Test non-ODEs"
320231

321232
dudt2 = Flux.Chain(Flux.Dense(2,50,tanh),Flux.Dense(50,2))
322233
fastdudt2 = FastChain(FastDense(2,50,tanh),FastDense(50,2))
323-
fastcdudt2 = FastChain(FastDense(2,50,tanh,numcols=size(xs)[2],precache=true),FastDense(50,2,numcols=size(xs)[2],precache=true))
324234
NeuralDSDE(dudt,dudt2,(0.0f0,.1f0),SOSRI(),saveat=0.1)(x)
325235
sode = NeuralDSDE(dudt,dudt2,(0.0f0,.1f0),SOSRI(),saveat=0.0:0.01:0.1)
326236

@@ -348,15 +258,6 @@ gradsnc2 = Zygote.gradient(()->sum(sode(xs)),Flux.params(xs,sode))
348258
@test ! iszero(gradsnc2[sode.p])
349259
@test ! iszero(gradsnc2[sode.p][end])
350260

351-
sodec = NeuralDSDE(fastcdudt,fastcdudt2,(0.0f0,.1f0),SOSRI(),saveat=0.0:0.01:0.1,p=pd)
352-
Random.seed!(1234)
353-
gradsc = Zygote.gradient(()->sum(sodec(x)),Flux.params(x,sodec))
354-
@test ! iszero(gradsc[x])
355-
@test ! iszero(gradsc[sodec.p])
356-
@test ! iszero(gradsc[sodec.p][end])
357-
@test gradsnc[x] gradsc[x] rtol=1e-6
358-
@test gradsnc[sode.p] gradsc[sodec.p] rtol=1e-6
359-
360261
gradsc2 = Zygote.gradient(()->sum(sodec(xs)),Flux.params(xs,sodec))
361262
@test_broken gradsc2 isa Tuple
362263
@test ! iszero(gradsc2[xs])
@@ -365,7 +266,6 @@ gradsc2 = Zygote.gradient(()->sum(sodec(xs)),Flux.params(xs,sodec))
365266

366267
dudt22 = Flux.Chain(Flux.Dense(2,50,tanh),Flux.Dense(50,4),x->reshape(x,2,2))
367268
fastdudt22 = FastChain(FastDense(2,50,tanh),FastDense(50,4),(x,p)->reshape(x,2,2))
368-
fastcdudt22 = FastChain(FastDense(2,50,tanh,numcols=size(xs)[2],precache=true),FastDense(50,4,numcols=size(xs)[2],precache=true),(x,p)->reshape(x,2,2))
369269
NeuralSDE(dudt,dudt22,(0.0f0,.1f0),2,LambaEM(),saveat=0.01)(x)
370270

371271
sode = NeuralSDE(dudt,dudt22,(0.0f0,0.1f0),2,LambaEM(),saveat=0.0:0.01:0.1)
@@ -393,15 +293,6 @@ gradsnc = Zygote.gradient(()->sum(sode(x)),Flux.params(x,sode))
393293
@test ! iszero(gradsnc[sode.p])
394294
@test ! iszero(gradsnc[sode.p][end])
395295

396-
sodec = NeuralSDE(fastcdudt,fastcdudt22,(0.0f0,0.1f0),2,LambaEM(),saveat=0.0:0.01:0.1,p=pd)
397-
Random.seed!(1234)
398-
gradsc = Zygote.gradient(()->sum(sodec(x)),Flux.params(x,sodec))
399-
@test ! iszero(gradsc[x])
400-
@test ! iszero(gradsc[sodec.p])
401-
@test ! iszero(gradsc[sodec.p][end])
402-
@test gradsnc[x] gradsc[x] rtol=1e-6
403-
@test gradsnc[sode.p] gradsc[sodec.p] rtol=1e-6
404-
405296
@test_broken gradsc = Zygote.gradient(()->sum(sodec(xs)),Flux.params(xs,sodec))
406297
@test_broken ! iszero(gradsc[xs])
407298
@test ! iszero(gradsc[sodec.p])

0 commit comments

Comments
 (0)