Skip to content

Commit b844c9a

Browse files
committed
Bayes filter
1 parent 673f022 commit b844c9a

File tree

1 file changed

+45
-13
lines changed

1 file changed

+45
-13
lines changed

docs/src/lecture_13/ode.jl

Lines changed: 45 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -48,22 +48,22 @@ plot!(MXe[2,1:30:end],label="y",color=:red,errorbar=SXe[2,1:30:end])
4848

4949

5050
using LinearAlgebra
51-
function solve(f,x0::AbstractVector,sqΣ0, θ,dt,N,Nr)
51+
function solve(f,x0::AbstractVector,Σ0, θ,dt,N)
5252
n = length(x0)
5353
n2 = 2*length(x0)
5454
Qp = sqrt(n)*[I(n) -I(n)]
5555

5656
X = hcat([zero(x0) for i=1:N]...)
5757
S = hcat([zero(x0) for i=1:N]...)
5858
X[:,1]=x0
59-
Xp = x0 .+ sqΣ0*Qp
60-
sqΣ = sqΣ0
61-
Σ = sqΣ* sqΣ'
59+
Σ=Hermitian(Σ0)
60+
sqΣ = cholesky(Σ).L
61+
Xp = x0 .+ sqΣ*Qp
6262
S[:,1]= diag(Σ)
6363
for t=1:N-1
64-
if rem(t,Nr)==0
65-
Xp .= X[:,t] .+ sqΣ * Qp
66-
end
64+
# if rem(t,Nr)==0
65+
# Xp .= X[:,t] .+ sqΣ * Qp
66+
# end
6767
for i=1:n2 # all quadrature points
6868
Xp[:,i].=Xp[:,i] + [dt*f(Xp[1:2,i],[Xp[3,i];θ[2:end]]);0]
6969
end
@@ -72,17 +72,49 @@ function solve(f,x0::AbstractVector,sqΣ0, θ,dt,N,Nr)
7272
Σ=Matrix((Xp.-mXp)*(Xp.-mXp)'/n2)
7373
S[:,t+1]=sqrt.(diag(Σ))
7474
# @show Σ
75-
76-
sqΣ = cholesky(Σ).L
77-
7875
end
79-
X,S
76+
X,S,Xp
8077
end
8178

8279
## Extension to arbitrary
8380

84-
QX,QS=solve(f,[1.0,1.0,0.1],diagm([0.1,0.1,0.01]),θ0,0.1,1000,1)
81+
QX,QS=solve(f,[1.0,1.0,0.1],diagm([0.1,0.1,0.01]),θ0,0.1,1000)
8582
plot(QX[1,1:30:end],label="x",color=:blue,errorbar=QS[1,1:30:end])
8683
plot!(QX[2,1:30:end],label="y",color=:red,errorbar=QS[2,1:30:end])
8784

88-
savefig("LV_Quadrics.svg")
85+
savefig("LV_Quadrics.svg")
86+
87+
function filter(f,x0::AbstractVector,Σ0, θ,dt,Ne,Y,σy)
88+
XX=[]
89+
SS=[]
90+
Σ=Σ0
91+
x=x0
92+
for t=1:length(Y)
93+
@show x
94+
@show eigen(Σ)
95+
Xt,St,Xp=solve(f,x,Σ,θ,dt,Ne) # prediction
96+
Yp = Xp[1,:] # measure only the first variable
97+
mYp = mean(Yp)
98+
mXp = mean(Xp,dims=2)
99+
SYp = cov(Yp)+σy
100+
C = mean([(Xp[:,i].-mXp)*(Yp[i].-mYp) for i=1:6])
101+
G = C*inv(SYp)
102+
x = vec(Xt[:,end] + G*(Y[t]-mYp))
103+
Σ = cov(Xp,dims=2) - G*SYp*G'
104+
push!(XX,Xt)
105+
push!(SS,St)
106+
end
107+
XX,SS
108+
end
109+
110+
111+
112+
Y = X[1,100:100:end]
113+
Xh,Sh=filter(f,[1.0,1.0,0.1],diagm([0.1,0.1,0.01].^2),θ0,0.1,100,Y,0.01)
114+
XF=hcat(Xh...)
115+
SF=hcat(Sh...)
116+
117+
step=10
118+
plot([1:step:size(XF,2)],XF[1,1:step:end],label="x",color=:blue,errorbar=SF[1,1:step:end])
119+
plot!([1:step:size(XF,2)],XF[2,1:step:end],label="y",color=:red,errorbar=SF[2,1:step:end])
120+
scatter!([100:100:size(XF,2)],Y,label="measured",marker=:xcross)

0 commit comments

Comments
 (0)