Skip to content

Commit 102792e

Browse files
committed
BP is working for CSBM
1 parent db38721 commit 102792e

File tree

4 files changed

+46
-43
lines changed

4 files changed

+46
-43
lines changed

Manifest.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
julia_version = "1.9.2"
44
manifest_format = "2.0"
5-
project_hash = "426d5ac467261afc2896c60ab2735a7aa4815844"
5+
project_hash = "fa8ae38c891ad59da40addcb7399824f2795f7b2"
66

77
[[deps.ArgTools]]
88
uuid = "0dad84c5-d112-42e6-8d28-ef12dabb789f"

src/contextual_sbm.jl

Lines changed: 29 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -79,8 +79,11 @@ function Base.rand(rng::AbstractRNG, csbm::ContextualSBM)
7979

8080
r = rand(rng, N, N)
8181
Is, Js = Int[], Int[]
82-
for i in 1:N, j in 1:i
83-
if ((u[i] == u[j]) && (r[i, j] < cᵢ / N)) || ((u[i] != u[j]) && (r[i, j] < cₒ / N))
82+
for i in 1:N, j in (i + 1):N
83+
if (
84+
((u[i] == u[j]) && (r[i, j] < cᵢ / N)) || # same
85+
((u[i] != u[j]) && (r[i, j] < cₒ / N)) # diff
86+
)
8487
push!(Is, i)
8588
push!(Js, j)
8689
end
@@ -155,6 +158,16 @@ end
155158
χ₊::Vector{R}
156159
end
157160

161+
function Base.copy(temp_storage::AMPTempStorage)
162+
return AMPTempStorage(;
163+
û_no_feat=copy(temp_storage.û_no_feat),
164+
v̂_no_graph=copy(temp_storage.v̂_no_graph),
165+
h̃₊=copy(temp_storage.h̃₊),
166+
h̃₋=copy(temp_storage.h̃₋),
167+
χ₊=copy(temp_storage.χ₊),
168+
)
169+
end
170+
158171
function init_amp(
159172
rng::AbstractRNG;
160173
observations::ContextualSBMObservations,
@@ -191,8 +204,8 @@ function update_amp!(
191204
temp_storage::AMPTempStorage;
192205
storage::AMPStorage,
193206
observations::ContextualSBMObservations,
194-
csbm::ContextualSBM,
195-
)
207+
csbm::ContextualSBM{R},
208+
) where {R}
196209
(; d, λ, μ, N, P) = csbm
197210
(; B, G) = observations
198211
(; cᵢ, cₒ) = affinities(csbm)
@@ -223,18 +236,19 @@ function update_amp!(
223236
h₊ = (1 / N) * sum(cᵢ * (1 + ûᵗ[i]) / 2 + cₒ * (1 - ûᵗ[i]) / 2 for i in 1:N)
224237
h₋ = (1 / N) * sum(cₒ * (1 + ûᵗ[i]) / 2 + cᵢ * (1 - ûᵗ[i]) / 2 for i in 1:N)
225238
for i in 1:N
226-
h̃₊[i] = -h₊ + û_no_feat[i]
227-
h̃₋[i] = -h₋ - û_no_feat[i]
239+
h̃₊[i] = -h₊ + log(one(R) / 2) + û_no_feat[i]
240+
h̃₋[i] = -h₋ + log(one(R) / 2) - û_no_feat[i]
228241
end
229242

230243
# BP update of the messages
231244
for i in 1:N, j in neighbors(G, i)
232245
s_ij = h̃₊[i] - h̃₋[i]
233246
for k in neighbors(G, i)
234-
k != j || continue
235-
num = (cₒ + 2λ * sqrt(d) * χ₊eᵗ[k, i])
236-
den = (cᵢ - 2λ * sqrt(d) * χ₊eᵗ[k, i])
237-
s_ij += log(num / den)
247+
if k != j
248+
num = (cₒ + 2λ * sqrt(d) * χ₊eᵗ[k, i])
249+
den = (cᵢ - 2λ * sqrt(d) * χ₊eᵗ[k, i])
250+
s_ij += log(num / den)
251+
end
238252
end
239253
χ₊eᵗ⁺¹[i, j] = sigmoid(s_ij)
240254
end
@@ -267,19 +281,21 @@ function run_amp(
267281
)
268282
(; storage, next_storage, temp_storage) = init_amp(rng; observations, csbm, init_std)
269283
storage_history = [copy(storage)]
284+
temp_storage_history = [copy(temp_storage)]
270285
@showprogress "AMP-BP" for iter in 1:iterations
271286
update_amp!(next_storage, temp_storage; storage, observations, csbm)
272287
copy!(storage, next_storage)
273288
push!(storage_history, copy(storage))
289+
push!(temp_storage_history, copy(temp_storage))
274290
end
275291
return storage_history
276292
end
277293

278294
function evaluate_amp(; storage::AMPStorage, latents::ContextualSBMLatents)
295+
= sign.(storage.û)
296+
@assert all(abs.(û) .> 0.5)
279297
u = latents.u
280-
N = length(u)
281-
= 2 .* Int.(storage..> 0) .- 1
282-
q̂ᵤ = (1 / N) * max(count_equalities(û, u), count_equalities(û, -u))
298+
q̂ᵤ = (1 / length(û)) * max(count_equalities(û, u), count_equalities(û, -u))
283299
qᵤ = 2 * (q̂ᵤ - 0.5)
284300
return qᵤ
285301
end

src/utils.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
1-
sigmoid(x) = inv(1 + exp(-x))
1+
sigmoid(x) = 1 / (1 + exp(-x))
22

33
count_equalities(x, y) = sum(x[i] y[i] for i in eachindex(x, y))

test/contextual_sbm.jl

Lines changed: 15 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,34 +1,21 @@
11
using Random: default_rng
2+
using Statistics
23
using StochasticBlockModelVariants
34
using Test
45

56
rng = default_rng()
67

7-
# Parameters
8-
9-
d = 5.0
10-
λ = 2.0
11-
μ = 2.0
12-
N = 3 * 10^3
13-
P = 3 * 10^2
14-
15-
init_std = 1e-3
16-
iterations = 10
17-
18-
# Sampling
19-
20-
csbm = ContextualSBM(; d, λ, μ, N, P)
21-
22-
(; latents, observations) = rand(rng, csbm);
23-
24-
# Inference
25-
26-
(; storage, next_storage, temp_storage) = init_amp(rng; observations, csbm, init_std);
27-
28-
storage_history = run_amp(rng; observations, csbm, init_std, iterations);
29-
30-
@test effective_snr(csbm) > 1
31-
32-
overlap_history = [evaluate_amp(; storage, latents) for storage in storage_history]
33-
34-
@test_broken last(overlap_history) > 0.5
8+
function test_recovery(; d, λ, μ, N, P, init_std=1e-3, iterations=20)
9+
csbm = ContextualSBM(; d, λ, μ, N, P)
10+
@assert effective_snr(csbm) > 1
11+
(; latents, observations) = rand(rng, csbm)
12+
storage_history = run_amp(rng; observations, csbm, init_std, iterations)
13+
overlap_history = [evaluate_amp(; storage, latents) for storage in storage_history]
14+
@test last(overlap_history) > 10 * first(overlap_history)
15+
@test last(overlap_history) > 0.5
16+
return nothing
17+
end
18+
19+
test_recovery(; d=5.0, λ=2.0, μ=2.0, N=10^3, P=10^3) # AMP-BP
20+
test_recovery(; d=5.0, λ=0.0, μ=2.0, N=10^3, P=10^3) # AMP
21+
test_recovery(; d=5.0, λ=2.0, μ=0.0, N=10^3, P=10^3) # BP

0 commit comments

Comments
 (0)