1- # # Storage
1+ # # Marginals
22
33"""
4- AMPStorage
4+ AMPMarginals
55
66# Fields
77
88- `û::Vector`: posterior mean of `u`, length `N`
99- `v̂::Vector`: posterior mean of `v`, length `P`
1010- `χ₊e::Dict`: messages about the marginal distribution of `u`, size `(N, N)`
1111"""
12- @kwdef struct AMPStorage {R<: Real }
12+ @kwdef struct AMPMarginals {R<: Real }
1313 û:: Vector{R}
1414 v̂:: Vector{R}
1515 χ₊e:: Dict{Tuple{Int,Int},R}
1616end
1717
18- function Base. copy (storage:: AMPStorage )
19- return AMPStorage (; û= copy (storage. û), v̂= copy (storage. v̂), χ₊e= copy (storage. χ₊e))
18+ function Base. copy (marginals:: AMPMarginals )
19+ return AMPMarginals (;
20+ û= copy (marginals. û), v̂= copy (marginals. v̂), χ₊e= copy (marginals. χ₊e)
21+ )
22+ end
23+
24+ function Base. copy! (marginals_dest:: AMPMarginals , marginals_source:: AMPMarginals )
25+ marginals_dest. û .= marginals_source. û
26+ marginals_dest. v̂ .= marginals_source. v̂
27+ copy! (marginals_dest. χ₊e, marginals_source. χ₊e)
28+ return marginals_dest
2029end
2130
22- function Base. copy! (storage_dest:: AMPStorage , storage_source:: AMPStorage )
23- storage_dest. û .= storage_source. û
24- storage_dest. v̂ .= storage_source. v̂
25- copy! (storage_dest. χ₊e, storage_source. χ₊e)
26- return storage_dest
31+ function overlaps (;
32+ u:: Vector{<:Integer} , v:: Vector{R} , û:: Vector{R} , v̂:: Vector{R}
33+ ) where {R}
34+ û .= sign .(û)
35+ û[abs .(û) .< eps (R)] .= one (R)
36+
37+ q̂ᵤ = max (freq_equalities (û, u), freq_equalities (û, - u))
38+ qᵤ = 2 * (q̂ᵤ - one (R) / 2 )
39+
40+ q̂ᵥ = max (abs (dot (v̂, v)), abs (dot (v̂, - v)))
41+ qᵥ = q̂ᵥ / (eps (R) + norm (v̂) * norm (v))
42+
43+ return (; qᵤ, qᵥ)
2744end
2845
2946"""
30- AMPTempStorage
47+ AMPStorage
3148
3249# Fields
3350
3754- `h̃₋::Vector`: individual external field for `u=-1`, length `N`
3855- `χ₊::Vector`: marginal probability of `u=1`, length `N`
3956"""
40- @kwdef struct AMPTempStorage {R<: Real }
57+ @kwdef struct AMPStorage {R<: Real }
4158 û_no_feat:: Vector{R}
4259 v̂_no_comm:: Vector{R}
4360 h̃₊:: Vector{R}
4461 h̃₋:: Vector{R}
4562 χ₊:: Vector{R}
4663end
4764
48- function Base. copy (temp_storage :: AMPTempStorage )
49- return AMPTempStorage (;
50- û_no_feat= copy (temp_storage . û_no_feat),
51- v̂_no_comm= copy (temp_storage . v̂_no_comm),
52- h̃₊= copy (temp_storage . h̃₊),
53- h̃₋= copy (temp_storage . h̃₋),
54- χ₊= copy (temp_storage . χ₊),
65+ function Base. copy (storage :: AMPStorage )
66+ return AMPStorage (;
67+ û_no_feat= copy (storage . û_no_feat),
68+ v̂_no_comm= copy (storage . v̂_no_comm),
69+ h̃₊= copy (storage . h̃₊),
70+ h̃₋= copy (storage . h̃₋),
71+ χ₊= copy (storage . χ₊),
5572 )
5673end
5774
@@ -66,42 +83,42 @@ function init_amp(
6683 (; N, P) = csbm
6784 (; g, Ξ) = observations
6885
69- û = prior₊ .(R, Ξ) + init_std .* randn (rng, R, N)
70- v̂ = 2 .* prior₊ .(R, Ξ) .- one (R) .+ init_std .* randn (rng, R, P)
86+ û = 2 .* prior₊ .(R, Ξ) .- one (R) . + init_std .* randn (rng, R, N)
87+ v̂ = init_std .* randn (rng, R, P)
7188 χ₊e = Dict {Tuple{Int,Int},R} ()
7289 for i in 1 : N, j in neighbors (g, i)
73- χ₊e[i, j] = ( one (R) / 2 ) + init_std * randn (rng, R)
90+ χ₊e[i, j] = prior₊ (R, Ξ[i] ) + init_std * randn (rng, R)
7491 end
75- storage = AMPStorage (; û, v̂, χ₊e)
76- next_storage = copy (storage )
92+ marginals = AMPMarginals (; û, v̂, χ₊e)
93+ next_marginals = copy (marginals )
7794
7895 û_no_feat = zeros (R, N)
7996 v̂_no_comm = zeros (R, P)
8097 h̃₊ = zeros (R, N)
8198 h̃₋ = zeros (R, N)
8299 χ₊ = zeros (R, N)
83- temp_storage = AMPTempStorage (; û_no_feat, v̂_no_comm, h̃₊, h̃₋, χ₊)
100+ storage = AMPStorage (; û_no_feat, v̂_no_comm, h̃₊, h̃₋, χ₊)
84101
85- return (; storage, next_storage, temp_storage )
102+ return (; marginals, next_marginals, storage )
86103end
87104
88105prior₊ (:: Type{R} , Ξᵢ) where {R} = Ξᵢ == 0 ? one (R) / 2 : R (Ξᵢ == 1 )
89106prior₋ (:: Type{R} , Ξᵢ) where {R} = Ξᵢ == 0 ? one (R) / 2 : R (Ξᵢ == - 1 )
90107
91108function update_amp! (
92- next_storage :: AMPStorage {R} ,
93- temp_storage :: AMPTempStorage {R} ;
94- storage :: AMPStorage {R} ,
109+ next_marginals :: AMPMarginals {R} ,
110+ storage :: AMPStorage {R} ;
111+ marginals :: AMPMarginals {R} ,
95112 observations:: ContextualSBMObservations{R} ,
96113 csbm:: ContextualSBM{R} ,
97114) where {R}
98115 (; d, λ, μ, N, P) = csbm
99116 (; g, B, Ξ) = observations
100117 (; cᵢ, cₒ) = affinities (csbm)
101118
102- ûᵗ, v̂ᵗ, χ₊eᵗ = storage . û, storage . v̂, storage . χ₊e
103- ûᵗ⁺¹, v̂ᵗ⁺¹, χ₊eᵗ⁺¹ = next_storage . û, next_storage . v̂, next_storage . χ₊e
104- (; û_no_feat, v̂_no_comm, h̃₊, h̃₋, χ₊) = temp_storage
119+ ûᵗ, v̂ᵗ, χ₊eᵗ = marginals . û, marginals . v̂, marginals . χ₊e
120+ ûᵗ⁺¹, v̂ᵗ⁺¹, χ₊eᵗ⁺¹ = next_marginals . û, next_marginals . v̂, next_marginals . χ₊e
121+ (; û_no_feat, v̂_no_comm, h̃₊, h̃₋, χ₊) = storage
105122
106123 ûₜ_sum = sum (ûᵗ)
107124 ûₜ_sum2 = sum (abs2, ûᵗ)
@@ -156,29 +173,52 @@ end
156173
157174function run_amp (
158175 rng:: AbstractRNG ;
159- observations:: ContextualSBMObservations ,
160- csbm:: ContextualSBM ,
176+ observations:: ContextualSBMObservations{R} ,
177+ csbm:: ContextualSBM{R} ,
161178 init_std:: Real = 1e-3 ,
162- iterations:: Integer = 10 ,
163- show_progress= false ,
164- )
165- (; storage, next_storage, temp_storage) = init_amp (rng; observations, csbm, init_std)
166- storage_history = [copy (storage)]
167- prog = Progress (iterations; desc= " AMP-BP" , enabled= show_progress)
168- for _ in 1 : iterations
169- update_amp! (next_storage, temp_storage; storage, observations, csbm)
170- copy! (storage, next_storage)
171- push! (storage_history, copy (storage))
172- next! (prog)
179+ max_iterations:: Integer = 200 ,
180+ convergence_threshold= 1e-3 ,
181+ recent_past= 10 ,
182+ show_progress:: Bool = false ,
183+ ) where {R}
184+ (; N, P) = csbm
185+ (; marginals, next_marginals, storage) = init_amp (rng; observations, csbm, init_std)
186+
187+ û_history = Matrix {R} (undef, N, max_iterations)
188+ v̂_history = Matrix {R} (undef, P, max_iterations)
189+ converged = false
190+ prog = Progress (max_iterations; desc= " AMP-BP" , enabled= show_progress)
191+
192+ for t in 1 : max_iterations
193+ update_amp! (next_marginals, storage; marginals, observations, csbm)
194+ copy! (marginals, next_marginals)
195+
196+ û_history[:, t] .= marginals. û
197+ v̂_history[:, t] .= marginals. v̂
198+
199+ if t <= recent_past
200+ û_recent_std = typemax (R)
201+ v̂_recent_std = typemax (R)
202+ else
203+ û_recent_std = maximum (std (view (û_history, :, (t - recent_past): t); dims= 2 ))
204+ v̂_recent_std = maximum (std (view (v̂_history, :, (t - recent_past): t); dims= 2 ))
205+ end
206+ converged = (
207+ û_recent_std < convergence_threshold && v̂_recent_std < convergence_threshold
208+ )
209+ if converged
210+ û_history = û_history[:, 1 : t]
211+ v̂_history = v̂_history[:, 1 : t]
212+ break
213+ else
214+ showvalues = [
215+ (:û_recent_std , û_recent_std),
216+ (:v̂_recent_std , v̂_recent_std),
217+ (:convergence_threshold , convergence_threshold),
218+ ]
219+ next! (prog; showvalues)
220+ end
173221 end
174- return storage_history
175- end
176222
177- function overlap (; storage:: AMPStorage , latents:: ContextualSBMLatents )
178- û = sign .(storage. û)
179- @assert all (abs .(û) .> 0.5 )
180- u = latents. u
181- q̂ᵤ = (1 / length (û)) * max (count_equalities (û, u), count_equalities (û, - u))
182- qᵤ = 2 * (q̂ᵤ - 0.5 )
183- return qᵤ
223+ return (; û_history, v̂_history, converged)
184224end
0 commit comments