Skip to content

Commit aa9914d

Browse files
committed
Working GLM-SBM ?
1 parent aeb63f9 commit aa9914d

File tree

14 files changed

+123
-146
lines changed

14 files changed

+123
-146
lines changed

.vscode/settings.json

Lines changed: 0 additions & 3 deletions
This file was deleted.

Manifest.toml

Lines changed: 1 addition & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
julia_version = "1.10.0-beta2"
44
manifest_format = "2.0"
5-
project_hash = "ba25bdf8dd9a1a3a609260b65801f3325dbdaa20"
5+
project_hash = "06fe3f2e3a2d6f17958684cbeb77c96e1e718454"
66

77
[[deps.ArgTools]]
88
uuid = "0dad84c5-d112-42e6-8d28-ef12dabb789f"
@@ -75,12 +75,6 @@ git-tree-sha1 = "1cf1d7dcb4bc32d7b4a5add4232db3750c27ecb4"
7575
uuid = "86223c79-3864-5bf0-83f7-82e725a168b6"
7676
version = "1.8.0"
7777

78-
[[deps.Infiltrator]]
79-
deps = ["InteractiveUtils", "Markdown", "REPL", "UUIDs"]
80-
git-tree-sha1 = "04de041e4590428cccbb026d86e5d670513be2e3"
81-
uuid = "5903a43b-9cc3-4c30-8d17-598619ec4e9b"
82-
version = "1.6.4"
83-
8478
[[deps.Inflate]]
8579
git-tree-sha1 = "5cd07aab533df5170988219191dfad0519391428"
8680
uuid = "d25df0c9-e2be-5dd7-82c8-3ad0b3e990b9"
@@ -149,12 +143,6 @@ version = "0.3.26"
149143
ChangesOfVariables = "9e997f8a-9a97-42d5-a9f1-ce6bfc15e2c0"
150144
InverseFunctions = "3587e190-3f89-42d0-90ee-14403ec27112"
151145

152-
[[deps.LogarithmicNumbers]]
153-
deps = ["Random"]
154-
git-tree-sha1 = "8522befb54ff3b4bcf17d57b14b884d536a22015"
155-
uuid = "aa2f6b4e-9042-5d33-9679-40d3a6b85899"
156-
version = "1.2.1"
157-
158146
[[deps.Logging]]
159147
uuid = "56ddb016-857b-54e1-b83d-db4d58db5568"
160148

Project.toml

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,7 @@ version = "0.1.0"
77
DensityInterface = "b429d917-457f-4dbc-8f4c-0cc954292b1d"
88
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
99
Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6"
10-
Infiltrator = "5903a43b-9cc3-4c30-8d17-598619ec4e9b"
1110
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
12-
LogarithmicNumbers = "aa2f6b4e-9042-5d33-9679-40d3a6b85899"
1311
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
1412
ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca"
1513
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
@@ -20,6 +18,7 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
2018

2119
[compat]
2220
DensityInterface = "0.4"
21+
DocStringExtensions = "0.9"
2322
Graphs = "1.8"
2423
PrecompileTools = "1.1"
2524
ProgressMeter = "1.7"

src/StochasticBlockModelVariants.jl

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@ module StochasticBlockModelVariants
1212
using DensityInterface: logdensityof, densityof
1313
using DocStringExtensions
1414
using Graphs: AbstractGraph, neighbors
15-
using Infiltrator
1615
using LinearAlgebra: LinearAlgebra, dot, mul!, norm, normalize!
1716
using PrecompileTools: @compile_workload
1817
using ProgressMeter: Progress, next!
@@ -22,10 +21,11 @@ using SpecialFunctions: erf
2221
using Statistics: mean, std
2322
using SparseArrays: SparseMatrixCSC, sparse, findnz
2423

24+
export AbstractSBM
2525
export nb_features, average_degree, communities_snr, affinities, features_snr, effective_snr
2626
export CSBM, LatentsCSBM, ObservationsCSBM
2727
export GLMSBM, LatentsGLMSBM, ObservationsGLMSBM
28-
export overlaps
28+
export GaussianWeightPrior, RademacherWeightPrior
2929
export init_amp, update_amp!, run_amp, evaluate_amp
3030

3131
include("abstract_sbm.jl")
@@ -34,11 +34,4 @@ include("glmsbm.jl")
3434
include("csbm_inference.jl")
3535
include("glmsbm_inference.jl")
3636

37-
# @compile_workload begin
38-
# rng = default_rng()
39-
# csbm = CSBM(; N=10^2, P=10^2, d=5, λ=2, μ=2, ρ=0.1)
40-
# (; observations) = rand(rng, csbm)
41-
# run_amp(rng; observations, csbm, max_iterations=2)
42-
# end
43-
4437
end

src/abstract_sbm.jl

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -105,13 +105,17 @@ sigmoid(x) = 1 / (1 + exp(-x))
105105
freq_equalities(x, y) = mean(x[i] y[i] for i in eachindex(x, y))
106106

107107
function discrete_overlap(u, û)
108-
q̂ᵤ = max(freq_equalities(û, u), freq_equalities(û, -u))
108+
R = eltype(û)
109+
û_sign = sign.(û)
110+
û_sign[abs.(û) .< eps(R)] .= 1
111+
q̂ᵤ = max(freq_equalities(u, û_sign), freq_equalities(u, -û_sign))
109112
qᵤ = 2 * (q̂ᵤ - one(R) / 2)
110113
return qᵤ
111114
end
112115

113116
function continuous_overlap(v, v̂)
117+
R = eltype(v)
114118
q̂ᵥ = max(abs(dot(v̂, v)), abs(dot(v̂, -v)))
115-
qᵥ = q̂ᵥ / (eps() + norm(v̂) * norm(v))
119+
qᵥ = q̂ᵥ / (eps(R) + norm(v̂) * norm(v))
116120
return qᵥ
117121
end

src/csbm_inference.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,7 @@ function run_amp(
152152
observations::ObservationsCSBM,
153153
csbm::CSBM;
154154
init_std=1e-3,
155-
max_iterations=200,
155+
max_iterations=100,
156156
convergence_threshold=1e-3,
157157
recent_past=10,
158158
show_progress=false,
@@ -177,8 +177,8 @@ function run_amp(
177177
û_recent_std = typemax(R)
178178
v̂_recent_std = typemax(R)
179179
else
180-
û_recent_std = maximum(std(view(û_history, :, (t - recent_past):t); dims=2))
181-
v̂_recent_std = maximum(std(view(v̂_history, :, (t - recent_past):t); dims=2))
180+
û_recent_std = mean(std(view(û_history, :, (t - recent_past):t); dims=2))
181+
v̂_recent_std = mean(std(view(v̂_history, :, (t - recent_past):t); dims=2))
182182
end
183183
converged = (
184184
û_recent_std < convergence_threshold && v̂_recent_std < convergence_threshold
@@ -203,7 +203,7 @@ end
203203
function evaluate_amp(rng::AbstractRNG, csbm::CSBM; kwargs...)
204204
(; latents, observations) = rand(rng, csbm)
205205
(; û_history, v̂_history, converged) = run_amp(rng, observations, csbm; kwargs...)
206-
qᵤ = discrete_overlap(latents.u, û_history[:, end])
207-
qᵥ = continuous_overlap(latents.v, v̂_history[:, end])
208-
return (; qᵤ, qᵥ)
206+
q_dis = discrete_overlap(latents.u, û_history[:, end])
207+
q_cont = continuous_overlap(latents.v, v̂_history[:, end])
208+
return (; q_dis, q_cont, converged)
209209
end

src/glmsbm.jl

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ $(TYPEDFIELDS)
7171
g::G
7272
"revealed communities `±1` for a fraction `ρ` of nodes and `missing` for the rest, ngth `N`"
7373
Ξ::Vector{Union{Int,Missing}}
74-
"feature matrix, size `(M, N)`"
74+
"feature matrix, size `(N, M)`"
7575
F::Matrix{R}
7676
end
7777

@@ -85,7 +85,7 @@ Sample from a [`GLMSBM`](@ref) and return a named tuple `(; latents, observation
8585
function Base.rand(rng::AbstractRNG, glmsbm::GLMSBM)
8686
(; N, M, ρ, Pʷ) = glmsbm
8787

88-
F = randn(rng, M, N) ./ sqrt(M)
88+
F = randn(rng, N, M) ./ sqrt(M)
8989

9090
w = [rand(Pʷ) for l in 1:M]
9191
s = round.(Int, sign.(F * w))
@@ -103,6 +103,9 @@ end
103103
struct RademacherWeightPrior{R} end
104104
struct GaussianWeightPrior{R} end
105105

106+
RademacherWeightPrior() = RademacherWeightPrior{Float64}()
107+
GaussianWeightPrior() = GaussianWeightPrior{Float64}()
108+
106109
function Base.rand(rng::AbstractRNG, ::RademacherWeightPrior{R}) where {R}
107110
return rand(rng, (-one(R), +one(R)))
108111
end

src/glmsbm_inference.jl

Lines changed: 28 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ end
6666
ind(s) = mod(s, 3) # sends 1 to 1 and -1 to 2
6767

6868
function init_amp(
69-
rng::AbstractRNG; observations::ObservationsGLMSBM{R1}, glmsbm::GLMSBM{R2}, init_std::R3
69+
rng::AbstractRNG, observations::ObservationsGLMSBM{R1}, glmsbm::GLMSBM{R2}; init_std::R3
7070
) where {R1,R2,R3}
7171
R = promote_type(R1, R2, R3)
7272
(; N, M) = glmsbm
@@ -96,7 +96,7 @@ function init_amp(
9696
end
9797

9898
function update_amp!(
99-
next_marginals::MarginalsGLMSBM{R};
99+
next_marginals::MarginalsGLMSBM{R},
100100
marginals::MarginalsGLMSBM{R},
101101
observations::ObservationsGLMSBM,
102102
glmsbm::GLMSBM,
@@ -143,38 +143,21 @@ function update_amp!(
143143
end
144144
@views gₒᵗ⁺¹ = gₒ.(ωᵗ⁺¹, χlᵗ[1, :], Ref(Vᵗ⁺¹))
145145
Λᵗ⁺¹ = sum(abs2, gₒᵗ⁺¹) / M
146-
mul!(Γᵗ⁺¹, F, gₒᵗ⁺¹)
146+
mul!(Γᵗ⁺¹, F', gₒᵗ⁺¹)
147147
Γᵗ⁺¹ .+= Λᵗ⁺¹ .* ŵᵗ
148148

149149
# AMP update of the estimated marginals a, v
150-
ŵᵗ⁺¹ .= fₐ.(Ref(Pʷ), Ref(Λᵗ⁺¹), Γᵗ⁺¹)
151-
vᵗ⁺¹ .= fᵥ.(Ref(Pʷ), Ref(Λᵗ⁺¹), Γᵗ⁺¹)
150+
ŵᵗ⁺¹ .= fₐ.(Ref(Pʷ), Λᵗ⁺¹, Γᵗ⁺¹)
151+
vᵗ⁺¹ .= fᵥ.(Ref(Pʷ), Λᵗ⁺¹, Γᵗ⁺¹)
152152

153153
# BP update of the field h
154154
hᵗ⁺¹ = Vector{R}(undef, 2)
155155
for s in (-1, 1)
156-
hᵗ⁺¹[ind(s)] = sum(
157-
C[ind(s), ind(sμ)] * χᵗ[ind(sμ), μ] for μ in 1:N forin (-1, 1)
158-
)
156+
hᵗ⁺¹[ind(s)] =
157+
sum(C[ind(s), ind(sμ)] * χᵗ[ind(sμ), μ] for μ in 1:N forin (-1, 1)) / N
159158
end
160159

161160
# BP update of the messages χe and of the marginals χ
162-
for μ in 1:N, ν in neighbors(g, μ)
163-
forin (-1, 1)
164-
χeᵗ⁺¹[ind(sμ), μ, ν] =
165-
prior(R, sμ, Ξ[μ]) * exp(-hᵗ⁺¹[ind(sμ)]) * ψlᵗ⁺¹[ind(sμ), μ]
166-
for η in neighbors(g, μ)
167-
if η != ν
168-
χeᵗ⁺¹[ind(sμ), μ, ν] *= sum(
169-
C[ind(sη), ind(sμ)] * χeᵗ[ind(sη), η, μ] forin (-1, 1)
170-
)
171-
end
172-
end
173-
end
174-
normalization = χeᵗ⁺¹[1, μ, ν] + χeᵗ⁺¹[2, μ, ν]
175-
χeᵗ⁺¹[1, μ, ν] /= normalization
176-
χeᵗ⁺¹[2, μ, ν] /= normalization
177-
end
178161

179162
for μ in 1:N
180163
forin (-1, 1)
@@ -188,6 +171,16 @@ function update_amp!(
188171
@views χᵗ⁺¹[:, μ] ./= sum(χᵗ⁺¹[:, μ])
189172
end
190173

174+
for μ in 1:N, ν in neighbors(g, μ)
175+
forin (-1, 1)
176+
extra_factor = sum(C[ind(sν), ind(sμ)] * χeᵗ[ind(sν), ν, μ] forin (-1, 1))
177+
χeᵗ⁺¹[ind(sμ), μ, ν] = χᵗ⁺¹[ind(sμ), μ] / extra_factor
178+
end
179+
normalization = χeᵗ⁺¹[1, μ, ν] + χeᵗ⁺¹[2, μ, ν]
180+
χeᵗ⁺¹[1, μ, ν] /= normalization
181+
χeᵗ⁺¹[2, μ, ν] /= normalization
182+
end
183+
191184
# BP update of the SBM-to-GLM messages χl
192185
for μ in 1:N
193186
forin (-1, 1)
@@ -201,7 +194,7 @@ function update_amp!(
201194
@views χlᵗ⁺¹[:, μ] ./= sum(χlᵗ⁺¹[:, μ])
202195
end
203196

204-
@views ⁺¹ .= 2 .* χ[1, :] .- one(R)
197+
@views ŝᵗ⁺¹ .= 2 .* χᵗ⁺¹[1, :] .- one(R)
205198

206199
return nothing
207200
end
@@ -211,13 +204,13 @@ function run_amp(
211204
observations::ObservationsGLMSBM,
212205
glmsbm::GLMSBM;
213206
init_std=1e-3,
214-
max_iterations=200,
207+
max_iterations=100,
215208
convergence_threshold=1e-3,
216209
recent_past=10,
217210
show_progress=false,
218211
)
219-
(; N, M) = csbm
220-
(; marginals, next_marginals) = init_amp(rng; observations, glmsbm, init_std)
212+
(; N, M) = glmsbm
213+
(; marginals, next_marginals) = init_amp(rng, observations, glmsbm; init_std)
221214

222215
R = eltype(marginals)
223216
ŝ_history = Matrix{R}(undef, N, max_iterations)
@@ -226,7 +219,7 @@ function run_amp(
226219
prog = Progress(max_iterations; desc="AMP-BP for GLM-SBM", enabled=show_progress)
227220

228221
for t in 1:max_iterations
229-
update_amp!(next_marginals; marginals, observations, glmsbm)
222+
update_amp!(next_marginals, marginals, observations, glmsbm)
230223
copy!(marginals, next_marginals)
231224

232225
ŝ_history[:, t] .= marginals.
@@ -236,8 +229,8 @@ function run_amp(
236229
ŝ_recent_std = typemax(R)
237230
ŵ_recent_std = typemax(R)
238231
else
239-
ŝ_recent_std = maximum(std(view(ŝ_history, :, (t - recent_past):t); dims=2))
240-
ŵ_recent_std = maximum(std(view(ŵ_history, :, (t - recent_past):t); dims=2))
232+
ŝ_recent_std = mean(std(view(ŝ_history, :, (t - recent_past):t); dims=2))
233+
ŵ_recent_std = mean(std(view(ŵ_history, :, (t - recent_past):t); dims=2))
241234
end
242235
converged = (
243236
ŝ_recent_std < convergence_threshold && ŵ_recent_std < convergence_threshold
@@ -261,8 +254,8 @@ end
261254

262255
function evaluate_amp(rng::AbstractRNG, glmsbm::GLMSBM; kwargs...)
263256
(; latents, observations) = rand(rng, glmsbm)
264-
(; ŝ_history, ŵ_history, converged) = run_amp(rng, observations, csbm; kwargs...)
265-
qᵤ = discrete_overlap(latents.s, ŝ_history[:, end])
266-
qᵥ = continuous_overlap(latents.w, ŵ_history[:, end])
267-
return (; qᵤ, qᵥ)
257+
(; ŝ_history, ŵ_history, converged) = run_amp(rng, observations, glmsbm; kwargs...)
258+
q_dis = discrete_overlap(latents.s, ŝ_history[:, end])
259+
q_cont = continuous_overlap(latents.w, ŵ_history[:, end])
260+
return (; q_dis, q_cont, converged)
268261
end

test/allocations.jl

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
using JET
2+
using Random: default_rng
3+
using Statistics
4+
using StochasticBlockModelVariants
5+
using Test
6+
7+
function test_allocations(sbm::AbstractSBM)
8+
rng = default_rng()
9+
(; observations) = rand(rng, sbm)
10+
(; marginals, next_marginals) = init_amp(rng, observations, sbm; init_std=1e-3)
11+
alloc = @allocated update_amp!(next_marginals, marginals, observations, sbm)
12+
@test alloc == 0
13+
return nothing
14+
end
15+
16+
test_allocations(CSBM(; N=10^3, P=10^3, d=5, λ=2, μ=2, ρ=0.0))
17+
@test_skip test_allocations(
18+
GLMSBM(; N=10^3, M=10^3, c=5, λ=2, ρ=0.0, Pʷ=GaussianWeightPrior())
19+
)

test/csbm.jl

Lines changed: 0 additions & 58 deletions
This file was deleted.

0 commit comments

Comments
 (0)