Skip to content

Commit 08ad19a

Browse files
committed
Plot kills Julia
1 parent 8476c13 commit 08ad19a

File tree

5 files changed

+87
-21
lines changed

5 files changed

+87
-21
lines changed

docs/plots_csbm.jl

Lines changed: 21 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -8,35 +8,38 @@ using ProgressMeter
88

99
BLAS.set_num_threads(1)
1010

11-
function compute_fig1(; α, d, μ, ρ, N_values, λ_values, trials)
11+
function compute_fig1_csbm(; α, d, μ, ρ, N_values, λ_values, trials)
1212
rng = default_rng()
1313
I, J, K = length(N_values), length(λ_values), trials
1414
qu_values = zeros(I, J, K)
1515
qv_values = zeros(I, J, K)
1616
converged_values = zeros(Bool, I, J, K)
1717
prog = Progress(I * J * K; desc="CSBM - Fig 1")
18-
@threads for i in 1:I, j in 1:J, k in 1:K
19-
N, λ = N_values[i], λ_values[j]
20-
P = ceil(Int, N / α)
21-
csbm = CSBM(; N, P, d, μ, λ, ρ)
22-
(qu, qv, converged) = evaluate_amp(rng, csbm)
23-
qu_values[i, j, k] = qu
24-
qv_values[i, j, k] = qv
25-
converged_values[i, j, k] = converged
26-
next!(prog)
18+
for i in 1:I
19+
for j in 1:J
20+
@threads for k in 1:K
21+
N, λ = N_values[i], λ_values[j]
22+
P = ceil(Int, N / α)
23+
csbm = CSBM(; N, P, d, μ, λ, ρ)
24+
(qu, qv, converged) = evaluate_amp(rng, csbm)
25+
qu_values[i, j, k] = qu
26+
qv_values[i, j, k] = qv
27+
converged_values[i, j, k] = converged
28+
next!(prog)
29+
end
30+
end
2731
end
2832
return (; qu_values, qv_values)
2933
end
3034

31-
function plot_fig1(res; N_values, λ_values)
35+
function plot_fig1_csbm(res; N_values, λ_values)
3236
qu_values, qv_values = res
3337
f = Figure()
3438
for (i, N) in enumerate(N_values)
35-
ax = Axis(f[i, 1]; title=" N = $N", xlabel="λ", ylabel="qᵤ", limits=(0, 2, 0, 1))
36-
qu_means = dropdims(mean(qu_values[i, :, :]; dims=1); dims=1)
37-
qu_stds = dropdims(std(qu_values[i, :, :]; dims=1); dims=1)
39+
ax = Axis(f[i, 1]; title="N = $N", xlabel="λ", ylabel="q_u", limits=(0, 2, 0, 1))
40+
qu_means = dropdims(mean(qu_values[i, :, :]; dims=2); dims=2)
41+
qu_stds = dropdims(std(qu_values[i, :, :]; dims=2); dims=2)
3842
lines!(ax, λ_values, qu_means)
39-
scatter!(ax, λ_values, qu_means)
4043
errorbars!(ax, λ_values, qu_means, qu_stds)
4144
end
4245
return f
@@ -45,10 +48,10 @@ end
4548
α = 10
4649
d = 5
4750
μ = 2
48-
ρ = 0.1
51+
ρ = 0.0
4952
N_values = reverse([3 * 10^3, 10^4, 3 * 10^4, 10^5])
5053
N_values = reverse([3 * 10^3, 10^4])
5154
λ_values = 0:0.1:2
5255
trials = 10
53-
res1 = compute_fig1(; α, d, μ, ρ, N_values, λ_values, trials)
54-
plot_fig1(res1; N_values, λ_values)
56+
res1 = compute_fig1_csbm(; α, d, μ, ρ, N_values, λ_values, trials)
57+
plot_fig1_csbm(res1; N_values, λ_values)

docs/plots_glmsbm.jl

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
using Base.Threads
2+
using CairoMakie
3+
using LinearAlgebra
4+
using Random: default_rng
5+
using Statistics
6+
using StochasticBlockModelVariants
7+
using ProgressMeter
8+
9+
BLAS.set_num_threads(1)
10+
11+
function compute_fig1_glmsbm(; N, c, ρ, Pʷ, α_values, λ_values, trials)
12+
rng = default_rng()
13+
I, J, K = length(α_values), length(λ_values), trials
14+
qs_values = zeros(I, J, K)
15+
qw_values = zeros(I, J, K)
16+
converged_values = zeros(Bool, I, J, K)
17+
prog = Progress(I * J * K; desc="CSBM - Fig 1")
18+
for i in 1:I
19+
for j in 1:J
20+
@threads for k in 1:K
21+
α, λ = α_values[i], λ_values[j]
22+
M = ceil(Int, N / α)
23+
glmsbm = GLMSBM(; N, M, c, λ, ρ, Pʷ)
24+
(qs, qw, converged) = evaluate_amp(rng, glmsbm)
25+
qs_values[i, j, k] = qs
26+
qw_values[i, j, k] = qw
27+
converged_values[i, j, k] = converged
28+
next!(prog)
29+
end
30+
end
31+
end
32+
return (; qs_values, qw_values)
33+
end
34+
35+
function plot_fig1_glmsbm(res; α_values, λ_values)
36+
qs_values, qw_values = res
37+
f = Figure()
38+
ax1 = Axis(f[1, 1]; title="Communities", xlabel="λ", ylabel="q_s", limits=(0, 2, 0, 1))
39+
ax2 = Axis(f[1, 2]; title="Weights", xlabel="λ", ylabel="q_w", limits=(0, 2, 0, 1))
40+
for (i, α) in enumerate(α_values)
41+
qs_means = dropdims(mean(qs_values[i, :, :]; dims=2); dims=2)
42+
qs_stds = dropdims(std(qs_values[i, :, :]; dims=2); dims=2)
43+
lines!(ax1, λ_values, qs_means; label="α=")
44+
errorbars!(ax1, λ_values, qs_means, qs_stds; label=nothing)
45+
46+
qw_means = dropdims(mean(qs_values[i, :, :]; dims=2); dims=2)
47+
qw_stds = dropdims(std(qs_values[i, :, :]; dims=2); dims=2)
48+
lines!(ax2, λ_values, qw_means; label="α=")
49+
errorbars!(ax2, λ_values, qw_means, qw_stds; label=nothing)
50+
end
51+
return f
52+
end
53+
54+
N = 10^4
55+
c = 5
56+
ρ = 0.0
57+
= GaussianWeightPrior()
58+
α_values = reverse([0.3, 1, 3, 10, 30])
59+
α_values = reverse([0.3, 1])
60+
λ_values = 0:0.1:2
61+
trials = 10
62+
res1 = compute_fig1_glmsbm(; N, c, ρ, Pʷ, α_values, λ_values, trials) # kills Julia
63+
plot_fig1_glmsbm(res1; α_values, λ_values)

src/csbm_inference.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ function run_amp(
129129
init_std=1e-3,
130130
max_iterations=200,
131131
convergence_threshold=1e-3,
132-
recent_past=10,
132+
recent_past=5,
133133
damping=0.0,
134134
show_progress=false,
135135
)

src/glmsbm_inference.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,7 @@ function run_amp(
170170
init_std=1e-3,
171171
max_iterations=100,
172172
convergence_threshold=1e-3,
173-
recent_past=10,
173+
recent_past=5,
174174
damping=0.5,
175175
show_progress=false,
176176
)

src/precompile.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,6 @@
55
GLMSBM(; N=10^2, M=10^2, c=5, λ=2, ρ=0.0, Pʷ=GaussianWeightPrior()),
66
GLMSBM(; N=10^2, M=10^2, c=5, λ=2, ρ=0.1, Pʷ=RademacherWeightPrior()),
77
]
8-
evaluate_amp(rng, sbm)
8+
evaluate_amp(rng, sbm; max_iterations=2)
99
end
1010
end

0 commit comments

Comments
 (0)