11# # Marginals
22
33"""
4- AMPMarginalsCSBM
4+ MarginalsCSBM
55
66# Fields
77
1111- `v̂_no_comm::Vector`: posterior mean of `v` if there were no communities, length `P` (aka `Bᵤ`)
1212- `h̃₊::Vector`: individual external field for `u=1`, length `N`
1313- `h̃₋::Vector`: individual external field for `u=-1`, length `N`
14- - `χ₊e ::Dict`: messages about the marginal distribution of `u`, size `(N, N)`
14+ - `χe₊ ::Dict`: messages about the marginal distribution of `u`, size `(N, N)`
1515- `χ₊::Vector`: marginal probability of `u=1`, length `N`
1616"""
17- @kwdef struct AMPMarginalsCSBM {R<: Real }
17+ @kwdef struct MarginalsCSBM {R<: Real }
1818 û:: Vector{R}
1919 v̂:: Vector{R}
2020 û_no_feat:: Vector{R}
2121 v̂_no_comm:: Vector{R}
2222 h̃₊:: Vector{R}
2323 h̃₋:: Vector{R}
24- χ₊e :: Dict{Tuple{Int,Int},R}
24+ χe₊ :: Dict{Tuple{Int,Int},R}
2525 χ₊:: Vector{R}
2626end
2727
28- function Base. copy (marginals:: AMPMarginalsCSBM )
29- return AMPMarginalsCSBM (;
28+ function Base. copy (marginals:: MarginalsCSBM )
29+ return MarginalsCSBM (;
3030 û= copy (marginals. û),
3131 v̂= copy (marginals. v̂),
3232 û_no_feat= copy (marginals. û_no_feat),
3333 v̂_no_comm= copy (marginals. v̂_no_comm),
3434 h̃₊= copy (marginals. h̃₊),
3535 h̃₋= copy (marginals. h̃₋),
36- χ₊e = copy (marginals. χ₊e ),
36+ χe₊ = copy (marginals. χe₊ ),
3737 χ₊= copy (marginals. χ₊),
3838 )
3939end
4040
41- function Base. copy! (marginals_dest:: AMPMarginalsCSBM , marginals_source:: AMPMarginalsCSBM )
41+ function Base. copy! (marginals_dest:: MarginalsCSBM , marginals_source:: MarginalsCSBM )
4242 copy! (marginals_dest. û, marginals_source. û),
4343 copy! (marginals_dest. v̂, marginals_source. v̂),
4444 copy! (marginals_dest. û_no_feat, marginals_source. û_no_feat),
4545 copy! (marginals_dest. v̂_no_comm, marginals_source. v̂_no_comm),
4646 copy! (marginals_dest. h̃₊, marginals_source. h̃₊),
4747 copy! (marginals_dest. h̃₋, marginals_source. h̃₋),
48- copy! (marginals_dest. χ₊e , marginals_source. χ₊e ),
48+ copy! (marginals_dest. χe₊ , marginals_source. χe₊ ),
4949 copy! (marginals_dest. χ₊, marginals_source. χ₊),
5050 return marginals_dest
5151end
@@ -60,42 +60,42 @@ function init_amp(
6060
6161 û = 2 .* prior₊ .(R, Ξ) .- one (R) .+ init_std .* randn (rng, R, N)
6262 v̂ = init_std .* randn (rng, R, P)
63-
63+
6464 û_no_feat = zeros (R, N)
6565 v̂_no_comm = zeros (R, P)
66-
66+
6767 h̃₊ = zeros (R, N)
6868 h̃₋ = zeros (R, N)
69-
70- χ₊e = Dict {Tuple{Int,Int},R} ()
69+
70+ χe₊ = Dict {Tuple{Int,Int},R} ()
7171 for i in 1 : N, j in neighbors (g, i)
72- χ₊e [i, j] = prior₊ (R, Ξ[i]) + init_std * randn (rng, R)
72+ χe₊ [i, j] = prior₊ (R, Ξ[i]) + init_std * randn (rng, R)
7373 end
7474 χ₊ = zeros (R, N)
75-
76- marginals = AMPMarginalsCSBM (; û, v̂, û_no_feat, v̂_no_comm, h̃₊, h̃₋, χ₊e , χ₊)
75+
76+ marginals = MarginalsCSBM (; û, v̂, û_no_feat, v̂_no_comm, h̃₊, h̃₋, χe₊ , χ₊)
7777 next_marginals = copy (marginals)
7878 return (; marginals, next_marginals)
7979end
8080
8181function update_amp! (
82- next_marginals:: AMPMarginalsCSBM {R} ;
83- marginals:: AMPMarginalsCSBM {R} ,
82+ next_marginals:: MarginalsCSBM {R} ;
83+ marginals:: MarginalsCSBM {R} ,
8484 observations:: ObservationsCSBM{R} ,
8585 csbm:: CSBM{R} ,
8686) where {R}
8787 (; d, λ, μ, N, P) = csbm
88- (; g, B, Ξ ) = observations
88+ (; g, Ξ, B ) = observations
8989 (; cᵢ, cₒ) = affinities (csbm)
9090
91- ûᵗ, v̂ᵗ, χ₊eᵗ = marginals. û, marginals. v̂, marginals. χ₊e
92- ûᵗ⁺¹, v̂ᵗ⁺¹, χ₊eᵗ ⁺¹ = next_marginals. û, next_marginals. v̂, next_marginals. χ₊e
91+ ûᵗ, v̂ᵗ, χe₊ᵗ = marginals. û, marginals. v̂, marginals. χe₊
92+ ûᵗ⁺¹, v̂ᵗ⁺¹, χe₊ᵗ ⁺¹ = next_marginals. û, next_marginals. v̂, next_marginals. χe₊
9393 (; û_no_feat, v̂_no_comm, h̃₊, h̃₋, χ₊) = next_marginals
9494
9595 ûₜ_sum = sum (ûᵗ)
9696 ûₜ_sum2 = sum (abs2, ûᵗ)
9797
98- # CSBMAMP estimation of v
98+ # AMP estimation of v
9999 σᵥ_no_comm = (μ / N) * ûₜ_sum2
100100 mul! (v̂_no_comm, B, ûᵗ)
101101 v̂_no_comm .*= sqrt (μ / N)
@@ -118,23 +118,23 @@ function update_amp!(
118118 for i in 1 : N
119119 s_i = h̃₊[i] - h̃₋[i]
120120 for k in neighbors (g, i)
121- common = 2 λ * sqrt (d) * χ₊eᵗ [k, i]
121+ common = 2 λ * sqrt (d) * χe₊ᵗ [k, i]
122122 s_i += log ((cₒ + common) / (cᵢ - common))
123123 end
124124 χ₊[i] = s_i
125125 end
126126
127127 # BP update of the messages
128128 for i in 1 : N, j in neighbors (g, i)
129- common = 2 λ * sqrt (d) * χ₊eᵗ [j, i]
129+ common = 2 λ * sqrt (d) * χe₊ᵗ [j, i]
130130 s_ij = log ((cₒ + common) / (cᵢ - common))
131- χ₊eᵗ ⁺¹[i, j] = χ₊[i] - s_ij
131+ χe₊ᵗ ⁺¹[i, j] = χ₊[i] - s_ij
132132 end
133133
134134 # Sigmoidize probabilities
135135 χ₊ .= sigmoid .(χ₊)
136- for (key, val) in pairs (χ₊eᵗ ⁺¹)
137- χ₊eᵗ ⁺¹[key] = sigmoid (val)
136+ for (key, val) in pairs (χe₊ᵗ ⁺¹)
137+ χe₊ᵗ ⁺¹[key] = sigmoid (val)
138138 end
139139
140140 # BP estimation of u
0 commit comments