Skip to content

Commit ebde6d0

Browse files
committed
Partially revealed communities
1 parent 9e739d8 commit ebde6d0

File tree

6 files changed

+61
-36
lines changed

6 files changed

+61
-36
lines changed

Manifest.toml

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
julia_version = "1.9.2"
44
manifest_format = "2.0"
5-
project_hash = "76c4ad29b7fdf61caa5b4fe006d178c04b45b443"
5+
project_hash = "eabebeb24d3db0211e95e73a9e2e965bf16ba4e1"
66

77
[[deps.ArgTools]]
88
uuid = "0dad84c5-d112-42e6-8d28-ef12dabb789f"
@@ -158,6 +158,12 @@ version = "1.4.0"
158158
deps = ["Unicode"]
159159
uuid = "de0858da-6303-5e67-8744-51eddeeeb8d7"
160160

161+
[[deps.ProgressMeter]]
162+
deps = ["Distributed", "Printf"]
163+
git-tree-sha1 = "d7a7aef8f8f2d537104f170139553b14dfe39fe9"
164+
uuid = "92933f4c-e287-5a05-a399-4b506db050ca"
165+
version = "1.7.2"
166+
161167
[[deps.REPL]]
162168
deps = ["InteractiveUtils", "Markdown", "Sockets", "Unicode"]
163169
uuid = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb"

Project.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ version = "0.1.0"
77
Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6"
88
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
99
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
10+
ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca"
1011
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1112
SimpleWeightedGraphs = "47aef6b3-ad0c-573a-a1e2-d07658019622"
1213
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
@@ -15,7 +16,7 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
1516
[compat]
1617
Graphs = "1.8"
1718
PrecompileTools = "1.1"
18-
ProgressLogging = "0.1"
19+
ProgressMeter = "1.7"
1920
SimpleWeightedGraphs = "1.4"
2021
julia = "1.9"
2122

src/StochasticBlockModelVariants.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,22 +3,24 @@ module StochasticBlockModelVariants
33
using Graphs: AbstractGraph, neighbors
44
using LinearAlgebra: dot, mul!
55
using PrecompileTools: @compile_workload
6+
using ProgressMeter: Progress, next!
67
using Random: AbstractRNG, default_rng
78
using SimpleWeightedGraphs: SimpleWeightedGraph
89
using Statistics: mean
910
using SparseArrays: SparseMatrixCSC, sparse, findnz
1011

1112
export ContextualSBM, ContextualSBMLatents, ContextualSBMObservations
1213
export affinities, effective_snr
13-
export init_amp, update_amp!, run_amp, evaluate_amp
14+
export init_amp, update_amp!, run_amp
15+
export overlap
1416

1517
include("utils.jl")
1618
include("csbm.jl")
1719
include("csbm_inference.jl")
1820

1921
@compile_workload begin
2022
rng = default_rng()
21-
csbm = ContextualSBM(; N=10^3, P=10^3, d=5, λ=2, μ=2, ρ=0.0)
23+
csbm = ContextualSBM(; N=10^2, P=10^2, d=5, λ=2, μ=2, ρ=0.0)
2224
(; observations) = rand(rng, csbm)
2325
run_amp(rng; observations, csbm, iterations=2)
2426
end

src/csbm.jl

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -79,13 +79,18 @@ end
7979
The observations generated by sampling from a [`ContextualSBM`](@ref).
8080
8181
# Fields
82-
83-
- `g::AbstractGraph`: undirected, unweighted graph generated from the adjacency matrix
82+
- `A::AbstractMatrix`: symmetric boolean adjacency matrix, size `(N, N)`
83+
- `g::AbstractGraph`: undirected unweighted graph generated from `A`
8484
- `B::Matrix`: feature matrix, size `(P, N)`
85+
- `Ξ::Vector`: revealed communities `±1` for a fraction of nodes and `0` for the rest, length `N`
8586
"""
86-
@kwdef struct ContextualSBMObservations{R<:Real,G<:AbstractGraph{Int}}
87-
B::Matrix{R}
87+
@kwdef struct ContextualSBMObservations{
88+
R<:Real,M<:AbstractMatrix{Bool},G<:AbstractGraph{Int}
89+
}
90+
A::M
8891
g::G
92+
B::Matrix{R}
93+
Ξ::Vector{Int}
8994
end
9095

9196
## Simulation
@@ -96,7 +101,7 @@ end
96101
Sample from a [`ContextualSBM`](@ref) and return a named tuple `(; latents, observations)`.
97102
"""
98103
function Base.rand(rng::AbstractRNG, csbm::ContextualSBM)
99-
(; μ, N, P) = csbm
104+
(; N, P, μ, ρ) = csbm
100105
(; cᵢ, cₒ) = affinities(csbm)
101106

102107
u = rand(rng, (-1, +1), N)
@@ -121,10 +126,15 @@ function Base.rand(rng::AbstractRNG, csbm::ContextualSBM)
121126
g = SimpleWeightedGraph(A)
122127

123128
B = randn(rng, P, N)
124-
for i in 1:N, α in 1:P
125-
B[α, i] += sqrt/ N) * v[α] * u[i]
129+
B .+= sqrt/ N) .* v .* u'
130+
131+
Ξ = zeros(Int, N)
132+
for i in 1:N
133+
if rand(rng) < ρ
134+
Ξ[i] = u[i]
135+
end
126136
end
127137

128-
observations = ContextualSBMObservations(; g, B)
138+
observations = ContextualSBMObservations(; A, g, B, Ξ)
129139
return (; latents, observations)
130140
end

src/csbm_inference.jl

Lines changed: 27 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -59,16 +59,15 @@ end
5959

6060
function init_amp(
6161
rng::AbstractRNG;
62-
observations::ContextualSBMObservations{R1},
63-
csbm::ContextualSBM{R2},
64-
init_std::R3,
65-
) where {R1,R2,R3}
66-
R = promote_type(R1, R2, R3)
62+
observations::ContextualSBMObservations{R},
63+
csbm::ContextualSBM{R},
64+
init_std,
65+
) where {R}
6766
(; N, P) = csbm
68-
(; g) = observations
67+
(; g, Ξ) = observations
6968

70-
= init_std .* randn(rng, R, N)
71-
= init_std .* randn(rng, R, P)
69+
= prior₊.(R, Ξ) + init_std .* randn(rng, R, N)
70+
= 2 .* prior₊.(R, Ξ) .- one(R) .+ init_std .* randn(rng, R, P)
7271
χ₊e = Dict{Tuple{Int,Int},R}()
7372
for i in 1:N, j in neighbors(g, i)
7473
χ₊e[i, j] = (one(R) / 2) + init_std * randn(rng, R)
@@ -86,15 +85,18 @@ function init_amp(
8685
return (; storage, next_storage, temp_storage)
8786
end
8887

88+
prior₊(::Type{R}, Ξᵢ) where {R} = Ξᵢ == 0 ? one(R) / 2 : R(Ξᵢ == 1)
89+
prior₋(::Type{R}, Ξᵢ) where {R} = Ξᵢ == 0 ? one(R) / 2 : R(Ξᵢ == -1)
90+
8991
function update_amp!(
90-
next_storage::AMPStorage,
91-
temp_storage::AMPTempStorage;
92-
storage::AMPStorage,
93-
observations::ContextualSBMObservations,
92+
next_storage::AMPStorage{R},
93+
temp_storage::AMPTempStorage{R};
94+
storage::AMPStorage{R},
95+
observations::ContextualSBMObservations{R},
9496
csbm::ContextualSBM{R},
9597
) where {R}
9698
(; d, λ, μ, N, P) = csbm
97-
(; g, B) = observations
99+
(; g, B, Ξ) = observations
98100
(; cᵢ, cₒ) = affinities(csbm)
99101

100102
ûᵗ, v̂ᵗ, χ₊eᵗ = storage.û, storage.v̂, storage.χ₊e
@@ -109,19 +111,19 @@ function update_amp!(
109111
mul!(v̂_no_comm, B, ûᵗ)
110112
v̂_no_comm .*= sqrt/ N)
111113
v̂_no_comm .-=/ N) .* v̂ᵗ .* (N - ûₜ_sum2)
112-
v̂ᵗ⁺¹ .= v̂_no_comm ./ (1 + σᵥ_no_comm)
113-
σᵥ = 1 / (1 + σᵥ_no_comm)
114+
v̂ᵗ⁺¹ .= v̂_no_comm ./ (one(R) + σᵥ_no_comm)
115+
σᵥ = one(R) / (one(R) + σᵥ_no_comm)
114116

115117
# BP estimation of u
116118
mul!(û_no_feat, B', v̂ᵗ⁺¹)
117119
û_no_feat .*= sqrt/ N)
118120
û_no_feat .-=/ (N / P)) .* σᵥ .* ûᵗ
119121

120122
# Estimation of the field h
121-
h₊ = (1 / 2N) * (cᵢ * (N + ûₜ_sum) + cₒ * (N - ûₜ_sum))
122-
h₋ = (1 / 2N) * (cₒ * (N + ûₜ_sum) + cᵢ * (N - ûₜ_sum))
123-
h̃₊ .= -h₊ .+ log(one(R) / 2) .+ û_no_feat
124-
h̃₋ .= -h₋ .+ log(one(R) / 2) .- û_no_feat
123+
h₊ = (one(R) / 2N) * (cᵢ * (N + ûₜ_sum) + cₒ * (N - ûₜ_sum))
124+
h₋ = (one(R) / 2N) * (cₒ * (N + ûₜ_sum) + cᵢ * (N - ûₜ_sum))
125+
h̃₊ .= -h₊ .+ log.(prior₊.(R, Ξ)) .+ û_no_feat
126+
h̃₋ .= -h₋ .+ log.(prior₋.(R, Ξ)) .- û_no_feat
125127

126128
# BP update of the marginals
127129
for i in 1:N
@@ -147,7 +149,7 @@ function update_amp!(
147149
end
148150

149151
# BP estimation of u
150-
ûᵗ⁺¹ .= 2 .* χ₊ .- 1
152+
ûᵗ⁺¹ .= 2 .* χ₊ .- one(R)
151153

152154
return nothing
153155
end
@@ -158,18 +160,21 @@ function run_amp(
158160
csbm::ContextualSBM,
159161
init_std::Real=1e-3,
160162
iterations::Integer=10,
163+
show_progress=false,
161164
)
162165
(; storage, next_storage, temp_storage) = init_amp(rng; observations, csbm, init_std)
163166
storage_history = [copy(storage)]
164-
for iter in 1:iterations
167+
prog = Progress(iterations; desc="AMP-BP", enabled=show_progress)
168+
for _ in 1:iterations
165169
update_amp!(next_storage, temp_storage; storage, observations, csbm)
166170
copy!(storage, next_storage)
167171
push!(storage_history, copy(storage))
172+
next!(prog)
168173
end
169174
return storage_history
170175
end
171176

172-
function evaluate_amp(; storage::AMPStorage, latents::ContextualSBMLatents)
177+
function overlap(; storage::AMPStorage, latents::ContextualSBMLatents)
173178
= sign.(storage.û)
174179
@assert all(abs.(û) .> 0.5)
175180
u = latents.u

test/csbm.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,8 @@ function test_recovery(csbm::ContextualSBM)
1010
@assert effective_snr(csbm) > 1
1111
(; latents, observations) = rand(rng, csbm)
1212
storage_history = run_amp(rng; observations, csbm)
13-
overlap_history = [evaluate_amp(; storage, latents) for storage in storage_history]
14-
@test last(overlap_history) > 10 * first(overlap_history)
13+
overlap_history = [overlap(; storage, latents) for storage in storage_history]
14+
@test last(overlap_history) > first(overlap_history)
1515
@test last(overlap_history) > 0.5
1616
return nothing
1717
end
@@ -41,6 +41,7 @@ end
4141
test_recovery(ContextualSBM(; N=10^3, P=10^3, d=5, λ=2, μ=2, ρ=0.0)) # AMP-BP
4242
test_recovery(ContextualSBM(; N=10^3, P=10^3, d=5, λ=0, μ=2, ρ=0.0)) # AMP
4343
test_recovery(ContextualSBM(; N=10^3, P=10^3, d=5, λ=2, μ=0, ρ=0.0)) # BP
44+
test_recovery(ContextualSBM(; N=10^2, P=10^2, d=5, λ=2, μ=0, ρ=0.5)) # semi-supervised
4445
end
4546

4647
@testset "Good code" begin

0 commit comments

Comments
 (0)