Skip to content

Commit b3b7ae7

Browse files
authored
Fixed NeuralDSDE
Fixed dereferencing of parameters in NeuralDSDE here for SciML/SciMLSensitivity.jl#623 to work , Lux compatible constructors for all layers have been added in /pull/722
1 parent ebc8142 commit b3b7ae7

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

src/neural_de.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,7 @@ struct NeuralDSDE{M,P,RE,M2,RE2,T,A,K} <: NeuralDELayer
159159
re2 = nothing
160160
new{typeof(model1),typeof(p),typeof(re1),typeof(model2),typeof(re2),
161161
typeof(tspan),typeof(args),typeof(kwargs)}(p,
162-
length(p1),model1,re1,model2,re2,tspan,args,kwargs)
162+
Int(1),model1,re1,model2,re2,tspan,args,kwargs)
163163
end
164164
end
165165

@@ -179,13 +179,13 @@ function (n::NeuralDSDE{M})(x,p=n.p) where {M<:FastChain}
179179
solve(prob,n.args...;sensealg=TrackerAdjoint(),n.kwargs...)
180180
end
181181

182-
function (n::NeuralDSDE{M})(x,p1,p2,st1,st2) where {M<:Lux.AbstractExplicitLayer}
182+
function (n::NeuralDSDE{M})(x,p,st1,st2) where {M<:Lux.AbstractExplicitLayer}
183183
function dudt_(u,p,t)
184-
u_, st1 = n.model1(u,p1,st1)
184+
u_, st1 = n.model1(u,p[1],st1)
185185
return u_
186186
end
187187
function g(u,p,t)
188-
u_, st2 = n.model2(u,p2,st2)
188+
u_, st2 = n.model2(u,p[2],st2)
189189
return u_
190190
end
191191

0 commit comments

Comments
 (0)