Skip to content

Commit c065ebc

Browse files
committed
More unified interface
1 parent 08ad19a commit c065ebc

File tree

9 files changed

+232
-179
lines changed

9 files changed

+232
-179
lines changed

docs/Manifest.toml

Lines changed: 50 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# This file is machine-generated - editing it directly is not advised
22

3-
julia_version = "1.10.0-beta2"
3+
julia_version = "1.9.3"
44
manifest_format = "2.0"
55
project_hash = "8f377c7585a608434bf87f0689076fc3a3224468"
66

@@ -212,7 +212,7 @@ weakdeps = ["Dates", "LinearAlgebra"]
212212
[[deps.CompilerSupportLibraries_jll]]
213213
deps = ["Artifacts", "Libdl"]
214214
uuid = "e66e0078-7015-5450-92f7-15fbd957f2ae"
215-
version = "1.0.5+1"
215+
version = "1.0.5+0"
216216

217217
[[deps.ConstructionBase]]
218218
deps = ["LinearAlgebra"]
@@ -230,6 +230,11 @@ git-tree-sha1 = "d05d9e7b7aedff4e5b51a029dced05cfb6125781"
230230
uuid = "d38c429a-6771-53c6-b99e-75d170b6e991"
231231
version = "0.6.2"
232232

233+
[[deps.Crayons]]
234+
git-tree-sha1 = "249fe38abf76d48563e2f4556bebd215aa317e15"
235+
uuid = "a8cc5b0e-0ffa-5ad4-8c14-923d3ee1735f"
236+
version = "4.1.1"
237+
233238
[[deps.DataAPI]]
234239
git-tree-sha1 = "8da84edb865b0b5b0100c0666a9bc9a0b71c553c"
235240
uuid = "9a962f9c-6df0-11e9-0e5d-c546b8b5ee8a"
@@ -713,12 +718,12 @@ version = "0.3.1"
713718
[[deps.LibCURL]]
714719
deps = ["LibCURL_jll", "MozillaCACerts_jll"]
715720
uuid = "b27032c2-a3e7-50c8-80cd-2d36dbcbfd21"
716-
version = "0.6.4"
721+
version = "0.6.3"
717722

718723
[[deps.LibCURL_jll]]
719724
deps = ["Artifacts", "LibSSH2_jll", "Libdl", "MbedTLS_jll", "Zlib_jll", "nghttp2_jll"]
720725
uuid = "deac9b47-8bc7-5906-a0fe-35ac56dc84c0"
721-
version = "8.0.1+1"
726+
version = "7.84.0+0"
722727

723728
[[deps.LibGit2]]
724729
deps = ["Base64", "NetworkOptions", "Printf", "SHA"]
@@ -727,7 +732,7 @@ uuid = "76f85450-5226-5b5a-8eaa-529ad045b433"
727732
[[deps.LibSSH2_jll]]
728733
deps = ["Artifacts", "Libdl", "MbedTLS_jll"]
729734
uuid = "29816b5a-b9ab-546f-933c-edad1886dfa8"
730-
version = "1.11.0+1"
735+
version = "1.10.2+0"
731736

732737
[[deps.Libdl]]
733738
uuid = "8f399da3-3557-5675-b5ff-fb832c97cbdb"
@@ -838,6 +843,12 @@ git-tree-sha1 = "2dab0221fe2b0f2cb6754eaa743cc266339f527e"
838843
uuid = "dbb5928d-eab1-5f90-85c2-b9b0edb7c900"
839844
version = "0.4.2"
840845

846+
[[deps.MarchingCubes]]
847+
deps = ["PrecompileTools", "StaticArrays"]
848+
git-tree-sha1 = "c8e29e2bacb98c9b6f10445227a8b0402f2f173a"
849+
uuid = "299715c1-40a9-479a-aaf9-4a633d36f717"
850+
version = "0.1.8"
851+
841852
[[deps.Markdown]]
842853
deps = ["Base64"]
843854
uuid = "d6f4376e-aef5-505a-96c1-9c027394607a"
@@ -856,7 +867,7 @@ version = "0.5.6"
856867
[[deps.MbedTLS_jll]]
857868
deps = ["Artifacts", "Libdl"]
858869
uuid = "c8ffd9c3-330d-5841-b78e-0817d7145fa1"
859-
version = "2.28.2+1"
870+
version = "2.28.2+0"
860871

861872
[[deps.Missings]]
862873
deps = ["DataAPI"]
@@ -880,7 +891,7 @@ version = "0.3.4"
880891

881892
[[deps.MozillaCACerts_jll]]
882893
uuid = "14a3606d-f60d-562e-9121-12d972cd8159"
883-
version = "2023.1.10"
894+
version = "2022.10.11"
884895

885896
[[deps.Multisets]]
886897
git-tree-sha1 = "8d852646862c96e226367ad10c8af56099b4047e"
@@ -929,7 +940,7 @@ version = "1.3.5+1"
929940
[[deps.OpenBLAS_jll]]
930941
deps = ["Artifacts", "CompilerSupportLibraries_jll", "Libdl"]
931942
uuid = "4536629a-c528-5b80-bd46-f80d51c5b363"
932-
version = "0.3.23+2"
943+
version = "0.3.21+4"
933944

934945
[[deps.OpenEXR]]
935946
deps = ["Colors", "FileIO", "OpenEXR_jll"]
@@ -946,7 +957,7 @@ version = "3.1.4+0"
946957
[[deps.OpenLibm_jll]]
947958
deps = ["Artifacts", "Libdl"]
948959
uuid = "05823500-19ac-5b8b-9628-191a04bc5112"
949-
version = "0.8.1+2"
960+
version = "0.8.1+0"
950961

951962
[[deps.OpenSSL_jll]]
952963
deps = ["Artifacts", "JLLWrappers", "Libdl"]
@@ -980,7 +991,7 @@ version = "1.6.2"
980991
[[deps.PCRE2_jll]]
981992
deps = ["Artifacts", "Libdl"]
982993
uuid = "efcefdf7-47ab-520b-bdef-62a2eaa19f15"
983-
version = "10.42.0+1"
994+
version = "10.42.0+0"
984995

985996
[[deps.PDMats]]
986997
deps = ["LinearAlgebra", "SparseArrays", "SuiteSparse"]
@@ -1039,7 +1050,7 @@ version = "0.42.2+0"
10391050
[[deps.Pkg]]
10401051
deps = ["Artifacts", "Dates", "Downloads", "FileWatching", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "Serialization", "TOML", "Tar", "UUIDs", "p7zip_jll"]
10411052
uuid = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
1042-
version = "1.10.0"
1053+
version = "1.9.2"
10431054

10441055
[[deps.PkgVersion]]
10451056
deps = ["Pkg"]
@@ -1127,7 +1138,7 @@ deps = ["InteractiveUtils", "Markdown", "Sockets", "Unicode"]
11271138
uuid = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb"
11281139

11291140
[[deps.Random]]
1130-
deps = ["SHA"]
1141+
deps = ["SHA", "Serialization"]
11311142
uuid = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
11321143

11331144
[[deps.RangeArrays]]
@@ -1291,7 +1302,6 @@ version = "1.1.1"
12911302
[[deps.SparseArrays]]
12921303
deps = ["Libdl", "LinearAlgebra", "Random", "Serialization", "SuiteSparse_jll"]
12931304
uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
1294-
version = "1.10.0"
12951305

12961306
[[deps.SpecialFunctions]]
12971307
deps = ["IrrationalConstants", "LogExpFunctions", "OpenLibm_jll", "OpenSpecFun_jll"]
@@ -1359,7 +1369,7 @@ weakdeps = ["ChainRulesCore", "InverseFunctions"]
13591369
StatsFunsInverseFunctionsExt = "InverseFunctions"
13601370

13611371
[[deps.StochasticBlockModelVariants]]
1362-
deps = ["DensityInterface", "DocStringExtensions", "Graphs", "LinearAlgebra", "PrecompileTools", "ProgressMeter", "Random", "SimpleWeightedGraphs", "SparseArrays", "SpecialFunctions", "Statistics"]
1372+
deps = ["DensityInterface", "DocStringExtensions", "ForwardDiff", "Graphs", "LinearAlgebra", "Optim", "PrecompileTools", "ProgressMeter", "Random", "SimpleWeightedGraphs", "SparseArrays", "SpecialFunctions", "Statistics", "UnicodePlots"]
13631373
path = ".."
13641374
uuid = "37adabcb-1964-4c56-850d-39a4302c0c39"
13651375
version = "0.1.0"
@@ -1377,7 +1387,7 @@ uuid = "4607b0f0-06f3-5cda-b6b1-a6196a1729e9"
13771387
[[deps.SuiteSparse_jll]]
13781388
deps = ["Artifacts", "Libdl", "Pkg", "libblastrampoline_jll"]
13791389
uuid = "bea87d4a-7f5b-5778-9afe-8cc45184846c"
1380-
version = "7.2.0+1"
1390+
version = "5.10.1+6"
13811391

13821392
[[deps.TOML]]
13831393
deps = ["Dates"]
@@ -1451,6 +1461,27 @@ git-tree-sha1 = "53915e50200959667e78a92a418594b428dffddf"
14511461
uuid = "1cfade01-22cf-5700-b092-accc4b62d6e1"
14521462
version = "0.4.1"
14531463

1464+
[[deps.UnicodePlots]]
1465+
deps = ["ColorSchemes", "ColorTypes", "Contour", "Crayons", "Dates", "LinearAlgebra", "MarchingCubes", "NaNMath", "PrecompileTools", "Printf", "Requires", "SparseArrays", "StaticArrays", "StatsBase"]
1466+
git-tree-sha1 = "b96de03092fe4b18ac7e4786bee55578d4b75ae8"
1467+
uuid = "b8865327-cd53-5732-bb35-84acbb429228"
1468+
version = "3.6.0"
1469+
1470+
[deps.UnicodePlots.extensions]
1471+
FreeTypeExt = ["FileIO", "FreeType"]
1472+
ImageInTerminalExt = "ImageInTerminal"
1473+
IntervalSetsExt = "IntervalSets"
1474+
TermExt = "Term"
1475+
UnitfulExt = "Unitful"
1476+
1477+
[deps.UnicodePlots.weakdeps]
1478+
FileIO = "5789e2e9-d7fb-5bc7-8068-2c6fae9b9549"
1479+
FreeType = "b38be410-82b0-50bf-ab77-7b57e271db43"
1480+
ImageInTerminal = "d8c32880-2388-543b-8c61-d9f865259254"
1481+
IntervalSets = "8197267c-284f-5f27-9208-e0e47529a953"
1482+
Term = "22787eb5-b846-44ae-b979-8e399b8463ab"
1483+
Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d"
1484+
14541485
[[deps.WoodburyMatrices]]
14551486
deps = ["LinearAlgebra", "SparseArrays"]
14561487
git-tree-sha1 = "de67fa59e33ad156a590055375a30b23c40299d3"
@@ -1520,7 +1551,7 @@ version = "1.5.0+0"
15201551
[[deps.Zlib_jll]]
15211552
deps = ["Libdl"]
15221553
uuid = "83775a58-1f1d-513f-b197-d71354ab007a"
1523-
version = "1.2.13+1"
1554+
version = "1.2.13+0"
15241555

15251556
[[deps.isoband_jll]]
15261557
deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"]
@@ -1543,7 +1574,7 @@ version = "0.15.1+0"
15431574
[[deps.libblastrampoline_jll]]
15441575
deps = ["Artifacts", "Libdl"]
15451576
uuid = "8e850b90-86db-534c-a0d3-1478176c7d93"
1546-
version = "5.8.0+1"
1577+
version = "5.8.0+0"
15471578

15481579
[[deps.libfdk_aac_jll]]
15491580
deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"]
@@ -1572,12 +1603,12 @@ version = "1.3.7+1"
15721603
[[deps.nghttp2_jll]]
15731604
deps = ["Artifacts", "Libdl"]
15741605
uuid = "8e850ede-7688-5339-a07c-302acd2aaf8d"
1575-
version = "1.52.0+1"
1606+
version = "1.48.0+0"
15761607

15771608
[[deps.p7zip_jll]]
15781609
deps = ["Artifacts", "Libdl"]
15791610
uuid = "3f19e933-33d8-53b3-aaab-bd5110c3b7a0"
1580-
version = "17.4.0+2"
1611+
version = "17.4.0+0"
15811612

15821613
[[deps.x264_jll]]
15831614
deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"]

docs/plots_csbm.jl

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
using Base.Threads
22
using CairoMakie
33
using LinearAlgebra
4-
using Random: default_rng
4+
using Random: default_rng, rand!
55
using Statistics
66
using StochasticBlockModelVariants
77
using ProgressMeter
@@ -17,11 +17,13 @@ function compute_fig1_csbm(; α, d, μ, ρ, N_values, λ_values, trials)
1717
prog = Progress(I * J * K; desc="CSBM - Fig 1")
1818
for i in 1:I
1919
for j in 1:J
20-
@threads for k in 1:K
21-
N, λ = N_values[i], λ_values[j]
22-
P = ceil(Int, N / α)
23-
csbm = CSBM(; N, P, d, μ, λ, ρ)
24-
(qu, qv, converged) = evaluate_amp(rng, csbm)
20+
N, λ = N_values[i], λ_values[j]
21+
P = ceil(Int, N / α)
22+
csbm = CSBM(; N, P, d, μ, λ, ρ)
23+
B = rand(rng, csbm).observations.B
24+
for k in 1:K # don't parallelize
25+
(; latents, observations) = rand!(rng, B, csbm)
26+
(qu, qv, converged) = evaluate_amp(rng, csbm, latents, observations)
2527
qu_values[i, j, k] = qu
2628
qv_values[i, j, k] = qv
2729
converged_values[i, j, k] = converged

docs/plots_glmsbm.jl

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
using Base.Threads
22
using CairoMakie
33
using LinearAlgebra
4-
using Random: default_rng
4+
using Random: default_rng, rand!
55
using Statistics
66
using StochasticBlockModelVariants
77
using ProgressMeter
@@ -17,11 +17,13 @@ function compute_fig1_glmsbm(; N, c, ρ, Pʷ, α_values, λ_values, trials)
1717
prog = Progress(I * J * K; desc="CSBM - Fig 1")
1818
for i in 1:I
1919
for j in 1:J
20-
@threads for k in 1:K
21-
α, λ = α_values[i], λ_values[j]
22-
M = ceil(Int, N / α)
23-
glmsbm = GLMSBM(; N, M, c, λ, ρ, Pʷ)
24-
(qs, qw, converged) = evaluate_amp(rng, glmsbm)
20+
α, λ = α_values[i], λ_values[j]
21+
M = ceil(Int, N / α)
22+
glmsbm = GLMSBM(; N, M, c, λ, ρ, Pʷ)
23+
F = rand(rng, glmsbm).observations.F
24+
for k in 1:K # don't parallelize
25+
(; latents, observations) = rand!(rng, F, glmsbm)
26+
(qs, qw, converged) = evaluate_amp(rng, glmsbm, latents, observations)
2527
qs_values[i, j, k] = qs
2628
qw_values[i, j, k] = qw
2729
converged_values[i, j, k] = converged
@@ -43,8 +45,8 @@ function plot_fig1_glmsbm(res; α_values, λ_values)
4345
lines!(ax1, λ_values, qs_means; label="α=")
4446
errorbars!(ax1, λ_values, qs_means, qs_stds; label=nothing)
4547

46-
qw_means = dropdims(mean(qs_values[i, :, :]; dims=2); dims=2)
47-
qw_stds = dropdims(std(qs_values[i, :, :]; dims=2); dims=2)
48+
qw_means = dropdims(mean(qw_values[i, :, :]; dims=2); dims=2)
49+
qw_stds = dropdims(std(qw_values[i, :, :]; dims=2); dims=2)
4850
lines!(ax2, λ_values, qw_means; label="α=")
4951
errorbars!(ax2, λ_values, qw_means, qw_stds; label=nothing)
5052
end
@@ -58,6 +60,6 @@ Pʷ = GaussianWeightPrior()
5860
α_values = reverse([0.3, 1, 3, 10, 30])
5961
α_values = reverse([0.3, 1])
6062
λ_values = 0:0.1:2
61-
trials = 10
63+
trials = 1
6264
res1 = compute_fig1_glmsbm(; N, c, ρ, Pʷ, α_values, λ_values, trials) # kills Julia
6365
plot_fig1_glmsbm(res1; α_values, λ_values)

src/StochasticBlockModelVariants.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ using Graphs: AbstractGraph, neighbors
1515
using LinearAlgebra: LinearAlgebra, dot, mul!, norm
1616
using PrecompileTools: @compile_workload
1717
using ProgressMeter: Progress, next!
18-
using Random: AbstractRNG, default_rng
18+
using Random: Random, AbstractRNG, default_rng, rand!, randn!
1919
using SimpleWeightedGraphs: SimpleWeightedGraph
2020
using SpecialFunctions: erf
2121
using Statistics: mean, std
@@ -26,7 +26,7 @@ export nb_features, average_degree, communities_snr, affinities, features_snr, e
2626
export CSBM, LatentsCSBM, ObservationsCSBM
2727
export GLMSBM, LatentsGLMSBM, ObservationsGLMSBM
2828
export GaussianWeightPrior, RademacherWeightPrior
29-
export init_amp, update_amp!, run_amp, evaluate_amp
29+
export init_amp, update_amp!, run_amp, evaluate_amp, semisupervised_loss_amp
3030
export discrete_overlap, continuous_overlap
3131

3232
include("utils.jl")

0 commit comments

Comments
 (0)