Skip to content

Commit 716c52f

Browse files
committed
More variables
1 parent aa9bbd3 commit 716c52f

File tree

6 files changed

+236
-45
lines changed

6 files changed

+236
-45
lines changed

Manifest.toml

Lines changed: 59 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
julia_version = "1.10.0-beta2"
44
manifest_format = "2.0"
5-
project_hash = "4013fd98503838d81ad2bc81b622895c0637888e"
5+
project_hash = "3ba5719890857eaf711cde76af0aade7729abd07"
66

77
[[deps.ArgTools]]
88
uuid = "0dad84c5-d112-42e6-8d28-ef12dabb789f"
@@ -55,6 +55,12 @@ version = "0.4.0"
5555
deps = ["Random", "Serialization", "Sockets"]
5656
uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b"
5757

58+
[[deps.DocStringExtensions]]
59+
deps = ["LibGit2"]
60+
git-tree-sha1 = "2fb1e02f2b635d0845df5d7c167fec4dd739b00d"
61+
uuid = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
62+
version = "0.9.3"
63+
5864
[[deps.Downloads]]
5965
deps = ["ArgTools", "FileWatching", "LibCURL", "NetworkOptions"]
6066
uuid = "f43a241f-c20a-4ad4-852c-f6b1247861c6"
@@ -84,6 +90,17 @@ git-tree-sha1 = "68772f49f54b479fa88ace904f6127f0a3bb2e46"
8490
uuid = "3587e190-3f89-42d0-90ee-14403ec27112"
8591
version = "0.1.12"
8692

93+
[[deps.IrrationalConstants]]
94+
git-tree-sha1 = "630b497eafcc20001bba38a4651b327dcfc491d2"
95+
uuid = "92d709cd-6900-40b7-9082-c6be49f344b6"
96+
version = "0.2.2"
97+
98+
[[deps.JLLWrappers]]
99+
deps = ["Artifacts", "Preferences"]
100+
git-tree-sha1 = "7e5d6779a1e09a36db2a7b6cff50942a0a7d0fca"
101+
uuid = "692b3bcd-3c85-4b1f-b108-f13ce0eb3210"
102+
version = "1.5.0"
103+
87104
[[deps.LibCURL]]
88105
deps = ["LibCURL_jll", "MozillaCACerts_jll"]
89106
uuid = "b27032c2-a3e7-50c8-80cd-2d36dbcbfd21"
@@ -110,6 +127,22 @@ uuid = "8f399da3-3557-5675-b5ff-fb832c97cbdb"
110127
deps = ["Libdl", "OpenBLAS_jll", "libblastrampoline_jll"]
111128
uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
112129

130+
[[deps.LogExpFunctions]]
131+
deps = ["DocStringExtensions", "IrrationalConstants", "LinearAlgebra"]
132+
git-tree-sha1 = "7d6dd4e9212aebaeed356de34ccf262a3cd415aa"
133+
uuid = "2ab3a3ac-af41-5b50-aa03-7779005ae688"
134+
version = "0.3.26"
135+
136+
[deps.LogExpFunctions.extensions]
137+
LogExpFunctionsChainRulesCoreExt = "ChainRulesCore"
138+
LogExpFunctionsChangesOfVariablesExt = "ChangesOfVariables"
139+
LogExpFunctionsInverseFunctionsExt = "InverseFunctions"
140+
141+
[deps.LogExpFunctions.weakdeps]
142+
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
143+
ChangesOfVariables = "9e997f8a-9a97-42d5-a9f1-ce6bfc15e2c0"
144+
InverseFunctions = "3587e190-3f89-42d0-90ee-14403ec27112"
145+
113146
[[deps.Logging]]
114147
uuid = "56ddb016-857b-54e1-b83d-db4d58db5568"
115148

@@ -144,6 +177,17 @@ deps = ["Artifacts", "CompilerSupportLibraries_jll", "Libdl"]
144177
uuid = "4536629a-c528-5b80-bd46-f80d51c5b363"
145178
version = "0.3.23+2"
146179

180+
[[deps.OpenLibm_jll]]
181+
deps = ["Artifacts", "Libdl"]
182+
uuid = "05823500-19ac-5b8b-9628-191a04bc5112"
183+
version = "0.8.1+2"
184+
185+
[[deps.OpenSpecFun_jll]]
186+
deps = ["Artifacts", "CompilerSupportLibraries_jll", "JLLWrappers", "Libdl", "Pkg"]
187+
git-tree-sha1 = "13652491f6856acfd2db29360e1bbcd4565d04f1"
188+
uuid = "efe28fd5-8261-553b-a9e1-b2916fc3738e"
189+
version = "0.5.5+0"
190+
147191
[[deps.OrderedCollections]]
148192
git-tree-sha1 = "d321bf2de576bf25ec4d3e4360faca399afca282"
149193
uuid = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
@@ -215,11 +259,23 @@ deps = ["Libdl", "LinearAlgebra", "Random", "Serialization", "SuiteSparse_jll"]
215259
uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
216260
version = "1.10.0"
217261

262+
[[deps.SpecialFunctions]]
263+
deps = ["IrrationalConstants", "LogExpFunctions", "OpenLibm_jll", "OpenSpecFun_jll"]
264+
git-tree-sha1 = "e2cfc4012a19088254b3950b85c3c1d8882d864d"
265+
uuid = "276daf66-3868-5448-9aa4-cd146d93841b"
266+
version = "2.3.1"
267+
268+
[deps.SpecialFunctions.extensions]
269+
SpecialFunctionsChainRulesCoreExt = "ChainRulesCore"
270+
271+
[deps.SpecialFunctions.weakdeps]
272+
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
273+
218274
[[deps.StaticArrays]]
219275
deps = ["LinearAlgebra", "Random", "StaticArraysCore"]
220-
git-tree-sha1 = "0da7e6b70d1bb40b1ace3b576da9ea2992f76318"
276+
git-tree-sha1 = "9cabadf6e7cd2349b6cf49f1915ad2028d65e881"
221277
uuid = "90137ffa-7385-5640-81b9-e52037218182"
222-
version = "1.6.0"
278+
version = "1.6.2"
223279
weakdeps = ["Statistics"]
224280

225281
[deps.StaticArrays.extensions]

Project.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,16 @@ ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca"
1212
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1313
SimpleWeightedGraphs = "47aef6b3-ad0c-573a-a1e2-d07658019622"
1414
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
15+
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
1516
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
1617

1718
[compat]
19+
DensityInterface = "0.4"
1820
Graphs = "1.8"
1921
PrecompileTools = "1.1"
2022
ProgressMeter = "1.7"
2123
SimpleWeightedGraphs = "1.4"
24+
SpecialFunctions = "2.3"
2225
julia = "1.9"
2326

2427
[extras]

src/StochasticBlockModelVariants.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ using PrecompileTools: @compile_workload
77
using ProgressMeter: Progress, next!
88
using Random: AbstractRNG, default_rng
99
using SimpleWeightedGraphs: SimpleWeightedGraph
10+
using SpecialFunctions: erf
1011
using Statistics: mean, std
1112
using SparseArrays: SparseMatrixCSC, sparse, findnz
1213

src/csbm_inference.jl

Lines changed: 27 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
## Marginals
22

33
"""
4-
AMPMarginalsCSBM
4+
MarginalsCSBM
55
66
# Fields
77
@@ -11,41 +11,41 @@
1111
- `v̂_no_comm::Vector`: posterior mean of `v` if there were no communities, length `P` (aka `Bᵤ`)
1212
- `h̃₊::Vector`: individual external field for `u=1`, length `N`
1313
- `h̃₋::Vector`: individual external field for `u=-1`, length `N`
14-
- `χ₊e::Dict`: messages about the marginal distribution of `u`, size `(N, N)`
14+
- `χe₊::Dict`: messages about the marginal distribution of `u`, size `(N, N)`
1515
- `χ₊::Vector`: marginal probability of `u=1`, length `N`
1616
"""
17-
@kwdef struct AMPMarginalsCSBM{R<:Real}
17+
@kwdef struct MarginalsCSBM{R<:Real}
1818
::Vector{R}
1919
::Vector{R}
2020
û_no_feat::Vector{R}
2121
v̂_no_comm::Vector{R}
2222
h̃₊::Vector{R}
2323
h̃₋::Vector{R}
24-
χ₊e::Dict{Tuple{Int,Int},R}
24+
χe₊::Dict{Tuple{Int,Int},R}
2525
χ₊::Vector{R}
2626
end
2727

28-
function Base.copy(marginals::AMPMarginalsCSBM)
29-
return AMPMarginalsCSBM(;
28+
function Base.copy(marginals::MarginalsCSBM)
29+
return MarginalsCSBM(;
3030
=copy(marginals.û),
3131
=copy(marginals.v̂),
3232
û_no_feat=copy(marginals.û_no_feat),
3333
v̂_no_comm=copy(marginals.v̂_no_comm),
3434
h̃₊=copy(marginals.h̃₊),
3535
h̃₋=copy(marginals.h̃₋),
36-
χ₊e=copy(marginals.χ₊e),
36+
χe₊=copy(marginals.χe₊),
3737
χ₊=copy(marginals.χ₊),
3838
)
3939
end
4040

41-
function Base.copy!(marginals_dest::AMPMarginalsCSBM, marginals_source::AMPMarginalsCSBM)
41+
function Base.copy!(marginals_dest::MarginalsCSBM, marginals_source::MarginalsCSBM)
4242
copy!(marginals_dest.û, marginals_source.û),
4343
copy!(marginals_dest.v̂, marginals_source.v̂),
4444
copy!(marginals_dest.û_no_feat, marginals_source.û_no_feat),
4545
copy!(marginals_dest.v̂_no_comm, marginals_source.v̂_no_comm),
4646
copy!(marginals_dest.h̃₊, marginals_source.h̃₊),
4747
copy!(marginals_dest.h̃₋, marginals_source.h̃₋),
48-
copy!(marginals_dest.χ₊e, marginals_source.χ₊e),
48+
copy!(marginals_dest.χe₊, marginals_source.χe₊),
4949
copy!(marginals_dest.χ₊, marginals_source.χ₊),
5050
return marginals_dest
5151
end
@@ -60,42 +60,42 @@ function init_amp(
6060

6161
= 2 .* prior₊.(R, Ξ) .- one(R) .+ init_std .* randn(rng, R, N)
6262
= init_std .* randn(rng, R, P)
63-
63+
6464
û_no_feat = zeros(R, N)
6565
v̂_no_comm = zeros(R, P)
66-
66+
6767
h̃₊ = zeros(R, N)
6868
h̃₋ = zeros(R, N)
69-
70-
χ₊e = Dict{Tuple{Int,Int},R}()
69+
70+
χe₊ = Dict{Tuple{Int,Int},R}()
7171
for i in 1:N, j in neighbors(g, i)
72-
χ₊e[i, j] = prior₊(R, Ξ[i]) + init_std * randn(rng, R)
72+
χe₊[i, j] = prior₊(R, Ξ[i]) + init_std * randn(rng, R)
7373
end
7474
χ₊ = zeros(R, N)
75-
76-
marginals = AMPMarginalsCSBM(; û, v̂, û_no_feat, v̂_no_comm, h̃₊, h̃₋, χ₊e, χ₊)
75+
76+
marginals = MarginalsCSBM(; û, v̂, û_no_feat, v̂_no_comm, h̃₊, h̃₋, χe₊, χ₊)
7777
next_marginals = copy(marginals)
7878
return (; marginals, next_marginals)
7979
end
8080

8181
function update_amp!(
82-
next_marginals::AMPMarginalsCSBM{R};
83-
marginals::AMPMarginalsCSBM{R},
82+
next_marginals::MarginalsCSBM{R};
83+
marginals::MarginalsCSBM{R},
8484
observations::ObservationsCSBM{R},
8585
csbm::CSBM{R},
8686
) where {R}
8787
(; d, λ, μ, N, P) = csbm
88-
(; g, B, Ξ) = observations
88+
(; g, Ξ, B) = observations
8989
(; cᵢ, cₒ) = affinities(csbm)
9090

91-
ûᵗ, v̂ᵗ, χ₊eᵗ = marginals.û, marginals.v̂, marginals.χ₊e
92-
ûᵗ⁺¹, v̂ᵗ⁺¹, χ₊eᵗ⁺¹ = next_marginals.û, next_marginals.v̂, next_marginals.χ₊e
91+
ûᵗ, v̂ᵗ, χe₊ᵗ = marginals.û, marginals.v̂, marginals.χe₊
92+
ûᵗ⁺¹, v̂ᵗ⁺¹, χe₊ᵗ⁺¹ = next_marginals.û, next_marginals.v̂, next_marginals.χe₊
9393
(; û_no_feat, v̂_no_comm, h̃₊, h̃₋, χ₊) = next_marginals
9494

9595
ûₜ_sum = sum(ûᵗ)
9696
ûₜ_sum2 = sum(abs2, ûᵗ)
9797

98-
# CSBMAMP estimation of v
98+
# AMP estimation of v
9999
σᵥ_no_comm =/ N) * ûₜ_sum2
100100
mul!(v̂_no_comm, B, ûᵗ)
101101
v̂_no_comm .*= sqrt/ N)
@@ -118,23 +118,23 @@ function update_amp!(
118118
for i in 1:N
119119
s_i = h̃₊[i] - h̃₋[i]
120120
for k in neighbors(g, i)
121-
common = 2λ * sqrt(d) * χ₊eᵗ[k, i]
121+
common = 2λ * sqrt(d) * χe₊ᵗ[k, i]
122122
s_i += log((cₒ + common) / (cᵢ - common))
123123
end
124124
χ₊[i] = s_i
125125
end
126126

127127
# BP update of the messages
128128
for i in 1:N, j in neighbors(g, i)
129-
common = 2λ * sqrt(d) * χ₊eᵗ[j, i]
129+
common = 2λ * sqrt(d) * χe₊ᵗ[j, i]
130130
s_ij = log((cₒ + common) / (cᵢ - common))
131-
χ₊eᵗ⁺¹[i, j] = χ₊[i] - s_ij
131+
χe₊ᵗ⁺¹[i, j] = χ₊[i] - s_ij
132132
end
133133

134134
# Sigmoidize probabilities
135135
χ₊ .= sigmoid.(χ₊)
136-
for (key, val) in pairs(χ₊eᵗ⁺¹)
137-
χ₊eᵗ⁺¹[key] = sigmoid(val)
136+
for (key, val) in pairs(χe₊ᵗ⁺¹)
137+
χe₊ᵗ⁺¹[key] = sigmoid(val)
138138
end
139139

140140
# BP estimation of u

src/glmsbm.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -105,8 +105,8 @@ end
105105

106106
Base.rand(rng::AbstractRNG, ::GaussianWeightPrior{R}) where {R} = randn(rng, R)
107107

108-
fₐ(::RademacherWeightPrior{R}, Λ, Γ) where {R} = tanh(Γ)
109-
fᵥ(::RademacherWeightPrior{R}, Λ, Γ) where {R} = inv(abs2(cosh(Γ)))
108+
fₐ(::RademacherWeightPrior, Λ, Γ) = tanh(Γ)
109+
fᵥ(::RademacherWeightPrior, Λ, Γ) = inv(abs2(cosh(Γ)))
110110

111-
fₐ(::GaussianWeightPrior{R}, Λ, Γ) where {R} = Γ /+ 1)
112-
fᵥ(::GaussianWeightPrior{R}, Λ, Γ) where {R} = 1 /+ 1)
111+
fₐ(::GaussianWeightPrior, Λ, Γ) = Γ /+ 1)
112+
fᵥ(::GaussianWeightPrior, Λ, Γ) = 1 /+ 1)

0 commit comments

Comments
 (0)