Skip to content

Commit d818dde

Browse files
Fix issues with Cholesky of ForwardDiff.Dual matrices
1 parent a7faaaf commit d818dde

File tree

4 files changed

+16
-9
lines changed

4 files changed

+16
-9
lines changed

src/ProbNumDiffEq.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,14 @@ cov2psdmatrix(cov::Diagonal; d) =
5656
cov2psdmatrix(cov::AbstractMatrix; d) =
5757
(@assert size(cov, 1) == size(cov, 2) == d; PSDMatrix(Matrix(cholesky(cov).U)))
5858
cov2psdmatrix(cov::PSDMatrix; d) = (@assert size(cov, 1) == size(cov, 2) == d; cov)
59+
make_hermitian_if_fowarddiff(M::AbstractMatrix) = M
60+
make_hermitian_if_fowarddiff(M::AbstractMatrix{<:ForwardDiff.Dual}) = begin
61+
# Since ForwardDiff 1.0 `ishermitian` fails for small fluctuations in the dual
62+
if !ishermitian(ForwardDiff.value.(M))
63+
error("Matrix is not Hermitian.")
64+
end
65+
return 0.5 * (M + M')
66+
end
5967

6068
"""
6169
add!(out, toadd)

src/diffusions/calibration.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ Needed for MLE diffusion estimation.
88
invquad
99
invquad(v, M::Matrix; v_cache, M_cache) = begin
1010
v_cache .= v
11+
M = make_hermitian_if_fowarddiff(M)
1112
M_chol = cholesky!(copy!(M_cache, M))
1213
ldiv!(M_chol, v_cache)
1314
dot(v, v_cache)

src/filtering/update.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,7 @@ function update!(
104104
_matmul!(K2_cache, P_p, H')
105105
end
106106

107+
_S = make_hermitian_if_fowarddiff(_S)
107108
S_chol = length(_S) == 1 ? _S[1] : cholesky!(_S)
108109
rdiv!(K, S_chol)
109110

src/solve.jl

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,10 @@
22
function DiffEqBase.__init(
33
prob::DiffEqBase.AbstractODEProblem{uType,tType,false},
44
alg::AbstractEK,
5-
timeseries_init=(),
6-
ts_init=(),
7-
ks_init=(),
8-
recompile::Type{Val{recompile_flag}}=Val{true};
5+
args...;
96
kwargs...,
10-
) where {uType,tType,recompile_flag}
7+
) where {uType,tType}
8+
# @info "Inside ProbNumDiffEq's overloaded DiffEqBase.__init function"
119
@warn "The given problem is in out-of-place form. Since the algorithms in this " *
1210
"package are written for in-place problems, it will be automatically converted."
1311
if prob.f isa DynamicalODEFunction
@@ -43,13 +41,12 @@ function DiffEqBase.__init(
4341
prob.kwargs...,
4442
)
4543
end
44+
45+
# @info "Calling DiffEqBase.__init now"
4646
return DiffEqBase.__init(
4747
_prob,
4848
alg,
49-
timeseries_init,
50-
ts_init,
51-
ks_init,
52-
recompile;
49+
args...;
5350
kwargs...,
5451
)
5552
end

0 commit comments

Comments
 (0)