Skip to content

Commit aa9bbd3

Browse files
committed
Start inference GLMSBM
1 parent 6276c79 commit aa9bbd3

File tree

12 files changed

+312
-181
lines changed

12 files changed

+312
-181
lines changed

Manifest.toml

Lines changed: 33 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
# This file is machine-generated - editing it directly is not advised
22

3-
julia_version = "1.9.2"
3+
julia_version = "1.10.0-beta2"
44
manifest_format = "2.0"
5-
project_hash = "dd3777d035d7f906717e2a66a2e6f5c3854a6b1a"
5+
project_hash = "4013fd98503838d81ad2bc81b622895c0637888e"
66

77
[[deps.ArgTools]]
88
uuid = "0dad84c5-d112-42e6-8d28-ef12dabb789f"
@@ -33,7 +33,7 @@ weakdeps = ["Dates", "LinearAlgebra"]
3333
[[deps.CompilerSupportLibraries_jll]]
3434
deps = ["Artifacts", "Libdl"]
3535
uuid = "e66e0078-7015-5450-92f7-15fbd957f2ae"
36-
version = "1.0.5+0"
36+
version = "1.0.5+1"
3737

3838
[[deps.DataStructures]]
3939
deps = ["Compat", "InteractiveUtils", "OrderedCollections"]
@@ -45,6 +45,12 @@ version = "0.18.14"
4545
deps = ["Printf"]
4646
uuid = "ade2ca70-3891-5945-98fb-dc099432e06a"
4747

48+
[[deps.DensityInterface]]
49+
deps = ["InverseFunctions", "Test"]
50+
git-tree-sha1 = "80c3e8639e3353e5d2912fb3a1916b8455e2494b"
51+
uuid = "b429d917-457f-4dbc-8f4c-0cc954292b1d"
52+
version = "0.4.0"
53+
4854
[[deps.Distributed]]
4955
deps = ["Random", "Serialization", "Sockets"]
5056
uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b"
@@ -72,15 +78,21 @@ version = "0.1.3"
7278
deps = ["Markdown"]
7379
uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
7480

81+
[[deps.InverseFunctions]]
82+
deps = ["Test"]
83+
git-tree-sha1 = "68772f49f54b479fa88ace904f6127f0a3bb2e46"
84+
uuid = "3587e190-3f89-42d0-90ee-14403ec27112"
85+
version = "0.1.12"
86+
7587
[[deps.LibCURL]]
7688
deps = ["LibCURL_jll", "MozillaCACerts_jll"]
7789
uuid = "b27032c2-a3e7-50c8-80cd-2d36dbcbfd21"
78-
version = "0.6.3"
90+
version = "0.6.4"
7991

8092
[[deps.LibCURL_jll]]
8193
deps = ["Artifacts", "LibSSH2_jll", "Libdl", "MbedTLS_jll", "Zlib_jll", "nghttp2_jll"]
8294
uuid = "deac9b47-8bc7-5906-a0fe-35ac56dc84c0"
83-
version = "7.84.0+0"
95+
version = "8.0.1+1"
8496

8597
[[deps.LibGit2]]
8698
deps = ["Base64", "NetworkOptions", "Printf", "SHA"]
@@ -89,7 +101,7 @@ uuid = "76f85450-5226-5b5a-8eaa-529ad045b433"
89101
[[deps.LibSSH2_jll]]
90102
deps = ["Artifacts", "Libdl", "MbedTLS_jll"]
91103
uuid = "29816b5a-b9ab-546f-933c-edad1886dfa8"
92-
version = "1.10.2+0"
104+
version = "1.11.0+1"
93105

94106
[[deps.Libdl]]
95107
uuid = "8f399da3-3557-5675-b5ff-fb832c97cbdb"
@@ -114,14 +126,14 @@ uuid = "d6f4376e-aef5-505a-96c1-9c027394607a"
114126
[[deps.MbedTLS_jll]]
115127
deps = ["Artifacts", "Libdl"]
116128
uuid = "c8ffd9c3-330d-5841-b78e-0817d7145fa1"
117-
version = "2.28.2+0"
129+
version = "2.28.2+1"
118130

119131
[[deps.Mmap]]
120132
uuid = "a63ad114-7e13-5084-954f-fe012c677804"
121133

122134
[[deps.MozillaCACerts_jll]]
123135
uuid = "14a3606d-f60d-562e-9121-12d972cd8159"
124-
version = "2022.10.11"
136+
version = "2023.1.10"
125137

126138
[[deps.NetworkOptions]]
127139
uuid = "ca575930-c2e3-43a9-ace4-1e988b2c1908"
@@ -130,7 +142,7 @@ version = "1.2.0"
130142
[[deps.OpenBLAS_jll]]
131143
deps = ["Artifacts", "CompilerSupportLibraries_jll", "Libdl"]
132144
uuid = "4536629a-c528-5b80-bd46-f80d51c5b363"
133-
version = "0.3.21+4"
145+
version = "0.3.23+2"
134146

135147
[[deps.OrderedCollections]]
136148
git-tree-sha1 = "d321bf2de576bf25ec4d3e4360faca399afca282"
@@ -140,7 +152,7 @@ version = "1.6.0"
140152
[[deps.Pkg]]
141153
deps = ["Artifacts", "Dates", "Downloads", "FileWatching", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "Serialization", "TOML", "Tar", "UUIDs", "p7zip_jll"]
142154
uuid = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
143-
version = "1.9.2"
155+
version = "1.10.0"
144156

145157
[[deps.PrecompileTools]]
146158
deps = ["Preferences"]
@@ -169,7 +181,7 @@ deps = ["InteractiveUtils", "Markdown", "Sockets", "Unicode"]
169181
uuid = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb"
170182

171183
[[deps.Random]]
172-
deps = ["SHA", "Serialization"]
184+
deps = ["SHA"]
173185
uuid = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
174186

175187
[[deps.SHA]]
@@ -201,6 +213,7 @@ uuid = "6462fe0b-24de-5631-8697-dd941f90decc"
201213
[[deps.SparseArrays]]
202214
deps = ["Libdl", "LinearAlgebra", "Random", "Serialization", "SuiteSparse_jll"]
203215
uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
216+
version = "1.10.0"
204217

205218
[[deps.StaticArrays]]
206219
deps = ["LinearAlgebra", "Random", "StaticArraysCore"]
@@ -225,7 +238,7 @@ version = "1.9.0"
225238
[[deps.SuiteSparse_jll]]
226239
deps = ["Artifacts", "Libdl", "Pkg", "libblastrampoline_jll"]
227240
uuid = "bea87d4a-7f5b-5778-9afe-8cc45184846c"
228-
version = "5.10.1+6"
241+
version = "7.2.0+1"
229242

230243
[[deps.TOML]]
231244
deps = ["Dates"]
@@ -237,6 +250,10 @@ deps = ["ArgTools", "SHA"]
237250
uuid = "a4e569a6-e804-4fa4-b0f3-eef7a1d5b13e"
238251
version = "1.10.0"
239252

253+
[[deps.Test]]
254+
deps = ["InteractiveUtils", "Logging", "Random", "Serialization"]
255+
uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
256+
240257
[[deps.UUIDs]]
241258
deps = ["Random", "SHA"]
242259
uuid = "cf7118a7-6976-5b1a-9a39-7adc72f591a4"
@@ -247,19 +264,19 @@ uuid = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5"
247264
[[deps.Zlib_jll]]
248265
deps = ["Libdl"]
249266
uuid = "83775a58-1f1d-513f-b197-d71354ab007a"
250-
version = "1.2.13+0"
267+
version = "1.2.13+1"
251268

252269
[[deps.libblastrampoline_jll]]
253270
deps = ["Artifacts", "Libdl"]
254271
uuid = "8e850b90-86db-534c-a0d3-1478176c7d93"
255-
version = "5.8.0+0"
272+
version = "5.8.0+1"
256273

257274
[[deps.nghttp2_jll]]
258275
deps = ["Artifacts", "Libdl"]
259276
uuid = "8e850ede-7688-5339-a07c-302acd2aaf8d"
260-
version = "1.48.0+0"
277+
version = "1.52.0+1"
261278

262279
[[deps.p7zip_jll]]
263280
deps = ["Artifacts", "Libdl"]
264281
uuid = "3f19e933-33d8-53b3-aaab-bd5110c3b7a0"
265-
version = "17.4.0+0"
282+
version = "17.4.0+2"

Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ authors = ["Guillaume Dalle <22795598+gdalle@users.noreply.github.com> and contr
44
version = "0.1.0"
55

66
[deps]
7+
DensityInterface = "b429d917-457f-4dbc-8f4c-0cc954292b1d"
78
Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6"
89
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
910
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"

docs/plots.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ function compute_fig1(; N=3 * 10^3, P=N ÷ 10, d=5, μ=2, ρ=0.1, λ_values=0:0.
1616
@threads for i in 1:trials
1717
for j in eachindex(λ_values)
1818
λ = λ_values[j]
19-
csbm = ContextualSBM(; N, P, d, μ, λ, ρ)
19+
csbm = CSBM(; N, P, d, μ, λ, ρ)
2020
(; qᵤ, qᵥ) = evaluate_amp(rng; csbm)
2121
qᵤ_values[i, j] = qᵤ
2222
qᵥ_values[i, j] = qᵥ
Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
module StochasticBlockModelVariants
22

3+
using DensityInterface: logdensityof, densityof
34
using Graphs: AbstractGraph, neighbors
45
using LinearAlgebra: dot, mul!, norm
56
using PrecompileTools: @compile_workload
@@ -9,21 +10,23 @@ using SimpleWeightedGraphs: SimpleWeightedGraph
910
using Statistics: mean, std
1011
using SparseArrays: SparseMatrixCSC, sparse, findnz
1112

12-
export ContextualSBM, ContextualSBMLatents, ContextualSBMObservations
13-
export affinities, effective_snr
13+
export nb_features, average_degree, communities_snr, affinities, features_snr, effective_snr
14+
export CSBM, LatentsCSBM, ObservationsCSBM
15+
export GLMSBM, LatentsGLMSBM, ObservationsGLMSBM
1416
export overlaps
1517
export init_amp, update_amp!, run_amp, evaluate_amp
1618

1719
include("abstract_sbm.jl")
1820
include("utils.jl")
1921
include("csbm.jl")
22+
include("glmsbm.jl")
2023
include("csbm_inference.jl")
2124

22-
@compile_workload begin
23-
rng = default_rng()
24-
csbm = ContextualSBM(; N=10^2, P=10^2, d=5, λ=2, μ=2, ρ=0.1)
25-
(; observations) = rand(rng, csbm)
26-
run_amp(rng; observations, csbm, max_iterations=2)
27-
end
25+
# @compile_workload begin
26+
# rng = default_rng()
27+
# csbm = CSBM(; N=10^2, P=10^2, d=5, λ=2, μ=2, ρ=0.1)
28+
# (; observations) = rand(rng, csbm)
29+
# run_amp(rng; observations, csbm, max_iterations=2)
30+
# end
2831

2932
end

src/abstract_sbm.jl

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,20 @@ Abstract supertype for Stochastic Block Models with additional node features.
55
"""
66
abstract type AbstractSBM end
77

8+
"""
9+
length(sbm)
10+
11+
Return the number of nodes `N` in the graph.
12+
"""
13+
Base.length
14+
15+
"""
16+
nb_features(sbm)
17+
18+
Return the number of features for each node in the graph.
19+
"""
20+
function nb_features end
21+
822
"""
923
average_degree(sbm)
1024
@@ -31,3 +45,58 @@ function affinities(sbm::AbstractSBM)
3145
cₒ = d - λ * sqrt(d)
3246
return (; cᵢ, cₒ)
3347
end
48+
49+
"""
50+
fraction_observed(sbm)
51+
52+
Return the fraction `ρ` of community assignments that are observed.
53+
"""
54+
function fraction_observed end
55+
56+
"""
57+
sample_graph(rng, sbm, communities)
58+
59+
Sample a graph `g` from an SBM based on known community assignments.
60+
"""
61+
function sample_graph(rng::AbstractRNG, sbm::AbstractSBM, communities::Vector{<:Integer})
62+
N = length(sbm)
63+
(; cᵢ, cₒ) = affinities(sbm)
64+
Is, Js = Int[], Int[]
65+
for i in 1:N, j in (i + 1):N
66+
r = rand(rng)
67+
if (
68+
((communities[i] == communities[j]) && (r < cᵢ / N)) ||
69+
((communities[i] != communities[j]) && (r < cₒ / N))
70+
)
71+
push!(Is, i)
72+
push!(Is, j)
73+
push!(Js, j)
74+
push!(Js, i)
75+
end
76+
end
77+
Vs = fill(true, length(Is))
78+
A = sparse(Is, Js, Vs, N, N)
79+
g = SimpleWeightedGraph(A)
80+
return g
81+
end
82+
83+
"""
84+
sample_mask(rng, sbm, communities)
85+
86+
Sample a vector `Ξ` (Xi) whose components are equal to the community assignments with probability `ρ` and equal to `missing` with probability `1-ρ`.
87+
"""
88+
function sample_mask(rng::AbstractRNG, sbm::AbstractSBM, communities::Vector{<:Integer})
89+
N = length(sbm)
90+
ρ = fraction_observed(sbm)
91+
Ξ = Vector{Union{Missing, Int}}(undef, N)
92+
Ξ .= missing
93+
for i in 1:N
94+
if rand(rng) < ρ
95+
Ξ[i] = communities[i]
96+
end
97+
end
98+
return Ξ
99+
end
100+
101+
prior₊(::Type{R}, Ξᵢ) where {R} = ismissing(Ξᵢ) ? one(R) / 2 : R(Ξᵢ == 1)
102+
prior₋(::Type{R}, Ξᵢ) where {R} = ismissing(Ξᵢ) ? one(R) / 2 : R(Ξᵢ == -1)

0 commit comments

Comments
 (0)