Skip to content

Commit 55cfbfd

Browse files
committed
Sample CSBM
1 parent 27a2f20 commit 55cfbfd

File tree

6 files changed

+109
-5
lines changed

6 files changed

+109
-5
lines changed

.vscode/settings.json

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
{
2+
"julia.environmentPath": "/home/gdalle/Documents/GitHub/Julia/StochasticBlockModelVariants.jl"
3+
}

Project.toml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,13 @@ authors = ["Guillaume Dalle <22795598+gdalle@users.noreply.github.com> and contr
44
version = "0.1.0"
55

66
[deps]
7-
Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6"
87
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
8+
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
9+
SimpleWeightedGraphs = "47aef6b3-ad0c-573a-a1e2-d07658019622"
910
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
1011

1112
[compat]
12-
Graphs = "1.8"
13+
SimpleWeightedGraphs = "1.4"
1314
julia = "1.9"
1415

1516
[extras]
Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,12 @@
11
module StochasticBlockModelVariants
22

3-
using Graphs
4-
using LinearAlgebra
5-
using SparseArrays
3+
using SimpleWeightedGraphs: SimpleWeightedGraph
4+
using LinearAlgebra: Symmetric
5+
using Random: AbstractRNG, default_rng
6+
using SparseArrays: SparseMatrixCSC, sparse
7+
8+
export ContextualSBM, ContextualSBMLatents, ContextualSBMObservations
9+
10+
include("contextual_sbm.jl")
611

712
end

src/contextual_sbm.jl

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
struct ContextualSBM{R<:Real}
2+
d::R
3+
λ::R
4+
μ::R
5+
N::Int
6+
P::Int
7+
8+
function ContextualSBM(; d::R1, λ::R2, μ::R3, N, P) where {R1,R2,R3}
9+
R = promote_type(R1, R2, R3)
10+
return new{R}(d, λ, μ, N, P)
11+
end
12+
end
13+
14+
@kwdef struct ContextualSBMLatents{R<:Real}
15+
u::Vector{Int} # (N,)
16+
v::Vector{R} # (P,)
17+
end
18+
19+
@kwdef struct ContextualSBMObservations{R<:Real}
20+
A::Symmetric{Bool,SparseMatrixCSC{Bool,Int}} # (N, N)
21+
G::SimpleWeightedGraph{Int,Bool}
22+
B::Matrix{R} # (P, N)
23+
end
24+
25+
@kwdef struct ContextualSBMMessages
26+
# From variables to factors
27+
χ_node_node
28+
χ_node_feat
29+
χ_feat_node
30+
# From factors to variables
31+
ψ_node_node
32+
ψ_node_feat
33+
ψ_feat_node
34+
end
35+
36+
const CSBM = ContextualSBM
37+
const CSBML = ContextualSBMLatents
38+
const CSBMO = ContextualSBMObservations
39+
40+
function affinities(csbm::CSBM)
41+
(; d, λ) = csbm
42+
cᵢ = d + λ * sqrt(d)
43+
cₒ = d - λ * sqrt(d)
44+
return (; cᵢ, cₒ)
45+
end
46+
47+
nb_nodes(csbm::CSBM) = csbm.N
48+
nb_nodes(latents::CSBML) = length(latents.u)
49+
nb_nodes(obs::CSBMO) = size(obs.A, 1)
50+
51+
nb_features(csbm::CSBM) = csbm.P
52+
nb_features(latents::CSBML) = length(latents.v)
53+
nb_features(obs::CSBMO) = size(obs.B, 1)
54+
55+
function Base.rand(rng::AbstractRNG, csbm::CSBM)
56+
N, P = nb_nodes(csbm), nb_features(csbm)
57+
μ = csbm.μ
58+
(; cᵢ, cₒ) = affinities(csbm)
59+
60+
u = rand(rng, (-1, +1), N)
61+
v = randn(rng, P)
62+
63+
Is, Js = Int[], Int[]
64+
for i in 1:N, j in 1:i
65+
r = rand(rng)
66+
if (u[i] == u[j] && r < cᵢ / N) || (u[i] != u[j] && r < cₒ / N)
67+
push!(Is, i)
68+
push!(Js, j)
69+
end
70+
end
71+
Vs = fill(true, length(Is))
72+
A = Symmetric(sparse(Is, Js, Vs, N, N))
73+
G = SimpleWeightedGraph(A)
74+
75+
Z = randn(rng, P, N)
76+
B = similar(Z)
77+
for α in 1:P, i in 1:N
78+
B[α, i] = sqrt/ N) * v[α] * u[i] * Z[α, i]
79+
end
80+
81+
latents = CSBML(; u, v)
82+
obs = CSBMO(; A, G, B)
83+
return (; latents, obs)
84+
end
85+
86+
Base.rand(csbm::CSBM) = Base.rand(default_rng(), csbm)

test/contextual_sbm.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
using StochasticBlockModelVariants
2+
3+
csbm = ContextualSBM(; d=3, λ=1, μ=2.0, N=10, P=20)
4+
5+
(; latents, obs) = rand(csbm)

test/runtests.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,4 +27,8 @@ using Test
2727
@testset "Doctests" begin
2828
doctest(StochasticBlockModelVariants)
2929
end
30+
31+
@testset verbose = true "Contextual SBM" begin
32+
include("contextual_sbm.jl")
33+
end
3034
end

0 commit comments

Comments
 (0)