Skip to content

Commit d54194a

Browse files
committed
Dict storage for messages
1 parent 102792e commit d54194a

File tree

7 files changed

+172
-154
lines changed

7 files changed

+172
-154
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,3 +4,4 @@
44
# /Manifest.toml
55
/docs/Manifest.toml
66
/docs/build/
7+
/test/playground.jl

Manifest.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
julia_version = "1.9.2"
44
manifest_format = "2.0"
5-
project_hash = "fa8ae38c891ad59da40addcb7399824f2795f7b2"
5+
project_hash = "d06132272ceec5b60df1cd95ecf68b1ee04ca385"
66

77
[[deps.ArgTools]]
88
uuid = "0dad84c5-d112-42e6-8d28-ef12dabb789f"
Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,10 @@
11
module StochasticBlockModelVariants
22

3-
using Base: RefValue
4-
using Graphs: AbstractGraph, has_edge, neighbors
3+
using Graphs: AbstractGraph, neighbors
54
using LinearAlgebra: Symmetric, dot
65
using ProgressMeter: @showprogress
76
using Random: AbstractRNG, default_rng
8-
using SimpleWeightedGraphs: SimpleWeightedGraph, SimpleWeightedDiGraph
7+
using SimpleWeightedGraphs: SimpleWeightedGraph
98
using Statistics: mean
109
using SparseArrays: SparseMatrixCSC, sparse, findnz
1110

@@ -14,6 +13,7 @@ export affinities, effective_snr
1413
export init_amp, update_amp!, run_amp, evaluate_amp
1514

1615
include("utils.jl")
17-
include("contextual_sbm.jl")
16+
include("csbm.jl")
17+
include("csbm_inference.jl")
1818

1919
end

src/csbm.jl

Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
## Model
2+
3+
"""
4+
ContextualSBM
5+
6+
A generative model for graphs with node features, which combines the Stochastic Block Model with a mixture of Gaussians.
7+
8+
Reference: <https://arxiv.org/abs/2306.07948>
9+
10+
# Fields
11+
12+
- `N`: graph size
13+
- `P`: feature dimension
14+
- `d`: average degree
15+
- `λ`: SNR of the communities
16+
- `μ`: SNR of the features
17+
- `ρ`: fraction of nodes revealed
18+
"""
19+
struct ContextualSBM{R<:Real}
20+
N::Int
21+
P::Int
22+
d::R
23+
λ::R
24+
μ::R
25+
ρ::R
26+
27+
function ContextualSBM(;
28+
N::Integer, P::Integer, d::R1, λ::R2, μ::R3, ρ::R4
29+
) where {R1,R2,R3,R4}
30+
R = promote_type(R1, R2, R3, R4)
31+
return new{R}(N, P, d, λ, μ, ρ)
32+
end
33+
end
34+
35+
"""
36+
effective_snr(csbm)
37+
38+
Compute the effective SNR `λ² + μ² / (N/P)`.
39+
"""
40+
function effective_snr(csbm::ContextualSBM)
41+
(; λ, μ, N, P) = csbm
42+
return abs2(λ) + abs2(μ) / (N / P)
43+
end
44+
45+
"""
46+
affinities(csbm)
47+
48+
Return a named tuple `(; cᵢ, cₒ)` containing the affinities inside and outside of a community.
49+
"""
50+
function affinities(csbm::ContextualSBM)
51+
(; d, λ) = csbm
52+
cᵢ = d + λ * sqrt(d)
53+
cₒ = d - λ * sqrt(d)
54+
return (; cᵢ, cₒ)
55+
end
56+
57+
## Latents
58+
59+
"""
60+
ContextualSBMLatents
61+
62+
The hidden variables generated by sampling from a [`ContextualSBM`](@ref).
63+
64+
# Fields
65+
66+
- `u::Vector`: community assignments, length `N`
67+
- `v::Vector`: feature centroids, length `P`
68+
"""
69+
@kwdef struct ContextualSBMLatents{R<:Real}
70+
u::Vector{Int}
71+
v::Vector{R}
72+
end
73+
74+
## Observations
75+
76+
"""
77+
ContextualSBMObservations
78+
79+
The observations generated by sampling from a [`ContextualSBM`](@ref).
80+
81+
# Fields
82+
83+
- `g::AbstractGraph`: undirected, unweighted graph generated from the adjacency matrix
84+
- `B::Matrix`: feature matrix, size `(P, N)`
85+
"""
86+
@kwdef struct ContextualSBMObservations{R<:Real,G<:AbstractGraph{Int}}
87+
B::Matrix{R}
88+
g::G
89+
end
90+
91+
## Simulation
92+
93+
"""
94+
rand(rng, csbm)
95+
96+
Sample from a [`ContextualSBM`](@ref) and return a named tuple `(; latents, observations)`.
97+
"""
98+
function Base.rand(rng::AbstractRNG, csbm::ContextualSBM)
99+
(; μ, N, P) = csbm
100+
(; cᵢ, cₒ) = affinities(csbm)
101+
102+
u = rand(rng, (-1, +1), N)
103+
v = randn(rng, P)
104+
latents = ContextualSBMLatents(; u, v)
105+
106+
Is, Js = Int[], Int[]
107+
for i in 1:N, j in (i + 1):N
108+
r = rand(rng)
109+
if (
110+
((u[i] == u[j]) && (r < cᵢ / N)) || # same community
111+
((u[i] != u[j]) && (r < cₒ / N)) # different community
112+
)
113+
push!(Is, i)
114+
push!(Is, j)
115+
push!(Js, j)
116+
push!(Js, i)
117+
end
118+
end
119+
Vs = fill(true, length(Is))
120+
A = sparse(Is, Js, Vs, N, N)
121+
g = SimpleWeightedGraph(A)
122+
123+
B = randn(rng, P, N)
124+
for i in 1:N, α in 1:P
125+
B[α, i] += sqrt/ N) * v[α] * u[i]
126+
end
127+
128+
observations = ContextualSBMObservations(; g, B)
129+
return (; latents, observations)
130+
end

0 commit comments

Comments
 (0)