Skip to content

Commit 55fa515

Browse files
committed
Convergence criterion
1 parent ebde6d0 commit 55fa515

File tree

7 files changed

+137
-75
lines changed

7 files changed

+137
-75
lines changed

Manifest.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
julia_version = "1.9.2"
44
manifest_format = "2.0"
5-
project_hash = "eabebeb24d3db0211e95e73a9e2e965bf16ba4e1"
5+
project_hash = "dd3777d035d7f906717e2a66a2e6f5c3854a6b1a"
66

77
[[deps.ArgTools]]
88
uuid = "0dad84c5-d112-42e6-8d28-ef12dabb789f"

docs/Project.toml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,8 @@
11
[deps]
2+
CairoMakie = "13f3f980-e62b-5c42-98c6-ff1f3baf88f0"
23
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
4+
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
5+
ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca"
6+
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
7+
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
38
StochasticBlockModelVariants = "37adabcb-1964-4c56-850d-39a4302c0c39"

docs/plots.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
using Base.Threads
2+
using CairoMakie
3+
using LinearAlgebra
4+
using Random: default_rng
5+
using StochasticBlockModelVariants
6+
using ProgressMeter
7+
8+
BLAS.set_num_threads(1)
Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,27 @@
11
module StochasticBlockModelVariants
22

33
using Graphs: AbstractGraph, neighbors
4-
using LinearAlgebra: dot, mul!
4+
using LinearAlgebra: dot, mul!, norm
55
using PrecompileTools: @compile_workload
66
using ProgressMeter: Progress, next!
77
using Random: AbstractRNG, default_rng
88
using SimpleWeightedGraphs: SimpleWeightedGraph
9-
using Statistics: mean
9+
using Statistics: mean, std
1010
using SparseArrays: SparseMatrixCSC, sparse, findnz
1111

1212
export ContextualSBM, ContextualSBMLatents, ContextualSBMObservations
1313
export affinities, effective_snr
14-
export init_amp, update_amp!, run_amp
15-
export overlap
14+
export init_amp, update_amp!, run_amp, overlaps
1615

1716
include("utils.jl")
1817
include("csbm.jl")
1918
include("csbm_inference.jl")
2019

2120
@compile_workload begin
2221
rng = default_rng()
23-
csbm = ContextualSBM(; N=10^2, P=10^2, d=5, λ=2, μ=2, ρ=0.0)
22+
csbm = ContextualSBM(; N=10^2, P=10^2, d=5, λ=2, μ=2, ρ=0.1)
2423
(; observations) = rand(rng, csbm)
25-
run_amp(rng; observations, csbm, iterations=2)
24+
run_amp(rng; observations, csbm, max_iterations=2)
2625
end
2726

2827
end

src/csbm_inference.jl

Lines changed: 94 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -1,33 +1,50 @@
1-
## Storage
1+
## Marginals
22

33
"""
4-
AMPStorage
4+
AMPMarginals
55
66
# Fields
77
88
- `û::Vector`: posterior mean of `u`, length `N`
99
- `v̂::Vector`: posterior mean of `v`, length `P`
1010
- `χ₊e::Dict`: messages about the marginal distribution of `u`, size `(N, N)`
1111
"""
12-
@kwdef struct AMPStorage{R<:Real}
12+
@kwdef struct AMPMarginals{R<:Real}
1313
::Vector{R}
1414
::Vector{R}
1515
χ₊e::Dict{Tuple{Int,Int},R}
1616
end
1717

18-
function Base.copy(storage::AMPStorage)
19-
return AMPStorage(; û=copy(storage.û), v̂=copy(storage.v̂), χ₊e=copy(storage.χ₊e))
18+
function Base.copy(marginals::AMPMarginals)
19+
return AMPMarginals(;
20+
=copy(marginals.û), v̂=copy(marginals.v̂), χ₊e=copy(marginals.χ₊e)
21+
)
22+
end
23+
24+
function Base.copy!(marginals_dest::AMPMarginals, marginals_source::AMPMarginals)
25+
marginals_dest.û .= marginals_source.
26+
marginals_dest.v̂ .= marginals_source.
27+
copy!(marginals_dest.χ₊e, marginals_source.χ₊e)
28+
return marginals_dest
2029
end
2130

22-
function Base.copy!(storage_dest::AMPStorage, storage_source::AMPStorage)
23-
storage_dest.û .= storage_source.
24-
storage_dest.v̂ .= storage_source.
25-
copy!(storage_dest.χ₊e, storage_source.χ₊e)
26-
return storage_dest
31+
function overlaps(;
32+
u::Vector{<:Integer}, v::Vector{R}, û::Vector{R}, v̂::Vector{R}
33+
) where {R}
34+
û .= sign.(û)
35+
û[abs.(û) .< eps(R)] .= one(R)
36+
37+
q̂ᵤ = max(freq_equalities(û, u), freq_equalities(û, -u))
38+
qᵤ = 2 * (q̂ᵤ - one(R) / 2)
39+
40+
q̂ᵥ = max(abs(dot(v̂, v)), abs(dot(v̂, -v)))
41+
qᵥ = q̂ᵥ / (eps(R) + norm(v̂) * norm(v))
42+
43+
return (; qᵤ, qᵥ)
2744
end
2845

2946
"""
30-
AMPTempStorage
47+
AMPStorage
3148
3249
# Fields
3350
@@ -37,21 +54,21 @@ end
3754
- `h̃₋::Vector`: individual external field for `u=-1`, length `N`
3855
- `χ₊::Vector`: marginal probability of `u=1`, length `N`
3956
"""
40-
@kwdef struct AMPTempStorage{R<:Real}
57+
@kwdef struct AMPStorage{R<:Real}
4158
û_no_feat::Vector{R}
4259
v̂_no_comm::Vector{R}
4360
h̃₊::Vector{R}
4461
h̃₋::Vector{R}
4562
χ₊::Vector{R}
4663
end
4764

48-
function Base.copy(temp_storage::AMPTempStorage)
49-
return AMPTempStorage(;
50-
û_no_feat=copy(temp_storage.û_no_feat),
51-
v̂_no_comm=copy(temp_storage.v̂_no_comm),
52-
h̃₊=copy(temp_storage.h̃₊),
53-
h̃₋=copy(temp_storage.h̃₋),
54-
χ₊=copy(temp_storage.χ₊),
65+
function Base.copy(storage::AMPStorage)
66+
return AMPStorage(;
67+
û_no_feat=copy(storage.û_no_feat),
68+
v̂_no_comm=copy(storage.v̂_no_comm),
69+
h̃₊=copy(storage.h̃₊),
70+
h̃₋=copy(storage.h̃₋),
71+
χ₊=copy(storage.χ₊),
5572
)
5673
end
5774

@@ -66,42 +83,42 @@ function init_amp(
6683
(; N, P) = csbm
6784
(; g, Ξ) = observations
6885

69-
= prior₊.(R, Ξ) + init_std .* randn(rng, R, N)
70-
= 2 .* prior₊.(R, Ξ) .- one(R) .+ init_std .* randn(rng, R, P)
86+
= 2 .* prior₊.(R, Ξ) .- one(R) .+ init_std .* randn(rng, R, N)
87+
= init_std .* randn(rng, R, P)
7188
χ₊e = Dict{Tuple{Int,Int},R}()
7289
for i in 1:N, j in neighbors(g, i)
73-
χ₊e[i, j] = (one(R) / 2) + init_std * randn(rng, R)
90+
χ₊e[i, j] = prior₊(R, Ξ[i]) + init_std * randn(rng, R)
7491
end
75-
storage = AMPStorage(; û, v̂, χ₊e)
76-
next_storage = copy(storage)
92+
marginals = AMPMarginals(; û, v̂, χ₊e)
93+
next_marginals = copy(marginals)
7794

7895
û_no_feat = zeros(R, N)
7996
v̂_no_comm = zeros(R, P)
8097
h̃₊ = zeros(R, N)
8198
h̃₋ = zeros(R, N)
8299
χ₊ = zeros(R, N)
83-
temp_storage = AMPTempStorage(; û_no_feat, v̂_no_comm, h̃₊, h̃₋, χ₊)
100+
storage = AMPStorage(; û_no_feat, v̂_no_comm, h̃₊, h̃₋, χ₊)
84101

85-
return (; storage, next_storage, temp_storage)
102+
return (; marginals, next_marginals, storage)
86103
end
87104

88105
prior₊(::Type{R}, Ξᵢ) where {R} = Ξᵢ == 0 ? one(R) / 2 : R(Ξᵢ == 1)
89106
prior₋(::Type{R}, Ξᵢ) where {R} = Ξᵢ == 0 ? one(R) / 2 : R(Ξᵢ == -1)
90107

91108
function update_amp!(
92-
next_storage::AMPStorage{R},
93-
temp_storage::AMPTempStorage{R};
94-
storage::AMPStorage{R},
109+
next_marginals::AMPMarginals{R},
110+
storage::AMPStorage{R};
111+
marginals::AMPMarginals{R},
95112
observations::ContextualSBMObservations{R},
96113
csbm::ContextualSBM{R},
97114
) where {R}
98115
(; d, λ, μ, N, P) = csbm
99116
(; g, B, Ξ) = observations
100117
(; cᵢ, cₒ) = affinities(csbm)
101118

102-
ûᵗ, v̂ᵗ, χ₊eᵗ = storage.û, storage.v̂, storage.χ₊e
103-
ûᵗ⁺¹, v̂ᵗ⁺¹, χ₊eᵗ⁺¹ = next_storage.û, next_storage.v̂, next_storage.χ₊e
104-
(; û_no_feat, v̂_no_comm, h̃₊, h̃₋, χ₊) = temp_storage
119+
ûᵗ, v̂ᵗ, χ₊eᵗ = marginals.û, marginals.v̂, marginals.χ₊e
120+
ûᵗ⁺¹, v̂ᵗ⁺¹, χ₊eᵗ⁺¹ = next_marginals.û, next_marginals.v̂, next_marginals.χ₊e
121+
(; û_no_feat, v̂_no_comm, h̃₊, h̃₋, χ₊) = storage
105122

106123
ûₜ_sum = sum(ûᵗ)
107124
ûₜ_sum2 = sum(abs2, ûᵗ)
@@ -156,29 +173,52 @@ end
156173

157174
function run_amp(
158175
rng::AbstractRNG;
159-
observations::ContextualSBMObservations,
160-
csbm::ContextualSBM,
176+
observations::ContextualSBMObservations{R},
177+
csbm::ContextualSBM{R},
161178
init_std::Real=1e-3,
162-
iterations::Integer=10,
163-
show_progress=false,
164-
)
165-
(; storage, next_storage, temp_storage) = init_amp(rng; observations, csbm, init_std)
166-
storage_history = [copy(storage)]
167-
prog = Progress(iterations; desc="AMP-BP", enabled=show_progress)
168-
for _ in 1:iterations
169-
update_amp!(next_storage, temp_storage; storage, observations, csbm)
170-
copy!(storage, next_storage)
171-
push!(storage_history, copy(storage))
172-
next!(prog)
179+
max_iterations::Integer=200,
180+
convergence_threshold=1e-3,
181+
recent_past=10,
182+
show_progress::Bool=false,
183+
) where {R}
184+
(; N, P) = csbm
185+
(; marginals, next_marginals, storage) = init_amp(rng; observations, csbm, init_std)
186+
187+
û_history = Matrix{R}(undef, N, max_iterations)
188+
v̂_history = Matrix{R}(undef, P, max_iterations)
189+
converged = false
190+
prog = Progress(max_iterations; desc="AMP-BP", enabled=show_progress)
191+
192+
for t in 1:max_iterations
193+
update_amp!(next_marginals, storage; marginals, observations, csbm)
194+
copy!(marginals, next_marginals)
195+
196+
û_history[:, t] .= marginals.
197+
v̂_history[:, t] .= marginals.
198+
199+
if t <= recent_past
200+
û_recent_std = typemax(R)
201+
v̂_recent_std = typemax(R)
202+
else
203+
û_recent_std = maximum(std(view(û_history, :, (t - recent_past):t); dims=2))
204+
v̂_recent_std = maximum(std(view(v̂_history, :, (t - recent_past):t); dims=2))
205+
end
206+
converged = (
207+
û_recent_std < convergence_threshold && v̂_recent_std < convergence_threshold
208+
)
209+
if converged
210+
û_history = û_history[:, 1:t]
211+
v̂_history = v̂_history[:, 1:t]
212+
break
213+
else
214+
showvalues = [
215+
(:û_recent_std, û_recent_std),
216+
(:v̂_recent_std, v̂_recent_std),
217+
(:convergence_threshold, convergence_threshold),
218+
]
219+
next!(prog; showvalues)
220+
end
173221
end
174-
return storage_history
175-
end
176222

177-
function overlap(; storage::AMPStorage, latents::ContextualSBMLatents)
178-
= sign.(storage.û)
179-
@assert all(abs.(û) .> 0.5)
180-
u = latents.u
181-
q̂ᵤ = (1 / length(û)) * max(count_equalities(û, u), count_equalities(û, -u))
182-
qᵤ = 2 * (q̂ᵤ - 0.5)
183-
return qᵤ
223+
return (; û_history, v̂_history, converged)
184224
end

src/utils.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
sigmoid(x) = 1 / (1 + exp(-x))
22

3-
count_equalities(x, y) = sum(x[i] y[i] for i in eachindex(x, y))
3+
freq_equalities(x, y) = mean(x[i] y[i] for i in eachindex(x, y))

test/csbm.jl

Lines changed: 23 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4,44 +4,54 @@ using Statistics
44
using StochasticBlockModelVariants
55
using Test
66

7-
rng = default_rng()
8-
9-
function test_recovery(csbm::ContextualSBM)
7+
function test_recovery(csbm::ContextualSBM; test_u=true, test_v=true)
8+
rng = default_rng()
109
@assert effective_snr(csbm) > 1
1110
(; latents, observations) = rand(rng, csbm)
12-
storage_history = run_amp(rng; observations, csbm)
13-
overlap_history = [overlap(; storage, latents) for storage in storage_history]
14-
@test last(overlap_history) > first(overlap_history)
15-
@test last(overlap_history) > 0.5
11+
(; u, v) = latents
12+
(; û_history, v̂_history, converged) = run_amp(rng; observations, csbm)
13+
@test converged
14+
first_q = overlaps(; u, v, û=û_history[:, begin], v̂=v̂_history[:, begin])
15+
last_q = overlaps(; u, v, û=û_history[:, end], v̂=v̂_history[:, end])
16+
if test_u
17+
@test last_q.qᵤ > first_q.qᵤ
18+
@test last_q.qᵤ > 0.5
19+
end
20+
if test_v
21+
@test last_q.qᵥ > first_q.qᵥ
22+
@test last_q.qᵥ > 0.5
23+
end
1624
return nothing
1725
end
1826

1927
function test_jet(csbm::ContextualSBM)
28+
rng = default_rng()
2029
(; observations) = rand(rng, csbm)
2130
@test_opt target_modules = (StochasticBlockModelVariants,) run_amp(
22-
rng; observations, csbm, iterations=2
31+
rng; observations, csbm, max_iterations=2
2332
)
2433
@test_call target_modules = (StochasticBlockModelVariants,) run_amp(
25-
rng; observations, csbm, iterations=2
34+
rng; observations, csbm, max_iterations=2
2635
)
2736
return nothing
2837
end
2938

3039
function test_allocations(csbm::ContextualSBM)
40+
rng = default_rng()
3141
(; observations) = rand(rng, csbm)
32-
(; storage, next_storage, temp_storage) = init_amp(
42+
(; marginals, next_marginals, storage) = init_amp(
3343
rng; observations, csbm, init_std=1e-3
3444
)
35-
alloc = @allocated update_amp!(next_storage, temp_storage; storage, observations, csbm)
45+
alloc = @allocated update_amp!(next_marginals, storage; marginals, observations, csbm)
3646
@test alloc == 0
3747
return nothing
3848
end
3949

4050
@testset "Correct code" begin
4151
test_recovery(ContextualSBM(; N=10^3, P=10^3, d=5, λ=2, μ=2, ρ=0.0)) # AMP-BP
4252
test_recovery(ContextualSBM(; N=10^3, P=10^3, d=5, λ=0, μ=2, ρ=0.0)) # AMP
43-
test_recovery(ContextualSBM(; N=10^3, P=10^3, d=5, λ=2, μ=0, ρ=0.0)) # BP
44-
test_recovery(ContextualSBM(; N=10^2, P=10^2, d=5, λ=2, μ=0, ρ=0.5)) # semi-supervised
53+
test_recovery(ContextualSBM(; N=10^3, P=10^3, d=5, λ=2, μ=0, ρ=0.0); test_v=false) # BP
54+
test_recovery(ContextualSBM(; N=10^2, P=10^2, d=5, λ=2, μ=2, ρ=0.5)) # semi-supervised
4555
end
4656

4757
@testset "Good code" begin

0 commit comments

Comments
 (0)