|
| 1 | +struct ContextualSBM{R<:Real} |
| 2 | + d::R |
| 3 | + λ::R |
| 4 | + μ::R |
| 5 | + N::Int |
| 6 | + P::Int |
| 7 | + |
| 8 | + function ContextualSBM(; d::R1, λ::R2, μ::R3, N, P) where {R1,R2,R3} |
| 9 | + R = promote_type(R1, R2, R3) |
| 10 | + return new{R}(d, λ, μ, N, P) |
| 11 | + end |
| 12 | +end |
| 13 | + |
| 14 | +@kwdef struct ContextualSBMLatents{R<:Real} |
| 15 | + u::Vector{Int} # (N,) |
| 16 | + v::Vector{R} # (P,) |
| 17 | +end |
| 18 | + |
| 19 | +@kwdef struct ContextualSBMObservations{R<:Real} |
| 20 | + A::Symmetric{Bool,SparseMatrixCSC{Bool,Int}} # (N, N) |
| 21 | + G::SimpleWeightedGraph{Int,Bool} |
| 22 | + B::Matrix{R} # (P, N) |
| 23 | +end |
| 24 | + |
| 25 | +@kwdef struct ContextualSBMMessages |
| 26 | + # From variables to factors |
| 27 | + χ_node_node |
| 28 | + χ_node_feat |
| 29 | + χ_feat_node |
| 30 | + # From factors to variables |
| 31 | + ψ_node_node |
| 32 | + ψ_node_feat |
| 33 | + ψ_feat_node |
| 34 | +end |
| 35 | + |
| 36 | +const CSBM = ContextualSBM |
| 37 | +const CSBML = ContextualSBMLatents |
| 38 | +const CSBMO = ContextualSBMObservations |
| 39 | + |
| 40 | +function affinities(csbm::CSBM) |
| 41 | + (; d, λ) = csbm |
| 42 | + cᵢ = d + λ * sqrt(d) |
| 43 | + cₒ = d - λ * sqrt(d) |
| 44 | + return (; cᵢ, cₒ) |
| 45 | +end |
| 46 | + |
| 47 | +nb_nodes(csbm::CSBM) = csbm.N |
| 48 | +nb_nodes(latents::CSBML) = length(latents.u) |
| 49 | +nb_nodes(obs::CSBMO) = size(obs.A, 1) |
| 50 | + |
| 51 | +nb_features(csbm::CSBM) = csbm.P |
| 52 | +nb_features(latents::CSBML) = length(latents.v) |
| 53 | +nb_features(obs::CSBMO) = size(obs.B, 1) |
| 54 | + |
| 55 | +function Base.rand(rng::AbstractRNG, csbm::CSBM) |
| 56 | + N, P = nb_nodes(csbm), nb_features(csbm) |
| 57 | + μ = csbm.μ |
| 58 | + (; cᵢ, cₒ) = affinities(csbm) |
| 59 | + |
| 60 | + u = rand(rng, (-1, +1), N) |
| 61 | + v = randn(rng, P) |
| 62 | + |
| 63 | + Is, Js = Int[], Int[] |
| 64 | + for i in 1:N, j in 1:i |
| 65 | + r = rand(rng) |
| 66 | + if (u[i] == u[j] && r < cᵢ / N) || (u[i] != u[j] && r < cₒ / N) |
| 67 | + push!(Is, i) |
| 68 | + push!(Js, j) |
| 69 | + end |
| 70 | + end |
| 71 | + Vs = fill(true, length(Is)) |
| 72 | + A = Symmetric(sparse(Is, Js, Vs, N, N)) |
| 73 | + G = SimpleWeightedGraph(A) |
| 74 | + |
| 75 | + Z = randn(rng, P, N) |
| 76 | + B = similar(Z) |
| 77 | + for α in 1:P, i in 1:N |
| 78 | + B[α, i] = sqrt(μ / N) * v[α] * u[i] * Z[α, i] |
| 79 | + end |
| 80 | + |
| 81 | + latents = CSBML(; u, v) |
| 82 | + obs = CSBMO(; A, G, B) |
| 83 | + return (; latents, obs) |
| 84 | +end |
| 85 | + |
| 86 | +Base.rand(csbm::CSBM) = Base.rand(default_rng(), csbm) |
0 commit comments