Skip to content

Commit 5efcf62

Browse files
Merge pull request #725 from frankschae/static_dense_test
Update neural_de tests
2 parents 9f648dc + 4b592f3 commit 5efcf62

File tree

1 file changed

+10
-1
lines changed

1 file changed

+10
-1
lines changed

test/neural_de.jl

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,11 +139,20 @@ grads = Zygote.gradient(()->sum(node(xs)),Flux.params(xs,node))
139139
goodgrad2 = grads[node.p]
140140
@test goodgradc goodgrad2 rtol=1e-6
141141

142-
node = NeuralODE(staticdudt,tspan,Tsit5(),save_everystep=false,save_start=false,sensealg=BacksolveAdjoint(),p=p)
142+
node = NeuralODE(staticdudt,tspan,Tsit5(),save_everystep=false,save_start=false,sensealg=BacksolveAdjoint(autojacvec=false),p=p)
143143
grads = Zygote.gradient(()->sum(node(x)),Flux.params(x,node))
144144
@test ! iszero(grads[x])
145145
@test ! iszero(grads[node.p])
146146

147+
grads = Zygote.gradient(()->sum(node(xs)),Flux.params(xs,node))
148+
goodgrad2 = grads[node.p]
149+
@test goodgrad goodgrad2 rtol = 1e-6
150+
151+
node = NeuralODE(staticdudt,tspan,Tsit5(),save_everystep=false,save_start=false,sensealg=BacksolveAdjoint(autojacvec=ZygoteVJP()),p=p)
152+
grads = Zygote.gradient(()->sum(node(x)),Flux.params(x,node))
153+
@test !iszero(grads[x])
154+
@test !iszero(grads[node.p])
155+
147156
@test_throws ErrorException grads = Zygote.gradient(()->sum(node(xs)),Flux.params(xs,node))
148157

149158
@info "Test some adjoints"

0 commit comments

Comments
 (0)