Skip to content

Commit aeb63f9

Browse files
committed
Finish recoding GLMSBM
1 parent 5b2fc68 commit aeb63f9

File tree

9 files changed

+252
-194
lines changed

9 files changed

+252
-194
lines changed

Manifest.toml

Lines changed: 27 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
julia_version = "1.10.0-beta2"
44
manifest_format = "2.0"
5-
project_hash = "5d6b46b606378213b4547f42572ed5ae33cb9d37"
5+
project_hash = "ba25bdf8dd9a1a3a609260b65801f3325dbdaa20"
66

77
[[deps.ArgTools]]
88
uuid = "0dad84c5-d112-42e6-8d28-ef12dabb789f"
@@ -22,9 +22,9 @@ uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f"
2222

2323
[[deps.Compat]]
2424
deps = ["UUIDs"]
25-
git-tree-sha1 = "4e88377ae7ebeaf29a047aa1ee40826e0b708a5d"
25+
git-tree-sha1 = "e460f044ca8b99be31d35fe54fc33a5c33dd8ed7"
2626
uuid = "34da2185-b29b-5c13-b0c7-acf172513d20"
27-
version = "4.7.0"
27+
version = "4.9.0"
2828
weakdeps = ["Dates", "LinearAlgebra"]
2929

3030
[deps.Compat.extensions]
@@ -37,9 +37,9 @@ version = "1.0.5+1"
3737

3838
[[deps.DataStructures]]
3939
deps = ["Compat", "InteractiveUtils", "OrderedCollections"]
40-
git-tree-sha1 = "cf25ccb972fec4e4817764d01c82386ae94f77b4"
40+
git-tree-sha1 = "3dbd312d370723b6bb43ba9d02fc36abade4518d"
4141
uuid = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
42-
version = "0.18.14"
42+
version = "0.18.15"
4343

4444
[[deps.Dates]]
4545
deps = ["Printf"]
@@ -75,6 +75,12 @@ git-tree-sha1 = "1cf1d7dcb4bc32d7b4a5add4232db3750c27ecb4"
7575
uuid = "86223c79-3864-5bf0-83f7-82e725a168b6"
7676
version = "1.8.0"
7777

78+
[[deps.Infiltrator]]
79+
deps = ["InteractiveUtils", "Markdown", "REPL", "UUIDs"]
80+
git-tree-sha1 = "04de041e4590428cccbb026d86e5d670513be2e3"
81+
uuid = "5903a43b-9cc3-4c30-8d17-598619ec4e9b"
82+
version = "1.6.4"
83+
7884
[[deps.Inflate]]
7985
git-tree-sha1 = "5cd07aab533df5170988219191dfad0519391428"
8086
uuid = "d25df0c9-e2be-5dd7-82c8-3ad0b3e990b9"
@@ -143,14 +149,20 @@ version = "0.3.26"
143149
ChangesOfVariables = "9e997f8a-9a97-42d5-a9f1-ce6bfc15e2c0"
144150
InverseFunctions = "3587e190-3f89-42d0-90ee-14403ec27112"
145151

152+
[[deps.LogarithmicNumbers]]
153+
deps = ["Random"]
154+
git-tree-sha1 = "8522befb54ff3b4bcf17d57b14b884d536a22015"
155+
uuid = "aa2f6b4e-9042-5d33-9679-40d3a6b85899"
156+
version = "1.2.1"
157+
146158
[[deps.Logging]]
147159
uuid = "56ddb016-857b-54e1-b83d-db4d58db5568"
148160

149161
[[deps.MacroTools]]
150162
deps = ["Markdown", "Random"]
151-
git-tree-sha1 = "42324d08725e200c23d4dfb549e0d5d89dede2d2"
163+
git-tree-sha1 = "9ee1618cbf5240e6d4e0371d6f24065083f60c48"
152164
uuid = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
153-
version = "0.5.10"
165+
version = "0.5.11"
154166

155167
[[deps.Markdown]]
156168
deps = ["Base64"]
@@ -189,9 +201,9 @@ uuid = "efe28fd5-8261-553b-a9e1-b2916fc3738e"
189201
version = "0.5.5+0"
190202

191203
[[deps.OrderedCollections]]
192-
git-tree-sha1 = "d321bf2de576bf25ec4d3e4360faca399afca282"
204+
git-tree-sha1 = "2e73fe17cac3c62ad1aebe70d44c963c3cfdc3e3"
193205
uuid = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
194-
version = "1.6.0"
206+
version = "1.6.2"
195207

196208
[[deps.Pkg]]
197209
deps = ["Artifacts", "Dates", "Downloads", "FileWatching", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "Serialization", "TOML", "Tar", "UUIDs", "p7zip_jll"]
@@ -200,9 +212,9 @@ version = "1.10.0"
200212

201213
[[deps.PrecompileTools]]
202214
deps = ["Preferences"]
203-
git-tree-sha1 = "9673d39decc5feece56ef3940e5dafba15ba0f81"
215+
git-tree-sha1 = "03b4c25b43cb84cee5c90aa9b5ea0a78fd848d2f"
204216
uuid = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
205-
version = "1.1.2"
217+
version = "1.2.0"
206218

207219
[[deps.Preferences]]
208220
deps = ["TOML"]
@@ -216,9 +228,9 @@ uuid = "de0858da-6303-5e67-8744-51eddeeeb8d7"
216228

217229
[[deps.ProgressMeter]]
218230
deps = ["Distributed", "Printf"]
219-
git-tree-sha1 = "d7a7aef8f8f2d537104f170139553b14dfe39fe9"
231+
git-tree-sha1 = "ae36206463b2395804f2787ffe172f44452b538d"
220232
uuid = "92933f4c-e287-5a05-a399-4b506db050ca"
221-
version = "1.7.2"
233+
version = "1.8.0"
222234

223235
[[deps.REPL]]
224236
deps = ["InteractiveUtils", "Markdown", "Sockets", "Unicode"]
@@ -282,9 +294,9 @@ weakdeps = ["Statistics"]
282294
StaticArraysStatisticsExt = "Statistics"
283295

284296
[[deps.StaticArraysCore]]
285-
git-tree-sha1 = "6b7ba252635a5eff6a0b0664a41ee140a1c9e72a"
297+
git-tree-sha1 = "36b3d696ce6366023a0ea192b4cd442268995a0d"
286298
uuid = "1e83bf80-4336-4d27-bf5d-d5a4f845583c"
287-
version = "1.4.0"
299+
version = "1.4.2"
288300

289301
[[deps.Statistics]]
290302
deps = ["LinearAlgebra", "SparseArrays"]

Project.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,11 @@ version = "0.1.0"
55

66
[deps]
77
DensityInterface = "b429d917-457f-4dbc-8f4c-0cc954292b1d"
8+
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
89
Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6"
10+
Infiltrator = "5903a43b-9cc3-4c30-8d17-598619ec4e9b"
911
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
12+
LogarithmicNumbers = "aa2f6b4e-9042-5d33-9679-40d3a6b85899"
1013
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
1114
ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca"
1215
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"

src/StochasticBlockModelVariants.jl

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,18 @@
1+
"""
2+
StochasticBlockModelVariants
3+
4+
A package for inference in SBMs with node features using message-passing algorithms.
5+
6+
# Exports
7+
8+
$(EXPORTS)
9+
"""
110
module StochasticBlockModelVariants
211

312
using DensityInterface: logdensityof, densityof
13+
using DocStringExtensions
414
using Graphs: AbstractGraph, neighbors
15+
using Infiltrator
516
using LinearAlgebra: LinearAlgebra, dot, mul!, norm, normalize!
617
using PrecompileTools: @compile_workload
718
using ProgressMeter: Progress, next!
@@ -18,7 +29,6 @@ export overlaps
1829
export init_amp, update_amp!, run_amp, evaluate_amp
1930

2031
include("abstract_sbm.jl")
21-
include("utils.jl")
2232
include("csbm.jl")
2333
include("glmsbm.jl")
2434
include("csbm_inference.jl")

src/abstract_sbm.jl

Lines changed: 22 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,33 +1,33 @@
11
"""
2-
AbstractSBM
2+
$(TYPEDEF)
33
44
Abstract supertype for Stochastic Block Models with additional node features.
55
"""
66
abstract type AbstractSBM end
77

88
"""
9-
length(sbm)
9+
length(sbm::AbstractSBM)
1010
1111
Return the number of nodes `N` in the graph.
1212
"""
1313
Base.length
1414

1515
"""
16-
nb_features(sbm)
16+
nb_features(sbm::AbstractSBM)
1717
18-
Return the number of features for each node in the graph.
18+
Return the number of nodes `N` in the graph.
1919
"""
2020
function nb_features end
2121

2222
"""
23-
average_degree(sbm)
23+
average_degree(sbm::AbstractSBM)
2424
2525
Return the average degre `d` of a node in the graph.
2626
"""
2727
function average_degree end
2828

2929
"""
30-
communities_snr(sbm)
30+
communities_snr(sbm::AbstractSBM)
3131
3232
Return the signal-to-noise ratio `λ` of the communities in the graph.
3333
"""
@@ -46,19 +46,6 @@ function affinities(sbm::AbstractSBM)
4646
return (; cᵢ, cₒ)
4747
end
4848

49-
struct AffinityMatrix{R}
50-
cᵢ::R
51-
cₒ::R
52-
end
53-
54-
function Base.getindex(C::AffinityMatrix, i, j)
55-
if i == j
56-
return C.cᵢ
57-
else
58-
return C.cₒ
59-
end
60-
end
61-
6249
"""
6350
fraction_observed(sbm)
6451
@@ -112,3 +99,19 @@ function sample_mask(rng::AbstractRNG, sbm::AbstractSBM, communities::Vector{<:I
11299
end
113100

114101
prior(::Type{R}, u, Ξᵢ) where {R} = ismissing(Ξᵢ) ? one(R) / 2 : R(Ξᵢ == u)
102+
103+
sigmoid(x) = 1 / (1 + exp(-x))
104+
105+
freq_equalities(x, y) = mean(x[i] y[i] for i in eachindex(x, y))
106+
107+
function discrete_overlap(u, û)
108+
q̂ᵤ = max(freq_equalities(û, u), freq_equalities(û, -u))
109+
qᵤ = 2 * (q̂ᵤ - one(R) / 2)
110+
return qᵤ
111+
end
112+
113+
function continuous_overlap(v, v̂)
114+
q̂ᵥ = max(abs(dot(v̂, v)), abs(dot(v̂, -v)))
115+
qᵥ = q̂ᵥ / (eps() + norm(v̂) * norm(v))
116+
return qᵥ
117+
end

src/csbm.jl

Lines changed: 19 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,28 @@
11
## Model
22

33
"""
4-
CSBM
4+
$(TYPEDEF)
55
66
A generative model for graphs with node features, which combines a Stochastic Block Model with a mixture of Gaussians.
77
88
Reference: <https://arxiv.org/abs/2306.07948>
99
1010
# Fields
1111
12-
- `N`: graph size
13-
- `P`: feature dimension
14-
- `d`: average degree
15-
- `λ`: SNR of the communities
16-
- `μ`: SNR of the features
17-
- `ρ`: fraction of node assignments observed
12+
$(TYPEDFIELDS)
1813
"""
19-
struct CSBM{R<:Real} <: AbstractSBM
14+
struct CSBM{R} <: AbstractSBM
15+
"graph size"
2016
N::Int
17+
"feature dimension"
2118
P::Int
19+
"average degree"
2220
d::R
21+
"SNR of the communities"
2322
λ::R
23+
"SNR of the features"
2424
μ::R
25+
"fraction of node assignments observed"
2526
ρ::R
2627

2728
function CSBM(; N::Integer, P::Integer, d::R1, λ::R2, μ::R3, ρ::R4) where {R1,R2,R3,R4}
@@ -50,35 +51,38 @@ end
5051
## Latents
5152

5253
"""
53-
LatentsCSBM
54+
$(TYPEDEF)
5455
5556
The hidden variables generated by sampling from a [`CSBM`](@ref).
5657
5758
# Fields
5859
59-
- `u::Vector`: community assignments, length `N`
60-
- `v::Vector`: feature centroids, length `P`
60+
$(TYPEDFIELDS)
6161
"""
6262
@kwdef struct LatentsCSBM{R<:Real}
63+
"community assignments, length `N`"
6364
u::Vector{Int}
65+
"feature centroids, length `P`"
6466
v::Vector{R}
6567
end
6668

6769
## Observations
6870

6971
"""
70-
ObservationsCSBM
72+
$(TYPEDEF)
7173
7274
The observations generated by sampling from a [`CSBM`](@ref).
7375
7476
# Fields
75-
- `g::AbstractGraph`: undirected unweighted graph with `N` nodes (~ adjacency matrix `A`)
76-
- `Ξ::Vector`: revealed communities `±1` for a fraction `ρ` of nodes and `missing` for the rest, length `N`
77-
- `B::Matrix`: feature matrix, size `(P, N)`
77+
78+
$(TYPEDFIELDS)
7879
"""
7980
@kwdef struct ObservationsCSBM{R<:Real,G<:AbstractGraph{Int}}
81+
"undirected unweighted graph with `N` nodes (~ adjacency matrix `A`)"
8082
g::G
83+
"revealed communities `±1` for a fraction `ρ` of nodes and `missing` for the rest, length `N`"
8184
Ξ::Vector{Union{Int,Missing}}
85+
"feature matrix, size `(P, N)`"
8286
B::Matrix{R}
8387
end
8488

0 commit comments

Comments
 (0)