Skip to content

Commit bc8574b

Browse files
committed
First plot
1 parent 55fa515 commit bc8574b

File tree

4 files changed

+49
-6
lines changed

4 files changed

+49
-6
lines changed

docs/plots.jl

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,40 @@ using Base.Threads
22
using CairoMakie
33
using LinearAlgebra
44
using Random: default_rng
5+
using Statistics
56
using StochasticBlockModelVariants
67
using ProgressMeter
78

89
BLAS.set_num_threads(1)
10+
11+
function compute_fig1(; N=3 * 10^3, P=N ÷ 10, d=5, μ=2, ρ=0.1, λ_values=0:0.1:2, trials=10)
12+
rng = default_rng()
13+
qᵤ_values = zeros(trials, length(λ_values))
14+
qᵥ_values = zeros(trials, length(λ_values))
15+
prog = Progress(trials * length(λ_values); desc="Fig 1 - trials")
16+
@threads for i in 1:trials
17+
for j in eachindex(λ_values)
18+
λ = λ_values[j]
19+
csbm = ContextualSBM(; N, P, d, μ, λ, ρ)
20+
(; qᵤ, qᵥ) = evaluate_amp(rng; csbm)
21+
qᵤ_values[i, j] = qᵤ
22+
qᵥ_values[i, j] = qᵥ
23+
next!(prog)
24+
end
25+
end
26+
return (; λ_values, qᵤ_values, qᵥ_values)
27+
end
28+
29+
function plot_fig1(res)
30+
f = Figure()
31+
ax = Axis(f[1, 1]; title="Fig 1", xlabel="λ", ylabel="qᵤ", limits=(0, 2, 0, 1))
32+
qᵤ_means = dropdims(mean(res.qᵤ_values; dims=1); dims=1)
33+
qᵤ_stds = dropdims(std(res.qᵤ_values; dims=1); dims=1)
34+
lines!(ax, λ_values, qᵤ_means)
35+
scatter!(ax, λ_values, qᵤ_means)
36+
errorbars!(ax, λ_values, qᵤ_means, qᵤ_stds)
37+
return f
38+
end
39+
40+
res1 = compute_fig1()
41+
plot_fig1(res1)

src/StochasticBlockModelVariants.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,8 @@ using SparseArrays: SparseMatrixCSC, sparse, findnz
1111

1212
export ContextualSBM, ContextualSBMLatents, ContextualSBMObservations
1313
export affinities, effective_snr
14-
export init_amp, update_amp!, run_amp, overlaps
14+
export overlaps
15+
export init_amp, update_amp!, run_amp, evaluate_amp
1516

1617
include("utils.jl")
1718
include("csbm.jl")

src/csbm_inference.jl

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -175,11 +175,11 @@ function run_amp(
175175
rng::AbstractRNG;
176176
observations::ContextualSBMObservations{R},
177177
csbm::ContextualSBM{R},
178-
init_std::Real=1e-3,
179-
max_iterations::Integer=200,
178+
init_std=1e-3,
179+
max_iterations=200,
180180
convergence_threshold=1e-3,
181181
recent_past=10,
182-
show_progress::Bool=false,
182+
show_progress=false,
183183
) where {R}
184184
(; N, P) = csbm
185185
(; marginals, next_marginals, storage) = init_amp(rng; observations, csbm, init_std)
@@ -222,3 +222,12 @@ function run_amp(
222222

223223
return (; û_history, v̂_history, converged)
224224
end
225+
226+
function evaluate_amp(rng::AbstractRNG; csbm::ContextualSBM, kwargs...)
227+
(; latents, observations) = rand(rng, csbm)
228+
(; û_history, v̂_history, converged) = run_amp(rng; observations, csbm, kwargs...)
229+
(; qᵤ, qᵥ) = overlaps(;
230+
u=latents.u, v=latents.v, û=û_history[:, end], v̂=v̂_history[:, end]
231+
)
232+
return (; qᵤ, qᵥ)
233+
end

test/csbm.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,11 @@ function test_recovery(csbm::ContextualSBM; test_u=true, test_v=true)
1414
first_q = overlaps(; u, v, û=û_history[:, begin], v̂=v̂_history[:, begin])
1515
last_q = overlaps(; u, v, û=û_history[:, end], v̂=v̂_history[:, end])
1616
if test_u
17-
@test last_q.qᵤ > first_q.qᵤ
17+
@test last_q.qᵤ >= first_q.qᵤ
1818
@test last_q.qᵤ > 0.5
1919
end
2020
if test_v
21-
@test last_q.qᵥ > first_q.qᵥ
21+
@test last_q.qᵥ >= first_q.qᵥ
2222
@test last_q.qᵥ > 0.5
2323
end
2424
return nothing

0 commit comments

Comments
 (0)