6666ind (s) = mod (s, 3 ) # sends 1 to 1 and -1 to 2
6767
6868function init_amp (
69- rng:: AbstractRNG ; observations:: ObservationsGLMSBM{R1} , glmsbm:: GLMSBM{R2} , init_std:: R3
69+ rng:: AbstractRNG , observations:: ObservationsGLMSBM{R1} , glmsbm:: GLMSBM{R2} ; init_std:: R3
7070) where {R1,R2,R3}
7171 R = promote_type (R1, R2, R3)
7272 (; N, M) = glmsbm
@@ -96,7 +96,7 @@ function init_amp(
9696end
9797
9898function update_amp! (
99- next_marginals:: MarginalsGLMSBM{R} ;
99+ next_marginals:: MarginalsGLMSBM{R} ,
100100 marginals:: MarginalsGLMSBM{R} ,
101101 observations:: ObservationsGLMSBM ,
102102 glmsbm:: GLMSBM ,
@@ -143,38 +143,21 @@ function update_amp!(
143143 end
144144 @views gₒᵗ⁺¹ = gₒ .(ωᵗ⁺¹, χlᵗ[1 , :], Ref (Vᵗ⁺¹))
145145 Λᵗ⁺¹ = sum (abs2, gₒᵗ⁺¹) / M
146- mul! (Γᵗ⁺¹, F, gₒᵗ⁺¹)
146+ mul! (Γᵗ⁺¹, F' , gₒᵗ⁺¹)
147147 Γᵗ⁺¹ .+ = Λᵗ⁺¹ .* ŵᵗ
148148
149149 # AMP update of the estimated marginals a, v
150- ŵᵗ⁺¹ .= fₐ .(Ref (Pʷ), Ref ( Λᵗ⁺¹) , Γᵗ⁺¹)
151- vᵗ⁺¹ .= fᵥ .(Ref (Pʷ), Ref ( Λᵗ⁺¹) , Γᵗ⁺¹)
150+ ŵᵗ⁺¹ .= fₐ .(Ref (Pʷ), Λᵗ⁺¹, Γᵗ⁺¹)
151+ vᵗ⁺¹ .= fᵥ .(Ref (Pʷ), Λᵗ⁺¹, Γᵗ⁺¹)
152152
153153 # BP update of the field h
154154 hᵗ⁺¹ = Vector {R} (undef, 2 )
155155 for s in (- 1 , 1 )
156- hᵗ⁺¹[ind (s)] = sum (
157- C[ind (s), ind (sμ)] * χᵗ[ind (sμ), μ] for μ in 1 : N for sμ in (- 1 , 1 )
158- )
156+ hᵗ⁺¹[ind (s)] =
157+ sum (C[ind (s), ind (sμ)] * χᵗ[ind (sμ), μ] for μ in 1 : N for sμ in (- 1 , 1 )) / N
159158 end
160159
161160 # BP update of the messages χe and of the marginals χ
162- for μ in 1 : N, ν in neighbors (g, μ)
163- for sμ in (- 1 , 1 )
164- χeᵗ⁺¹[ind (sμ), μ, ν] =
165- prior (R, sμ, Ξ[μ]) * exp (- hᵗ⁺¹[ind (sμ)]) * ψlᵗ⁺¹[ind (sμ), μ]
166- for η in neighbors (g, μ)
167- if η != ν
168- χeᵗ⁺¹[ind (sμ), μ, ν] *= sum (
169- C[ind (sη), ind (sμ)] * χeᵗ[ind (sη), η, μ] for sη in (- 1 , 1 )
170- )
171- end
172- end
173- end
174- normalization = χeᵗ⁺¹[1 , μ, ν] + χeᵗ⁺¹[2 , μ, ν]
175- χeᵗ⁺¹[1 , μ, ν] /= normalization
176- χeᵗ⁺¹[2 , μ, ν] /= normalization
177- end
178161
179162 for μ in 1 : N
180163 for sμ in (- 1 , 1 )
@@ -188,6 +171,16 @@ function update_amp!(
188171 @views χᵗ⁺¹[:, μ] ./= sum (χᵗ⁺¹[:, μ])
189172 end
190173
174+ for μ in 1 : N, ν in neighbors (g, μ)
175+ for sμ in (- 1 , 1 )
176+ extra_factor = sum (C[ind (sν), ind (sμ)] * χeᵗ[ind (sν), ν, μ] for sν in (- 1 , 1 ))
177+ χeᵗ⁺¹[ind (sμ), μ, ν] = χᵗ⁺¹[ind (sμ), μ] / extra_factor
178+ end
179+ normalization = χeᵗ⁺¹[1 , μ, ν] + χeᵗ⁺¹[2 , μ, ν]
180+ χeᵗ⁺¹[1 , μ, ν] /= normalization
181+ χeᵗ⁺¹[2 , μ, ν] /= normalization
182+ end
183+
191184 # BP update of the SBM-to-GLM messages χl
192185 for μ in 1 : N
193186 for sμ in (- 1 , 1 )
@@ -201,7 +194,7 @@ function update_amp!(
201194 @views χlᵗ⁺¹[:, μ] ./= sum (χlᵗ⁺¹[:, μ])
202195 end
203196
204- @views ᵗ ⁺¹ .= 2 .* χ [1 , :] .- one (R)
197+ @views ŝᵗ ⁺¹ .= 2 .* χᵗ⁺¹ [1 , :] .- one (R)
205198
206199 return nothing
207200end
@@ -211,13 +204,13 @@ function run_amp(
211204 observations:: ObservationsGLMSBM ,
212205 glmsbm:: GLMSBM ;
213206 init_std= 1e-3 ,
214- max_iterations= 200 ,
207+ max_iterations= 100 ,
215208 convergence_threshold= 1e-3 ,
216209 recent_past= 10 ,
217210 show_progress= false ,
218211)
219- (; N, M) = csbm
220- (; marginals, next_marginals) = init_amp (rng; observations, glmsbm, init_std)
212+ (; N, M) = glmsbm
213+ (; marginals, next_marginals) = init_amp (rng, observations, glmsbm; init_std)
221214
222215 R = eltype (marginals)
223216 ŝ_history = Matrix {R} (undef, N, max_iterations)
@@ -226,7 +219,7 @@ function run_amp(
226219 prog = Progress (max_iterations; desc= " AMP-BP for GLM-SBM" , enabled= show_progress)
227220
228221 for t in 1 : max_iterations
229- update_amp! (next_marginals; marginals, observations, glmsbm)
222+ update_amp! (next_marginals, marginals, observations, glmsbm)
230223 copy! (marginals, next_marginals)
231224
232225 ŝ_history[:, t] .= marginals. ŝ
@@ -236,8 +229,8 @@ function run_amp(
236229 ŝ_recent_std = typemax (R)
237230 ŵ_recent_std = typemax (R)
238231 else
239- ŝ_recent_std = maximum (std (view (ŝ_history, :, (t - recent_past): t); dims= 2 ))
240- ŵ_recent_std = maximum (std (view (ŵ_history, :, (t - recent_past): t); dims= 2 ))
232+ ŝ_recent_std = mean (std (view (ŝ_history, :, (t - recent_past): t); dims= 2 ))
233+ ŵ_recent_std = mean (std (view (ŵ_history, :, (t - recent_past): t); dims= 2 ))
241234 end
242235 converged = (
243236 ŝ_recent_std < convergence_threshold && ŵ_recent_std < convergence_threshold
261254
262255function evaluate_amp (rng:: AbstractRNG , glmsbm:: GLMSBM ; kwargs... )
263256 (; latents, observations) = rand (rng, glmsbm)
264- (; ŝ_history, ŵ_history, converged) = run_amp (rng, observations, csbm ; kwargs... )
265- qᵤ = discrete_overlap (latents. s, ŝ_history[:, end ])
266- qᵥ = continuous_overlap (latents. w, ŵ_history[:, end ])
267- return (; qᵤ, qᵥ )
257+ (; ŝ_history, ŵ_history, converged) = run_amp (rng, observations, glmsbm ; kwargs... )
258+ q_dis = discrete_overlap (latents. s, ŝ_history[:, end ])
259+ q_cont = continuous_overlap (latents. w, ŵ_history[:, end ])
260+ return (; q_dis, q_cont, converged )
268261end
0 commit comments