diff --git a/src/fit_em.jl b/src/fit_em.jl index 53aecb7..f26d0ba 100644 --- a/src/fit_em.jl +++ b/src/fit_em.jl @@ -124,9 +124,20 @@ function E_step!( ) where {T<:AbstractFloat} # evaluate likelihood for each type k for k in eachindex(dists) - LL[:, k] .= log(α[k]) .+ logpdf.(dists[k], y) + logα, distk = log(α[k]), dists[k] + if robust + isfinite(logα) || continue + for n in eachindex(y) + logp = logpdf(distk, y[n]) + isfinite(logp) || continue + LL[n, k] = logα + logp + end + else + for n in eachindex(y) + LL[n, k] = logα + logpdf(distk, y[n]) + end + end end - robust && replace!(LL, -Inf => nextfloat(-Inf), Inf => log(prevfloat(Inf))) # get posterior of each category logsumexp!(c, LL) # c[:] = logsumexp(LL, dims=2) γ[:, :] .= exp.(LL .- c) @@ -143,12 +154,20 @@ function E_step!( ) # evaluate likelihood for each type k for k in eachindex(dists) - LL[:, k] .= log(α[k]) - for n in axes(y, 2) - LL[n, k] += logpdf(dists[k], y[:, n]) + logα, distk = log(α[k]), dists[k] + if robust + isfinite(logα) || continue + for n in axes(y, 2) + logp = logpdf(distk, y[:, n]) + isfinite(logp) || continue + LL[n, k] = logα + logp + end + else + for n in axes(y, 2) + LL[n, k] = logα + logpdf(distk, y[:, n]) + end end end - robust && replace!(LL, -Inf => nextfloat(-Inf), Inf => log(prevfloat(Inf))) # get posterior of each category c[:] = logsumexp(LL, dims = 2) γ[:, :] = exp.(LL .- c) diff --git a/src/stochastic_em.jl b/src/stochastic_em.jl index 4cd4da4..c652c34 100644 --- a/src/stochastic_em.jl +++ b/src/stochastic_em.jl @@ -59,8 +59,13 @@ function fit_mle!( # M-step # using ẑ, maximize (update) the parameters α[:] = length.(cat)/N - dists[:] = [fit_mle(dists[k], y[cat[k]]) for k = 1:K] - + dists[:] = map(1:K) do k + if α[k] > 0 + fit_mle(dists[k], y[cat[k]]) + else + dists[k] + end + end # E-step # evaluate likelihood for each type k E_step!(LL, c, γ, dists, α, y; robust = robust) @@ -133,7 +138,13 @@ function fit_mle!( # M-step # using ẑ, maximize (update) the parameters α[:] = length.(cat)/N - dists[:] = [fit_mle(dists[k], y[:, cat[k]]) for k = 1:K] + dists[:] = map(1:K) do k + if α[k] > 0 + fit_mle(dists[k], y[:, cat[k]]) + else + dists[k] + end + end # E-step # evaluate likelihood for each type k diff --git a/test/runtests.jl b/test/runtests.jl index 6aaccd9..abcfc8e 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -260,3 +260,37 @@ end ẑ = predict(m, y) @test count(ẑ .== z) / N > 0.85 end + +@testset "Test robustness against dropout issue" begin + # See https://github.com/dmetivie/ExpectationMaximization.jl/issues/11 + # In this example, one of the mixture weight goes to zero outputing at iteration 3 an + # ERROR: PosDefException: matrix is not Hermitian; Cholesky factorization failed. + Random.seed!(1234) + + N = 600 + + ctrue = [[-0.3, 1], + [-0.4, 0.7], + [0.4, -0.6]] + X = reduce(hcat, [randn(length(c), N÷3) .+ c for c in ctrue]) + mix_bad_guess = MixtureModel([MvNormal([1.6, -2.4], [100 0.0; 0.0 1]), MvNormal([-1.1, -0.6], 0.01), MvNormal([0.4, 2.4], 1)]) + + fit_mle(mix_bad_guess, X, maxiter = 1) + + try # make sure our test case is problematic after two iterations without robust option + fit_mle(mix_bad_guess, X, maxiter = 20) #triggers error + @test false + catch e + @test true + end + begin + #! no error thrown, however the EM converges to some bad local maxima! + mix_mle_bad = fit_mle(mix_bad_guess, X, maxiter = 2000, robust = true) + @test true + end + begin + #! no error thrown, however the SEM has one mixture component with zero proba (remaining the same at every iteration) + mix_mle_S = fit_mle(mix_bad_guess, X, method = StochasticEM(), maxiter = 2000) + @test true + end +end \ No newline at end of file