@@ -79,8 +79,11 @@ function Base.rand(rng::AbstractRNG, csbm::ContextualSBM)
7979
8080 r = rand (rng, N, N)
8181 Is, Js = Int[], Int[]
82- for i in 1 : N, j in 1 : i
83- if ((u[i] == u[j]) && (r[i, j] < cᵢ / N)) || ((u[i] != u[j]) && (r[i, j] < cₒ / N))
82+ for i in 1 : N, j in (i + 1 ): N
83+ if (
84+ ((u[i] == u[j]) && (r[i, j] < cᵢ / N)) || # same
85+ ((u[i] != u[j]) && (r[i, j] < cₒ / N)) # diff
86+ )
8487 push! (Is, i)
8588 push! (Js, j)
8689 end
155158 χ₊:: Vector{R}
156159end
157160
161+ function Base. copy (temp_storage:: AMPTempStorage )
162+ return AMPTempStorage (;
163+ û_no_feat= copy (temp_storage. û_no_feat),
164+ v̂_no_graph= copy (temp_storage. v̂_no_graph),
165+ h̃₊= copy (temp_storage. h̃₊),
166+ h̃₋= copy (temp_storage. h̃₋),
167+ χ₊= copy (temp_storage. χ₊),
168+ )
169+ end
170+
158171function init_amp (
159172 rng:: AbstractRNG ;
160173 observations:: ContextualSBMObservations ,
@@ -191,8 +204,8 @@ function update_amp!(
191204 temp_storage:: AMPTempStorage ;
192205 storage:: AMPStorage ,
193206 observations:: ContextualSBMObservations ,
194- csbm:: ContextualSBM ,
195- )
207+ csbm:: ContextualSBM{R} ,
208+ ) where {R}
196209 (; d, λ, μ, N, P) = csbm
197210 (; B, G) = observations
198211 (; cᵢ, cₒ) = affinities (csbm)
@@ -223,18 +236,19 @@ function update_amp!(
223236 h₊ = (1 / N) * sum (cᵢ * (1 + ûᵗ[i]) / 2 + cₒ * (1 - ûᵗ[i]) / 2 for i in 1 : N)
224237 h₋ = (1 / N) * sum (cₒ * (1 + ûᵗ[i]) / 2 + cᵢ * (1 - ûᵗ[i]) / 2 for i in 1 : N)
225238 for i in 1 : N
226- h̃₊[i] = - h₊ + û_no_feat[i]
227- h̃₋[i] = - h₋ - û_no_feat[i]
239+ h̃₊[i] = - h₊ + log ( one (R) / 2 ) + û_no_feat[i]
240+ h̃₋[i] = - h₋ + log ( one (R) / 2 ) - û_no_feat[i]
228241 end
229242
230243 # BP update of the messages
231244 for i in 1 : N, j in neighbors (G, i)
232245 s_ij = h̃₊[i] - h̃₋[i]
233246 for k in neighbors (G, i)
234- k != j || continue
235- num = (cₒ + 2 λ * sqrt (d) * χ₊eᵗ[k, i])
236- den = (cᵢ - 2 λ * sqrt (d) * χ₊eᵗ[k, i])
237- s_ij += log (num / den)
247+ if k != j
248+ num = (cₒ + 2 λ * sqrt (d) * χ₊eᵗ[k, i])
249+ den = (cᵢ - 2 λ * sqrt (d) * χ₊eᵗ[k, i])
250+ s_ij += log (num / den)
251+ end
238252 end
239253 χ₊eᵗ⁺¹[i, j] = sigmoid (s_ij)
240254 end
@@ -267,19 +281,21 @@ function run_amp(
267281)
268282 (; storage, next_storage, temp_storage) = init_amp (rng; observations, csbm, init_std)
269283 storage_history = [copy (storage)]
284+ temp_storage_history = [copy (temp_storage)]
270285 @showprogress " AMP-BP" for iter in 1 : iterations
271286 update_amp! (next_storage, temp_storage; storage, observations, csbm)
272287 copy! (storage, next_storage)
273288 push! (storage_history, copy (storage))
289+ push! (temp_storage_history, copy (temp_storage))
274290 end
275291 return storage_history
276292end
277293
278294function evaluate_amp (; storage:: AMPStorage , latents:: ContextualSBMLatents )
295+ û = sign .(storage. û)
296+ @assert all (abs .(û) .> 0.5 )
279297 u = latents. u
280- N = length (u)
281- û = 2 .* Int .(storage. û .> 0 ) .- 1
282- q̂ᵤ = (1 / N) * max (count_equalities (û, u), count_equalities (û, - u))
298+ q̂ᵤ = (1 / length (û)) * max (count_equalities (û, u), count_equalities (û, - u))
283299 qᵤ = 2 * (q̂ᵤ - 0.5 )
284300 return qᵤ
285301end
0 commit comments