Skip to content

Commit 29a33a8

Browse files
committed
Basic structs
1 parent bc8574b commit 29a33a8

File tree

6 files changed

+112
-15
lines changed

6 files changed

+112
-15
lines changed

.gitignore

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,4 +4,4 @@
44
# /Manifest.toml
55
/docs/Manifest.toml
66
/docs/build/
7-
/test/playground.jl
7+
playground.jl

CITATION.bib

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
@misc{StochasticBlockModelVariants.jl,
2-
author = {Guillaume Dalle <22795598+gdalle@users.noreply.github.com> and contributors},
2+
author = {Guillaume Dalle},
33
title = {StochasticBlockModelVariants.jl},
44
url = {https://github.com/gdalle/StochasticBlockModelVariants.jl},
55
version = {v0.1.0},

src/StochasticBlockModelVariants.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ export affinities, effective_snr
1414
export overlaps
1515
export init_amp, update_amp!, run_amp, evaluate_amp
1616

17+
include("abstract_sbm.jl")
1718
include("utils.jl")
1819
include("csbm.jl")
1920
include("csbm_inference.jl")

src/abstract_sbm.jl

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
"""
2+
AbstractSBM
3+
4+
Abstract supertype for Stochastic Block Models with additional node features.
5+
"""
6+
abstract type AbstractSBM end
7+
8+
"""
9+
average_degree(sbm)
10+
11+
Return the average degre `d` of a node in the graph.
12+
"""
13+
function average_degree end
14+
15+
"""
16+
communities_snr(sbm)
17+
18+
Return the signal-to-noise ratio `λ` of the communities in the graph.
19+
"""
20+
function communities_snr end
21+
22+
"""
23+
affinities(sbm)
24+
25+
Return a named tuple `(; cᵢ, cₒ)` containing the affinities inside and outside of a community.
26+
"""
27+
function affinities(sbm::AbstractSBM)
28+
d = average_degree(sbm)
29+
λ = communities_snr(sbm)
30+
cᵢ = d + λ * sqrt(d)
31+
cₒ = d - λ * sqrt(d)
32+
return (; cᵢ, cₒ)
33+
end

src/csbm.jl

Lines changed: 5 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
"""
44
ContextualSBM
55
6-
A generative model for graphs with node features, which combines the Stochastic Block Model with a mixture of Gaussians.
6+
A generative model for graphs with node features, which combines a Stochastic Block Model with a mixture of Gaussians.
77
88
Reference: <https://arxiv.org/abs/2306.07948>
99
@@ -16,7 +16,7 @@ Reference: <https://arxiv.org/abs/2306.07948>
1616
- `μ`: SNR of the features
1717
- `ρ`: fraction of nodes revealed
1818
"""
19-
struct ContextualSBM{R<:Real}
19+
struct ContextualSBM{R<:Real} <: AbstractSBM
2020
N::Int
2121
P::Int
2222
d::R
@@ -32,6 +32,9 @@ struct ContextualSBM{R<:Real}
3232
end
3333
end
3434

35+
average_degree(csbm::CSBM) = csbm.d
36+
communities_snr(csbm::CSBM) = csbm.λ
37+
3538
"""
3639
effective_snr(csbm)
3740
@@ -42,17 +45,6 @@ function effective_snr(csbm::ContextualSBM)
4245
return abs2(λ) + abs2(μ) / (N / P)
4346
end
4447

45-
"""
46-
affinities(csbm)
47-
48-
Return a named tuple `(; cᵢ, cₒ)` containing the affinities inside and outside of a community.
49-
"""
50-
function affinities(csbm::ContextualSBM)
51-
(; d, λ) = csbm
52-
cᵢ = d + λ * sqrt(d)
53-
cₒ = d - λ * sqrt(d)
54-
return (; cᵢ, cₒ)
55-
end
5648

5749
## Latents
5850

src/glmsbm.jl

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
## Model
2+
3+
"""
4+
GLMSBM
5+
6+
A generative model for graphs with node features, which combines a Generalized Linear Model with a Stochastic Block Model.
7+
8+
Reference: <https://arxiv.org/abs/2303.09995>
9+
10+
# Fields
11+
12+
- `N`: graph size
13+
- `M`: feature dimension
14+
- `c`: average degree
15+
- `λ`: SNR of the communities
16+
- `ρ`: fraction of nodes revealed
17+
"""
18+
struct GLMSBM{R<:Real} <: AbstractSBM
19+
N::Int
20+
M::Int
21+
c::R
22+
λ::R
23+
ρ::R
24+
25+
function GLMSBM(; N::Integer, M::Integer, c::R1, λ::R2, ρ::R3) where {R1,R2,R3}
26+
R = promote_type(R1, R2, R3)
27+
return new{R}(N, M, c, λ, ρ)
28+
end
29+
end
30+
31+
average_degree(glmsbm::GLMSBM) = glmsbm.c
32+
communities_snr(glmsbm::GLMSBM) = glmsbm.λ
33+
34+
## Latents
35+
36+
"""
37+
GLMSBMLatents
38+
39+
The hidden variables generated by sampling from a [`GLMSBM`](@ref).
40+
41+
# Fields
42+
43+
- `w::Vector`: feature weights, length `M`
44+
- `s::Vector`: community assignments, length `N`
45+
"""
46+
@kwdef struct GLMSBMLatents{R<:Real}
47+
w::Vector{R}
48+
s::Vector{Int}
49+
end
50+
51+
## Observations
52+
53+
"""
54+
GLMSBMObservations
55+
56+
The observations generated by sampling from a [`GLMSBM`](@ref).
57+
58+
# Fields
59+
- `A::AbstractMatrix`: symmetric boolean adjacency matrix, size `(N, N)`
60+
- `g::AbstractGraph`: undirected unweighted graph generated from `A`
61+
- `F::Matrix`: feature matrix, size `(M, N)`
62+
- `Ξ::Vector`: revealed communities `±1` for a fraction of nodes and `0` for the rest, length `N`
63+
"""
64+
@kwdef struct ContextualSBMObservations{
65+
R<:Real,M<:AbstractMatrix{Bool},G<:AbstractGraph{Int}
66+
}
67+
A::M
68+
g::G
69+
F::Matrix{R}
70+
Ξ::Vector{Int}
71+
end

0 commit comments

Comments
 (0)