Skip to content

Commit 5b2fc68

Browse files
committed
GLMSBM runs but fails
1 parent 556a63d commit 5b2fc68

File tree

9 files changed

+185
-92
lines changed

9 files changed

+185
-92
lines changed

Manifest.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
julia_version = "1.10.0-beta2"
44
manifest_format = "2.0"
5-
project_hash = "3ba5719890857eaf711cde76af0aade7729abd07"
5+
project_hash = "5d6b46b606378213b4547f42572ed5ae33cb9d37"
66

77
[[deps.ArgTools]]
88
uuid = "0dad84c5-d112-42e6-8d28-ef12dabb789f"

src/StochasticBlockModelVariants.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ module StochasticBlockModelVariants
22

33
using DensityInterface: logdensityof, densityof
44
using Graphs: AbstractGraph, neighbors
5-
using LinearAlgebra: dot, mul!, norm
5+
using LinearAlgebra: LinearAlgebra, dot, mul!, norm, normalize!
66
using PrecompileTools: @compile_workload
77
using ProgressMeter: Progress, next!
88
using Random: AbstractRNG, default_rng
@@ -22,6 +22,7 @@ include("utils.jl")
2222
include("csbm.jl")
2323
include("glmsbm.jl")
2424
include("csbm_inference.jl")
25+
include("glmsbm_inference.jl")
2526

2627
# @compile_workload begin
2728
# rng = default_rng()

src/abstract_sbm.jl

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,19 @@ function affinities(sbm::AbstractSBM)
4646
return (; cᵢ, cₒ)
4747
end
4848

49+
struct AffinityMatrix{R}
50+
cᵢ::R
51+
cₒ::R
52+
end
53+
54+
function Base.getindex(C::AffinityMatrix, i, j)
55+
if i == j
56+
return C.cᵢ
57+
else
58+
return C.cₒ
59+
end
60+
end
61+
4962
"""
5063
fraction_observed(sbm)
5164
@@ -88,7 +101,7 @@ Sample a vector `Ξ` (Xi) whose components are equal to the community assignment
88101
function sample_mask(rng::AbstractRNG, sbm::AbstractSBM, communities::Vector{<:Integer})
89102
N = length(sbm)
90103
ρ = fraction_observed(sbm)
91-
Ξ = Vector{Union{Missing, Int}}(undef, N)
104+
Ξ = Vector{Union{Missing,Int}}(undef, N)
92105
Ξ .= missing
93106
for i in 1:N
94107
if rand(rng) < ρ
@@ -98,5 +111,4 @@ function sample_mask(rng::AbstractRNG, sbm::AbstractSBM, communities::Vector{<:I
98111
return Ξ
99112
end
100113

101-
prior₊(::Type{R}, Ξᵢ) where {R} = ismissing(Ξᵢ) ? one(R) / 2 : R(Ξᵢ == 1)
102-
prior₋(::Type{R}, Ξᵢ) where {R} = ismissing(Ξᵢ) ? one(R) / 2 : R(Ξᵢ == -1)
114+
prior(::Type{R}, u, Ξᵢ) where {R} = ismissing(Ξᵢ) ? one(R) / 2 : R(Ξᵢ == u)

src/csbm_inference.jl

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -53,12 +53,12 @@ end
5353
## Message-passing
5454

5555
function init_amp(
56-
rng::AbstractRNG; observations::ObservationsCSBM{R}, csbm::CSBM{R}, init_std
56+
rng::AbstractRNG, observations::ObservationsCSBM{R}, csbm::CSBM{R}; init_std
5757
) where {R}
5858
(; N, P) = csbm
5959
(; g, Ξ) = observations
6060

61-
= 2 .* prior.(R, Ξ) .- one(R) .+ init_std .* randn(rng, R, N)
61+
= 2 .* prior.(R, 1, Ξ) .- one(R) .+ init_std .* randn(rng, R, N)
6262
= init_std .* randn(rng, R, P)
6363

6464
û_no_feat = zeros(R, N)
@@ -69,7 +69,7 @@ function init_amp(
6969

7070
χe₊ = Dict{Tuple{Int,Int},R}()
7171
for i in 1:N, j in neighbors(g, i)
72-
χe₊[i, j] = prior(R, Ξ[i]) + init_std * randn(rng, R)
72+
χe₊[i, j] = prior(R, 1, Ξ[i]) + init_std * randn(rng, R)
7373
end
7474
χ₊ = zeros(R, N)
7575

@@ -79,7 +79,7 @@ function init_amp(
7979
end
8080

8181
function update_amp!(
82-
next_marginals::MarginalsCSBM{R};
82+
next_marginals::MarginalsCSBM{R},
8383
marginals::MarginalsCSBM{R},
8484
observations::ObservationsCSBM{R},
8585
csbm::CSBM{R},
@@ -111,8 +111,8 @@ function update_amp!(
111111
# Estimation of the field h
112112
h₊ = (one(R) / 2N) * (cᵢ * (N + ûₜ_sum) + cₒ * (N - ûₜ_sum))
113113
h₋ = (one(R) / 2N) * (cₒ * (N + ûₜ_sum) + cᵢ * (N - ûₜ_sum))
114-
h̃₊ .= -h₊ .+ log.(prior.(R, Ξ)) .+ û_no_feat
115-
h̃₋ .= -h₋ .+ log.(prior.(R, Ξ)) .- û_no_feat
114+
h̃₊ .= -h₊ .+ log.(prior.(R, +1, Ξ)) .+ û_no_feat
115+
h̃₋ .= -h₋ .+ log.(prior.(R, -1, Ξ)) .- û_no_feat
116116

117117
# BP update of the marginals
118118
for i in 1:N
@@ -144,25 +144,25 @@ function update_amp!(
144144
end
145145

146146
function run_amp(
147-
rng::AbstractRNG;
147+
rng::AbstractRNG,
148148
observations::ObservationsCSBM{R},
149-
csbm::CSBM{R},
149+
csbm::CSBM{R};
150150
init_std=1e-3,
151151
max_iterations=200,
152152
convergence_threshold=1e-3,
153153
recent_past=10,
154154
show_progress=false,
155155
) where {R}
156156
(; N, P) = csbm
157-
(; marginals, next_marginals) = init_amp(rng; observations, csbm, init_std)
157+
(; marginals, next_marginals) = init_amp(rng, observations, csbm; init_std)
158158

159159
û_history = Matrix{R}(undef, N, max_iterations)
160160
v̂_history = Matrix{R}(undef, P, max_iterations)
161161
converged = false
162162
prog = Progress(max_iterations; desc="AMP-BP for CSBM", enabled=show_progress)
163163

164164
for t in 1:max_iterations
165-
update_amp!(next_marginals; marginals, observations, csbm)
165+
update_amp!(next_marginals, marginals, observations, csbm)
166166
copy!(marginals, next_marginals)
167167

168168
û_history[:, t] .= marginals.
@@ -197,7 +197,7 @@ end
197197

198198
function evaluate_amp(rng::AbstractRNG; csbm::CSBM, kwargs...)
199199
(; latents, observations) = rand(rng, csbm)
200-
(; û_history, v̂_history, converged) = run_amp(rng; observations, csbm, kwargs...)
200+
(; û_history, v̂_history, converged) = run_amp(rng, observations, csbm; kwargs...)
201201
(; qᵤ, qᵥ) = overlaps(;
202202
u=latents.u, v=latents.v, û=û_history[:, end], v̂=v̂_history[:, end]
203203
)

src/glmsbm.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,3 +110,8 @@ fᵥ(::RademacherWeightPrior, Λ, Γ) = inv(abs2(cosh(Γ)))
110110

111111
fₐ(::GaussianWeightPrior, Λ, Γ) = Γ /+ 1)
112112
fᵥ(::GaussianWeightPrior, Λ, Γ) = 1 /+ 1)
113+
114+
function gₒ(ω, χ, V)
115+
Znn = (1 + (2χ[1] - 1) * erf/ sqrt(2V))) / 2
116+
return inv(sqrt(2π * V)) * (2χ[1] - 1) * exp(-ω^2 / (2V)) / Znn
117+
end

src/glmsbm_inference.jl

Lines changed: 101 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -5,15 +5,15 @@
55
66
# Fields
77
8-
- `a::Vector{R}`: length `M`
9-
- `v::Vector{R}`: length `M`
10-
- `Γ::Vector{R}`: length `M`
11-
- `ω::Vector{R}`: length `N`
12-
- `gₒ::Vector{R}`: length `N`
13-
- `ψl::Vector{R}`: length `N`
14-
- `χl₊::Vector{R}`: length `N`
15-
- `χe₊::Dict{Tuple{Int,Int},R}`: size `(N, N)`
16-
- `χ::Vector{R}`: length `N`
8+
- `a::Vector`: length `M`
9+
- `v::Vector`: length `M`
10+
- `Γ::Vector`: length `M`
11+
- `ω::Vector`: length `N`
12+
- `gₒ::Vector`: length `N`
13+
- `ψl::Vector{PlusMinusMeasure}`: length `N`
14+
- `χe::Dict{Tuple{Int,Int},PlusMinusMeasure}`: size `(N, N)`
15+
- `χl::Vector{PlusMinusMeasure}`: length `N`
16+
- `χ::Vector{PlusMinusMeasure}`: length `N`
1717
1818
"""
1919
@kwdef struct MarginalsGLMSBM{R<:Real}
@@ -22,31 +22,23 @@
2222
Γ::Vector{R}
2323
ω::Vector{R}
2424
gₒ::Vector{R}
25-
ψl₊::Vector{R}
26-
ψl₋::Vector{R}
27-
χl₊::Vector{R}
28-
χl₋::Vector{R}
29-
χe₊::Dict{Tuple{Int,Int},R}
30-
χe₋::Dict{Tuple{Int,Int},R}
31-
χ₊::Vector{R}
32-
χ₋::Vector{R}
25+
ψl::Vector{PlusMinusMeasure{R}}
26+
χe::Dict{Tuple{Int,Int},PlusMinusMeasure{R}}
27+
χl::Vector{PlusMinusMeasure{R}}
28+
χ::Vector{PlusMinusMeasure{R}}
3329
end
3430

3531
function Base.copy(marginals::MarginalsGLMSBM)
36-
return MarginalsCSBM(;
32+
return MarginalsGLMSBM(;
3733
a=copy(marginals.a),
3834
v=copy(marginals.v),
3935
Γ=copy(marginals.Γ),
4036
ω=copy(marginals.ω),
4137
gₒ=copy(marginals.gₒ),
42-
ψl₊=copy(marginals.ψl₊),
43-
ψl₋=copy(marginals.ψl₋),
44-
χl₊=copy(marginals.χl₊),
45-
χl₋=copy(marginals.χl₋),
46-
χe₊=copy(marginals.χe₊),
47-
χe₋=copy(marginals.χe₋),
48-
χ₊=copy(marginals.χ₊),
49-
χ₋=copy(marginals.χ₋),
38+
ψl=copy(marginals.ψl),
39+
χe=copy(marginals.χe),
40+
χl=copy(marginals.χl),
41+
χ=copy(marginals.χ),
5042
)
5143
end
5244

@@ -56,14 +48,10 @@ function Base.copy!(marginals_dest::MarginalsGLMSBM, marginals_source::Marginals
5648
copy!(marginals_dest.Γ, marginals_source.Γ)
5749
copy!(marginals_dest.ω, marginals_source.ω)
5850
copy!(marginals_dest.gₒ, marginals_source.gₒ)
59-
copy!(marginals_dest.ψl₊, marginals_source.ψl₊)
60-
copy!(marginals_dest.ψl₋, marginals_source.ψl₋)
61-
copy!(marginals_dest.χl₊, marginals_source.χl₊)
62-
copy!(marginals_dest.χl₋, marginals_source.χl₋)
63-
copy!(marginals_dest.χe₊, marginals_source.χe₊)
64-
copy!(marginals_dest.χe₋, marginals_source.χe₋)
65-
copy!(marginals_dest.χ₊, marginals_source.χ₊)
66-
copy!(marginals_dest.χ₋, marginals_source.χ₋)
51+
copy!(marginals_dest.ψl, marginals_source.ψl)
52+
copy!(marginals_dest.χe, marginals_source.χe)
53+
copy!(marginals_dest.χl, marginals_source.χl)
54+
copy!(marginals_dest.χ, marginals_source.χ)
6755
return marginals_dest
6856
end
6957

@@ -82,14 +70,16 @@ function init_amp(
8270
ω = zeros(R, N)
8371
gₒ = zeros(R, N)
8472

85-
χl₊ = ones(R, N) / 2
86-
χe = Dict{Tuple{Int,Int},R}()
73+
ψl = [PlusMinusMeasure(one(R) / 2) for μ in 1:N]
74+
χe = Dict{Tuple{Int,Int},PlusMinusMeasure{R}}()
8775
for μ in 1:N, ν in neighbors(g, μ)
88-
χe₊[μ, ν] = one(R) / 2 + init_std * randn(rng, R)
76+
p = one(R) / 2 + init_std * randn(rng, R)
77+
χe[μ, ν] = PlusMinusMeasure(clamp(p, zero(R), one(R)))
8978
end
90-
χ₊ = zeros(R, N)
79+
χl = [PlusMinusMeasure(one(R) / 2) for μ in 1:N]
80+
χ = [PlusMinusMeasure(one(R) / 2) for μ in 1:N]
9181

92-
marginals = MarginalsGLMSBM(; a, v, Γ, ω, gₒ, ψl₊, ψl₋, χl₊, χl₋, χe₊, χe₋, χ₊, χ₋)
82+
marginals = MarginalsGLMSBM(; a, v, Γ, ω, gₒ, ψl, χl, χe, χ)
9383
next_marginals = copy(marginals)
9484
return (; marginals, next_marginals)
9585
end
@@ -103,36 +93,30 @@ function update_amp!(
10393
(; N, M, c, λ, Pʷ) = glmsbm
10494
(; g, Ξ, F) = observations
10595
(; cᵢ, cₒ) = affinities(glmsbm)
96+
C = AffinityMatrix(cᵢ, cₒ)
10697

107-
(aᵗ, vᵗ, Γᵗ, ωᵗ, gₒᵗ, ψl₊ᵗ, ψl₋ᵗ, χl₊ᵗ, χl₋ᵗ, χe₊ᵗ, χe₋ᵗ, χ₊ᵗ, χ₋ᵗ) = (
98+
(aᵗ, vᵗ, Γᵗ, ωᵗ, gₒᵗ, ψlᵗ, χeᵗ, χlᵗ, χᵗ) = (
10899
marginals.a,
109100
marginals.v,
110101
marginals.Γ,
111102
marginals.ω,
112103
marginals.gₒ,
113-
marginals.ψl₊,
114-
marginals.ψl₋,
115-
marginals.χl₊,
116-
marginals.χl₋,
117-
marginals.χe₊,
118-
marginals.χe₋,
119-
marginals.χ₊,
120-
marginals.χ₋,
104+
marginals.ψl,
105+
marginals.χe,
106+
marginals.χl,
107+
marginals.χ,
121108
)
122-
(aᵗ⁺¹, vᵗ⁺¹, Γᵗ⁺¹, ωᵗ⁺¹, gₒᵗ⁺¹, ψl₊ᵗ⁺¹, ψl₋ᵗ⁺¹, χl₊ᵗ⁺¹, χl₋ᵗ⁺¹, χe₊ᵗ⁺¹, χe₋ᵗ⁺¹, χ₊ᵗ⁺¹, χ₋ᵗ⁺¹) = (
109+
110+
(aᵗ⁺¹, vᵗ⁺¹, Γᵗ⁺¹, ωᵗ⁺¹, gₒᵗ⁺¹, ψlᵗ⁺¹, χeᵗ⁺¹, χlᵗ⁺¹, χᵗ⁺¹) = (
123111
next_marginals.a,
124112
next_marginals.v,
125113
next_marginals.Γ,
126114
next_marginals.ω,
127115
next_marginals.gₒ,
128-
next_marginals.ψl₊,
129-
next_marginals.ψl₋,
130-
next_marginals.χl₊,
131-
next_marginals.χl₋,
132-
next_marginals.χe₊,
133-
next_marginals.χe₋,
134-
next_marginals.χ₊,
135-
next_marginals.χ₋,
116+
next_marginals.ψl,
117+
next_marginals.χe,
118+
next_marginals.χl,
119+
next_marginals.χ,
136120
)
137121

138122
# AMP update of ω, V
@@ -141,29 +125,77 @@ function update_amp!(
141125
ωᵗ⁺¹ .-= Vᵗ⁺¹ .* gₒᵗ
142126

143127
# AMP update of ψl, gₒ, μ, Λ, Γ
144-
ψl₊ .= missing # TODO: fix
145-
gₒᵗ⁺¹ = gₒ.(ωᵗ⁺¹, χl₊ᵗ, χl₋ᵗ, Ref(Vᵗ⁺¹))
128+
for μ in 1:N, sμ in (-1, 1)
129+
ψlᵗ⁺¹[μ][sμ] = (one(R) +* erf(ωᵗ⁺¹[μ] / sqrt(2Vᵗ⁺¹))) / 2
130+
end
131+
gₒᵗ⁺¹ = gₒ.(ωᵗ⁺¹, χlᵗ, Ref(Vᵗ⁺¹))
146132
Λᵗ⁺¹ = sum(abs2, gₒᵗ⁺¹) / M
147133
mul!(Γᵗ⁺¹, F, gₒᵗ⁺¹)
148134
Γᵗ⁺¹ .+= Λᵗ⁺¹ .* aᵗ
149135

150136
# AMP update of the estimated marginals a, v
151-
aᵗ⁺¹ .= fₐ.(Ref(Λᵗ⁺¹), Γᵗ⁺¹)
152-
vᵗ⁺¹ .= fᵥ.(Ref(Λᵗ⁺¹), Γᵗ⁺¹)
137+
aᵗ⁺¹ .= fₐ.(Ref(Pʷ), Ref(Λᵗ⁺¹), Γᵗ⁺¹)
138+
vᵗ⁺¹ .= fᵥ.(Ref(Pʷ), Ref(Λᵗ⁺¹), Γᵗ⁺¹)
153139

154140
# BP update of the field h
155-
h₊ᵗ⁺¹
156-
h₋ᵗ⁺¹
141+
hᵗ⁺¹ = PlusMinusMeasure(zero(R), zero(R))
142+
for s in (-1, 1)
143+
hᵗ⁺¹[s] = sum(C[s, sμ] * χᵗ[μ][sμ] for μ in 1:N forin (-1, 1))
144+
end
157145

158146
# BP update of the messages χe and of the marginals χ
159-
χe₊ᵗ⁺¹
160-
χe₋ᵗ⁺¹
161-
χ₊ᵗ⁺¹
162-
χ₋ᵗ⁺¹
147+
for μ in 1:N, ν in neighbors(g, μ)
148+
forin (-1, 1)
149+
χeᵗ⁺¹[μ, ν][sμ] = prior(R, sμ, Ξ[μ]) * exp(-hᵗ⁺¹[sμ]) * ψlᵗ⁺¹[μ][sμ]
150+
for η in neighbors(g, μ)
151+
if η != ν
152+
χeᵗ⁺¹[μ, ν][sμ] *= sum(C[sη, sμ] * χeᵗ[η, μ][sη] forin (-1, 1))
153+
end
154+
end
155+
end
156+
normalize!(χeᵗ⁺¹[μ, ν])
157+
end
158+
for μ in 1:N
159+
forin (-1, 1)
160+
χᵗ⁺¹[μ][sμ] = prior(R, sμ, Ξ[μ]) * exp(-hᵗ⁺¹[sμ]) * ψlᵗ⁺¹[μ][sμ]
161+
for η in neighbors(g, μ)
162+
χᵗ⁺¹[μ][sμ] *= sum(C[sη, sμ] * χeᵗ[η, μ][sη] forin (-1, 1))
163+
end
164+
end
165+
normalize!(χᵗ⁺¹[μ])
166+
end
163167

164168
# BP update of the SBM-to-GLM messages χl
165-
χl₊ᵗ⁺¹
166-
χl₋ᵗ⁺¹
169+
for μ in 1:N
170+
forin (-1, 1)
171+
χlᵗ⁺¹[μ][sμ] = prior(R, sμ, Ξ[μ]) * exp(-hᵗ⁺¹[sμ])
172+
for η in neighbors(g, μ)
173+
χlᵗ⁺¹[μ][sμ] *= sum(C[sη, sμ] * χeᵗ[η, μ][sη] forin (-1, 1))
174+
end
175+
end
176+
normalize!(χlᵗ⁺¹[μ])
177+
end
167178

168179
return nothing
169180
end
181+
182+
function run_amp(
183+
rng::AbstractRNG,
184+
observations::ObservationsGLMSBM{R},
185+
glmsbm::GLMSBM{R};
186+
init_std=1e-3,
187+
max_iterations=200,
188+
convergence_threshold=1e-3,
189+
recent_past=10,
190+
show_progress=false,
191+
) where {R}
192+
(; marginals, next_marginals) = init_amp(rng; observations, glmsbm, init_std)
193+
converged = false
194+
prog = Progress(max_iterations; desc="AMP-BP for GLM-SBM", enabled=show_progress)
195+
for t in 1:max_iterations
196+
update_amp!(next_marginals; marginals, observations, glmsbm)
197+
copy!(marginals, next_marginals)
198+
next!(prog)
199+
end
200+
return marginals
201+
end

0 commit comments

Comments
 (0)