Skip to content

Commit 12ef020

Browse files
just remove staticdudt tests
We already deprecated it, and it throws some things, and it wasn't really documented in the first place, so the best thing to do is just to drop it from the tests.
1 parent b8bb050 commit 12ef020

File tree

1 file changed

+0
-11
lines changed

1 file changed

+0
-11
lines changed

test/neural_de.jl

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@ tspan = (0.0f0,1.0f0)
77
dudt = Flux.Chain(Flux.Dense(2,50,tanh),Flux.Dense(50,2))
88
fastdudt = FastChain(FastDense(2,50,tanh),FastDense(50,2))
99
fastcdudt = FastChain(FastDense(2,50,tanh,precache=true,numcols=size(xs)[2]),FastDense(50,2,precache=true,numcols=size(xs)[2]))
10-
staticdudt = FastChain(StaticDense(2,50,tanh),StaticDense(50,2))
1110

1211
NeuralODE(dudt,tspan,Tsit5(),save_everystep=false,save_start=false)(x)
1312
NeuralODE(dudt,tspan,Tsit5(),saveat=0.1)(x)
@@ -139,20 +138,10 @@ grads = Zygote.gradient(()->sum(node(xs)),Flux.params(xs,node))
139138
goodgrad2 = grads[node.p]
140139
@test goodgradc goodgrad2 rtol=1e-6
141140

142-
node = NeuralODE(staticdudt,tspan,Tsit5(),save_everystep=false,save_start=false,sensealg=BacksolveAdjoint(autojacvec=false),p=p)
143-
grads = Zygote.gradient(()->sum(node(x)),Flux.params(x,node))
144-
@test ! iszero(grads[x])
145-
@test ! iszero(grads[node.p])
146-
147141
grads = Zygote.gradient(()->sum(node(xs)),Flux.params(xs,node))
148142
goodgrad2 = grads[node.p]
149143
@test goodgrad goodgrad2 rtol = 1e-6
150144

151-
node = NeuralODE(staticdudt,tspan,Tsit5(),save_everystep=false,save_start=false,sensealg=BacksolveAdjoint(autojacvec=ZygoteVJP()),p=p)
152-
grads = Zygote.gradient(()->sum(node(x)),Flux.params(x,node))
153-
@test !iszero(grads[x])
154-
@test !iszero(grads[node.p])
155-
156145
@test_throws ErrorException grads = Zygote.gradient(()->sum(node(xs)),Flux.params(xs,node))
157146

158147
@info "Test some adjoints"

0 commit comments

Comments
 (0)