55
66# Fields
77
8- - `a::Vector{R} `: length `M`
9- - `v::Vector{R} `: length `M`
10- - `Γ::Vector{R} `: length `M`
11- - `ω::Vector{R} `: length `N`
12- - `gₒ::Vector{R} `: length `N`
13- - `ψl₊ ::Vector{R }`: length `N`
14- - `χl₊::Vector{R} `: length `N `
15- - `χe₊::Dict{Tuple{Int,Int},R} `: size `(N, N) `
16- - `χ₊ ::Vector{R }`: length `N`
8+ - `a::Vector`: length `M`
9+ - `v::Vector`: length `M`
10+ - `Γ::Vector`: length `M`
11+ - `ω::Vector`: length `N`
12+ - `gₒ::Vector`: length `N`
13+ - `ψl::Vector{PlusMinusMeasure }`: length `N`
14+ - `χe::Dict{Tuple{Int,Int},PlusMinusMeasure} `: size `(N, N) `
15+ - `χl::Vector{PlusMinusMeasure} `: length `N `
16+ - `χ::Vector{PlusMinusMeasure }`: length `N`
1717
1818"""
1919@kwdef struct MarginalsGLMSBM{R<: Real }
2222 Γ:: Vector{R}
2323 ω:: Vector{R}
2424 gₒ:: Vector{R}
25- ψl₊:: Vector{R}
26- ψl₋:: Vector{R}
27- χl₊:: Vector{R}
28- χl₋:: Vector{R}
29- χe₊:: Dict{Tuple{Int,Int},R}
30- χe₋:: Dict{Tuple{Int,Int},R}
31- χ₊:: Vector{R}
32- χ₋:: Vector{R}
25+ ψl:: Vector{PlusMinusMeasure{R}}
26+ χe:: Dict{Tuple{Int,Int},PlusMinusMeasure{R}}
27+ χl:: Vector{PlusMinusMeasure{R}}
28+ χ:: Vector{PlusMinusMeasure{R}}
3329end
3430
3531function Base. copy (marginals:: MarginalsGLMSBM )
36- return MarginalsCSBM (;
32+ return MarginalsGLMSBM (;
3733 a= copy (marginals. a),
3834 v= copy (marginals. v),
3935 Γ= copy (marginals. Γ),
4036 ω= copy (marginals. ω),
4137 gₒ= copy (marginals. gₒ),
42- ψl₊= copy (marginals. ψl₊),
43- ψl₋= copy (marginals. ψl₋),
44- χl₊= copy (marginals. χl₊),
45- χl₋= copy (marginals. χl₋),
46- χe₊= copy (marginals. χe₊),
47- χe₋= copy (marginals. χe₋),
48- χ₊= copy (marginals. χ₊),
49- χ₋= copy (marginals. χ₋),
38+ ψl= copy (marginals. ψl),
39+ χe= copy (marginals. χe),
40+ χl= copy (marginals. χl),
41+ χ= copy (marginals. χ),
5042 )
5143end
5244
@@ -56,14 +48,10 @@ function Base.copy!(marginals_dest::MarginalsGLMSBM, marginals_source::Marginals
5648 copy! (marginals_dest. Γ, marginals_source. Γ)
5749 copy! (marginals_dest. ω, marginals_source. ω)
5850 copy! (marginals_dest. gₒ, marginals_source. gₒ)
59- copy! (marginals_dest. ψl₊, marginals_source. ψl₊)
60- copy! (marginals_dest. ψl₋, marginals_source. ψl₋)
61- copy! (marginals_dest. χl₊, marginals_source. χl₊)
62- copy! (marginals_dest. χl₋, marginals_source. χl₋)
63- copy! (marginals_dest. χe₊, marginals_source. χe₊)
64- copy! (marginals_dest. χe₋, marginals_source. χe₋)
65- copy! (marginals_dest. χ₊, marginals_source. χ₊)
66- copy! (marginals_dest. χ₋, marginals_source. χ₋)
51+ copy! (marginals_dest. ψl, marginals_source. ψl)
52+ copy! (marginals_dest. χe, marginals_source. χe)
53+ copy! (marginals_dest. χl, marginals_source. χl)
54+ copy! (marginals_dest. χ, marginals_source. χ)
6755 return marginals_dest
6856end
6957
@@ -82,14 +70,16 @@ function init_amp(
8270 ω = zeros (R, N)
8371 gₒ = zeros (R, N)
8472
85- χl₊ = ones (R, N ) / 2
86- χe₊ = Dict {Tuple{Int,Int},R } ()
73+ ψl = [ PlusMinusMeasure ( one (R ) / 2 ) for μ in 1 : N]
74+ χe = Dict {Tuple{Int,Int},PlusMinusMeasure{R} } ()
8775 for μ in 1 : N, ν in neighbors (g, μ)
88- χe₊[μ, ν] = one (R) / 2 + init_std * randn (rng, R)
76+ p = one (R) / 2 + init_std * randn (rng, R)
77+ χe[μ, ν] = PlusMinusMeasure (clamp (p, zero (R), one (R)))
8978 end
90- χ₊ = zeros (R, N)
79+ χl = [PlusMinusMeasure (one (R) / 2 ) for μ in 1 : N]
80+ χ = [PlusMinusMeasure (one (R) / 2 ) for μ in 1 : N]
9181
92- marginals = MarginalsGLMSBM (; a, v, Γ, ω, gₒ, ψl₊, ψl₋, χl₊, χl₋, χe₊, χe₋, χ₊, χ₋ )
82+ marginals = MarginalsGLMSBM (; a, v, Γ, ω, gₒ, ψl, χl, χe, χ )
9383 next_marginals = copy (marginals)
9484 return (; marginals, next_marginals)
9585end
@@ -103,36 +93,30 @@ function update_amp!(
10393 (; N, M, c, λ, Pʷ) = glmsbm
10494 (; g, Ξ, F) = observations
10595 (; cᵢ, cₒ) = affinities (glmsbm)
96+ C = AffinityMatrix (cᵢ, cₒ)
10697
107- (aᵗ, vᵗ, Γᵗ, ωᵗ, gₒᵗ, ψl₊ᵗ, ψl₋ᵗ, χl₊ᵗ, χl₋ᵗ, χe₊ᵗ, χe₋ᵗ, χ₊ᵗ, χ₋ᵗ ) = (
98+ (aᵗ, vᵗ, Γᵗ, ωᵗ, gₒᵗ, ψlᵗ, χeᵗ, χlᵗ, χᵗ ) = (
10899 marginals. a,
109100 marginals. v,
110101 marginals. Γ,
111102 marginals. ω,
112103 marginals. gₒ,
113- marginals. ψl₊,
114- marginals. ψl₋,
115- marginals. χl₊,
116- marginals. χl₋,
117- marginals. χe₊,
118- marginals. χe₋,
119- marginals. χ₊,
120- marginals. χ₋,
104+ marginals. ψl,
105+ marginals. χe,
106+ marginals. χl,
107+ marginals. χ,
121108 )
122- (aᵗ⁺¹, vᵗ⁺¹, Γᵗ⁺¹, ωᵗ⁺¹, gₒᵗ⁺¹, ψl₊ᵗ⁺¹, ψl₋ᵗ⁺¹, χl₊ᵗ⁺¹, χl₋ᵗ⁺¹, χe₊ᵗ⁺¹, χe₋ᵗ⁺¹, χ₊ᵗ⁺¹, χ₋ᵗ⁺¹) = (
109+
110+ (aᵗ⁺¹, vᵗ⁺¹, Γᵗ⁺¹, ωᵗ⁺¹, gₒᵗ⁺¹, ψlᵗ⁺¹, χeᵗ⁺¹, χlᵗ⁺¹, χᵗ⁺¹) = (
123111 next_marginals. a,
124112 next_marginals. v,
125113 next_marginals. Γ,
126114 next_marginals. ω,
127115 next_marginals. gₒ,
128- next_marginals. ψl₊,
129- next_marginals. ψl₋,
130- next_marginals. χl₊,
131- next_marginals. χl₋,
132- next_marginals. χe₊,
133- next_marginals. χe₋,
134- next_marginals. χ₊,
135- next_marginals. χ₋,
116+ next_marginals. ψl,
117+ next_marginals. χe,
118+ next_marginals. χl,
119+ next_marginals. χ,
136120 )
137121
138122 # AMP update of ω, V
@@ -141,29 +125,77 @@ function update_amp!(
141125 ωᵗ⁺¹ .- = Vᵗ⁺¹ .* gₒᵗ
142126
143127 # AMP update of ψl, gₒ, μ, Λ, Γ
144- ψl₊ .= missing # TODO : fix
145- gₒᵗ⁺¹ = gₒ .(ωᵗ⁺¹, χl₊ᵗ, χl₋ᵗ, Ref (Vᵗ⁺¹))
128+ for μ in 1 : N, sμ in (- 1 , 1 )
129+ ψlᵗ⁺¹[μ][sμ] = (one (R) + sμ * erf (ωᵗ⁺¹[μ] / sqrt (2 Vᵗ⁺¹))) / 2
130+ end
131+ gₒᵗ⁺¹ = gₒ .(ωᵗ⁺¹, χlᵗ, Ref (Vᵗ⁺¹))
146132 Λᵗ⁺¹ = sum (abs2, gₒᵗ⁺¹) / M
147133 mul! (Γᵗ⁺¹, F, gₒᵗ⁺¹)
148134 Γᵗ⁺¹ .+ = Λᵗ⁺¹ .* aᵗ
149135
150136 # AMP update of the estimated marginals a, v
151- aᵗ⁺¹ .= fₐ .(Ref (Λᵗ⁺¹), Γᵗ⁺¹)
152- vᵗ⁺¹ .= fᵥ .(Ref (Λᵗ⁺¹), Γᵗ⁺¹)
137+ aᵗ⁺¹ .= fₐ .(Ref (Pʷ), Ref ( Λᵗ⁺¹), Γᵗ⁺¹)
138+ vᵗ⁺¹ .= fᵥ .(Ref (Pʷ), Ref ( Λᵗ⁺¹), Γᵗ⁺¹)
153139
154140 # BP update of the field h
155- h₊ᵗ⁺¹
156- h₋ᵗ⁺¹
141+ hᵗ⁺¹ = PlusMinusMeasure (zero (R), zero (R))
142+ for s in (- 1 , 1 )
143+ hᵗ⁺¹[s] = sum (C[s, sμ] * χᵗ[μ][sμ] for μ in 1 : N for sμ in (- 1 , 1 ))
144+ end
157145
158146 # BP update of the messages χe and of the marginals χ
159- χe₊ᵗ⁺¹
160- χe₋ᵗ⁺¹
161- χ₊ᵗ⁺¹
162- χ₋ᵗ⁺¹
147+ for μ in 1 : N, ν in neighbors (g, μ)
148+ for sμ in (- 1 , 1 )
149+ χeᵗ⁺¹[μ, ν][sμ] = prior (R, sμ, Ξ[μ]) * exp (- hᵗ⁺¹[sμ]) * ψlᵗ⁺¹[μ][sμ]
150+ for η in neighbors (g, μ)
151+ if η != ν
152+ χeᵗ⁺¹[μ, ν][sμ] *= sum (C[sη, sμ] * χeᵗ[η, μ][sη] for sη in (- 1 , 1 ))
153+ end
154+ end
155+ end
156+ normalize! (χeᵗ⁺¹[μ, ν])
157+ end
158+ for μ in 1 : N
159+ for sμ in (- 1 , 1 )
160+ χᵗ⁺¹[μ][sμ] = prior (R, sμ, Ξ[μ]) * exp (- hᵗ⁺¹[sμ]) * ψlᵗ⁺¹[μ][sμ]
161+ for η in neighbors (g, μ)
162+ χᵗ⁺¹[μ][sμ] *= sum (C[sη, sμ] * χeᵗ[η, μ][sη] for sη in (- 1 , 1 ))
163+ end
164+ end
165+ normalize! (χᵗ⁺¹[μ])
166+ end
163167
164168 # BP update of the SBM-to-GLM messages χl
165- χl₊ᵗ⁺¹
166- χl₋ᵗ⁺¹
169+ for μ in 1 : N
170+ for sμ in (- 1 , 1 )
171+ χlᵗ⁺¹[μ][sμ] = prior (R, sμ, Ξ[μ]) * exp (- hᵗ⁺¹[sμ])
172+ for η in neighbors (g, μ)
173+ χlᵗ⁺¹[μ][sμ] *= sum (C[sη, sμ] * χeᵗ[η, μ][sη] for sη in (- 1 , 1 ))
174+ end
175+ end
176+ normalize! (χlᵗ⁺¹[μ])
177+ end
167178
168179 return nothing
169180end
181+
182+ function run_amp (
183+ rng:: AbstractRNG ,
184+ observations:: ObservationsGLMSBM{R} ,
185+ glmsbm:: GLMSBM{R} ;
186+ init_std= 1e-3 ,
187+ max_iterations= 200 ,
188+ convergence_threshold= 1e-3 ,
189+ recent_past= 10 ,
190+ show_progress= false ,
191+ ) where {R}
192+ (; marginals, next_marginals) = init_amp (rng; observations, glmsbm, init_std)
193+ converged = false
194+ prog = Progress (max_iterations; desc= " AMP-BP for GLM-SBM" , enabled= show_progress)
195+ for t in 1 : max_iterations
196+ update_amp! (next_marginals; marginals, observations, glmsbm)
197+ copy! (marginals, next_marginals)
198+ next! (prog)
199+ end
200+ return marginals
201+ end
0 commit comments