5959
6060function init_amp (
6161 rng:: AbstractRNG ;
62- observations:: ContextualSBMObservations{R1} ,
63- csbm:: ContextualSBM{R2} ,
64- init_std:: R3 ,
65- ) where {R1,R2,R3}
66- R = promote_type (R1, R2, R3)
62+ observations:: ContextualSBMObservations{R} ,
63+ csbm:: ContextualSBM{R} ,
64+ init_std,
65+ ) where {R}
6766 (; N, P) = csbm
68- (; g) = observations
67+ (; g, Ξ ) = observations
6968
70- û = init_std .* randn (rng, R, N)
71- v̂ = init_std .* randn (rng, R, P)
69+ û = prior₊ .(R, Ξ) + init_std .* randn (rng, R, N)
70+ v̂ = 2 .* prior₊ .(R, Ξ) .- one (R) .+ init_std .* randn (rng, R, P)
7271 χ₊e = Dict {Tuple{Int,Int},R} ()
7372 for i in 1 : N, j in neighbors (g, i)
7473 χ₊e[i, j] = (one (R) / 2 ) + init_std * randn (rng, R)
@@ -86,15 +85,18 @@ function init_amp(
8685 return (; storage, next_storage, temp_storage)
8786end
8887
88+ prior₊ (:: Type{R} , Ξᵢ) where {R} = Ξᵢ == 0 ? one (R) / 2 : R (Ξᵢ == 1 )
89+ prior₋ (:: Type{R} , Ξᵢ) where {R} = Ξᵢ == 0 ? one (R) / 2 : R (Ξᵢ == - 1 )
90+
8991function update_amp! (
90- next_storage:: AMPStorage ,
91- temp_storage:: AMPTempStorage ;
92- storage:: AMPStorage ,
93- observations:: ContextualSBMObservations ,
92+ next_storage:: AMPStorage{R} ,
93+ temp_storage:: AMPTempStorage{R} ;
94+ storage:: AMPStorage{R} ,
95+ observations:: ContextualSBMObservations{R} ,
9496 csbm:: ContextualSBM{R} ,
9597) where {R}
9698 (; d, λ, μ, N, P) = csbm
97- (; g, B) = observations
99+ (; g, B, Ξ ) = observations
98100 (; cᵢ, cₒ) = affinities (csbm)
99101
100102 ûᵗ, v̂ᵗ, χ₊eᵗ = storage. û, storage. v̂, storage. χ₊e
@@ -109,19 +111,19 @@ function update_amp!(
109111 mul! (v̂_no_comm, B, ûᵗ)
110112 v̂_no_comm .*= sqrt (μ / N)
111113 v̂_no_comm .- = (μ / N) .* v̂ᵗ .* (N - ûₜ_sum2)
112- v̂ᵗ⁺¹ .= v̂_no_comm ./ (1 + σᵥ_no_comm)
113- σᵥ = 1 / (1 + σᵥ_no_comm)
114+ v̂ᵗ⁺¹ .= v̂_no_comm ./ (one (R) + σᵥ_no_comm)
115+ σᵥ = one (R) / (one (R) + σᵥ_no_comm)
114116
115117 # BP estimation of u
116118 mul! (û_no_feat, B' , v̂ᵗ⁺¹)
117119 û_no_feat .*= sqrt (μ / N)
118120 û_no_feat .- = (μ / (N / P)) .* σᵥ .* ûᵗ
119121
120122 # Estimation of the field h
121- h₊ = (1 / 2 N) * (cᵢ * (N + ûₜ_sum) + cₒ * (N - ûₜ_sum))
122- h₋ = (1 / 2 N) * (cₒ * (N + ûₜ_sum) + cᵢ * (N - ûₜ_sum))
123- h̃₊ .= - h₊ .+ log ( one (R) / 2 ) .+ û_no_feat
124- h̃₋ .= - h₋ .+ log ( one (R) / 2 ) .- û_no_feat
123+ h₊ = (one (R) / 2 N) * (cᵢ * (N + ûₜ_sum) + cₒ * (N - ûₜ_sum))
124+ h₋ = (one (R) / 2 N) * (cₒ * (N + ûₜ_sum) + cᵢ * (N - ûₜ_sum))
125+ h̃₊ .= - h₊ .+ log .( prior₊ .(R, Ξ) ) .+ û_no_feat
126+ h̃₋ .= - h₋ .+ log .( prior₋ .(R, Ξ) ) .- û_no_feat
125127
126128 # BP update of the marginals
127129 for i in 1 : N
@@ -147,7 +149,7 @@ function update_amp!(
147149 end
148150
149151 # BP estimation of u
150- ûᵗ⁺¹ .= 2 .* χ₊ .- 1
152+ ûᵗ⁺¹ .= 2 .* χ₊ .- one (R)
151153
152154 return nothing
153155end
@@ -158,18 +160,21 @@ function run_amp(
158160 csbm:: ContextualSBM ,
159161 init_std:: Real = 1e-3 ,
160162 iterations:: Integer = 10 ,
163+ show_progress= false ,
161164)
162165 (; storage, next_storage, temp_storage) = init_amp (rng; observations, csbm, init_std)
163166 storage_history = [copy (storage)]
164- for iter in 1 : iterations
167+ prog = Progress (iterations; desc= " AMP-BP" , enabled= show_progress)
168+ for _ in 1 : iterations
165169 update_amp! (next_storage, temp_storage; storage, observations, csbm)
166170 copy! (storage, next_storage)
167171 push! (storage_history, copy (storage))
172+ next! (prog)
168173 end
169174 return storage_history
170175end
171176
172- function evaluate_amp (; storage:: AMPStorage , latents:: ContextualSBMLatents )
177+ function overlap (; storage:: AMPStorage , latents:: ContextualSBMLatents )
173178 û = sign .(storage. û)
174179 @assert all (abs .(û) .> 0.5 )
175180 u = latents. u
0 commit comments