Skip to content

Commit 9e739d8

Browse files
committed
More testing
1 parent dee15a7 commit 9e739d8

File tree

5 files changed

+75
-34
lines changed

5 files changed

+75
-34
lines changed

Manifest.toml

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
julia_version = "1.9.2"
44
manifest_format = "2.0"
5-
project_hash = "d06132272ceec5b60df1cd95ecf68b1ee04ca385"
5+
project_hash = "76c4ad29b7fdf61caa5b4fe006d178c04b45b443"
66

77
[[deps.ArgTools]]
88
uuid = "0dad84c5-d112-42e6-8d28-ef12dabb789f"
@@ -142,16 +142,22 @@ deps = ["Artifacts", "Dates", "Downloads", "FileWatching", "LibGit2", "Libdl", "
142142
uuid = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
143143
version = "1.9.2"
144144

145+
[[deps.PrecompileTools]]
146+
deps = ["Preferences"]
147+
git-tree-sha1 = "9673d39decc5feece56ef3940e5dafba15ba0f81"
148+
uuid = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
149+
version = "1.1.2"
150+
151+
[[deps.Preferences]]
152+
deps = ["TOML"]
153+
git-tree-sha1 = "7eb1686b4f04b82f96ed7a4ea5890a4f0c7a09f1"
154+
uuid = "21216c6a-2e73-6563-6e65-726566657250"
155+
version = "1.4.0"
156+
145157
[[deps.Printf]]
146158
deps = ["Unicode"]
147159
uuid = "de0858da-6303-5e67-8744-51eddeeeb8d7"
148160

149-
[[deps.ProgressMeter]]
150-
deps = ["Distributed", "Printf"]
151-
git-tree-sha1 = "d7a7aef8f8f2d537104f170139553b14dfe39fe9"
152-
uuid = "92933f4c-e287-5a05-a399-4b506db050ca"
153-
version = "1.7.2"
154-
155161
[[deps.REPL]]
156162
deps = ["InteractiveUtils", "Markdown", "Sockets", "Unicode"]
157163
uuid = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb"

Project.toml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,15 +6,16 @@ version = "0.1.0"
66
[deps]
77
Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6"
88
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
9-
ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca"
9+
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
1010
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1111
SimpleWeightedGraphs = "47aef6b3-ad0c-573a-a1e2-d07658019622"
1212
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
1313
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
1414

1515
[compat]
1616
Graphs = "1.8"
17-
ProgressMeter = "1.7"
17+
PrecompileTools = "1.1"
18+
ProgressLogging = "0.1"
1819
SimpleWeightedGraphs = "1.4"
1920
julia = "1.9"
2021

src/StochasticBlockModelVariants.jl

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ module StochasticBlockModelVariants
22

33
using Graphs: AbstractGraph, neighbors
44
using LinearAlgebra: dot, mul!
5-
using ProgressMeter: @showprogress
5+
using PrecompileTools: @compile_workload
66
using Random: AbstractRNG, default_rng
77
using SimpleWeightedGraphs: SimpleWeightedGraph
88
using Statistics: mean
@@ -16,4 +16,11 @@ include("utils.jl")
1616
include("csbm.jl")
1717
include("csbm_inference.jl")
1818

19+
@compile_workload begin
20+
rng = default_rng()
21+
csbm = ContextualSBM(; N=10^3, P=10^3, d=5, λ=2, μ=2, ρ=0.0)
22+
(; observations) = rand(rng, csbm)
23+
run_amp(rng; observations, csbm, iterations=2)
24+
end
25+
1926
end

src/csbm_inference.jl

Lines changed: 17 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -123,26 +123,27 @@ function update_amp!(
123123
h̃₊ .= -h₊ .+ log(one(R) / 2) .+ û_no_feat
124124
h̃₋ .= -h₋ .+ log(one(R) / 2) .- û_no_feat
125125

126-
# BP update of the messages
127-
for i in 1:N, j in neighbors(g, i)
128-
s_ij = h̃₊[i] - h̃₋[i]
129-
for k in neighbors(g, i)
130-
if k != j
131-
common = 2λ * sqrt(d) * χ₊eᵗ[k, i]
132-
s_ij += log((cₒ + common) / (cᵢ - common))
133-
end
134-
end
135-
χ₊eᵗ⁺¹[i, j] = sigmoid(s_ij)
136-
end
137-
138126
# BP update of the marginals
139127
for i in 1:N
140128
s_i = h̃₊[i] - h̃₋[i]
141129
for k in neighbors(g, i)
142130
common = 2λ * sqrt(d) * χ₊eᵗ[k, i]
143131
s_i += log((cₒ + common) / (cᵢ - common))
144132
end
145-
χ₊[i] = sigmoid(s_i)
133+
χ₊[i] = s_i
134+
end
135+
136+
# BP update of the messages
137+
for i in 1:N, j in neighbors(g, i)
138+
common = 2λ * sqrt(d) * χ₊eᵗ[j, i]
139+
s_ij = log((cₒ + common) / (cᵢ - common))
140+
χ₊eᵗ⁺¹[i, j] = χ₊[i] - s_ij
141+
end
142+
143+
# Sigmoidize probabilities
144+
χ₊ .= sigmoid.(χ₊)
145+
for (key, val) in pairs(χ₊eᵗ⁺¹)
146+
χ₊eᵗ⁺¹[key] = sigmoid(val)
146147
end
147148

148149
# BP estimation of u
@@ -155,17 +156,15 @@ function run_amp(
155156
rng::AbstractRNG;
156157
observations::ContextualSBMObservations,
157158
csbm::ContextualSBM,
158-
init_std::Real,
159-
iterations::Integer,
159+
init_std::Real=1e-3,
160+
iterations::Integer=10,
160161
)
161162
(; storage, next_storage, temp_storage) = init_amp(rng; observations, csbm, init_std)
162163
storage_history = [copy(storage)]
163-
temp_storage_history = [copy(temp_storage)]
164-
@showprogress "AMP-BP" for iter in 1:iterations
164+
for iter in 1:iterations
165165
update_amp!(next_storage, temp_storage; storage, observations, csbm)
166166
copy!(storage, next_storage)
167167
push!(storage_history, copy(storage))
168-
push!(temp_storage_history, copy(temp_storage))
169168
end
170169
return storage_history
171170
end

test/csbm.jl

Lines changed: 34 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,49 @@
1+
using JET
12
using Random: default_rng
23
using Statistics
34
using StochasticBlockModelVariants
45
using Test
56

67
rng = default_rng()
78

8-
function test_recovery(; N, P, d, λ, μ, ρ, init_std=1e-3, iterations=20)
9-
csbm = ContextualSBM(; N, P, d, λ, μ, ρ)
9+
function test_recovery(csbm::ContextualSBM)
1010
@assert effective_snr(csbm) > 1
1111
(; latents, observations) = rand(rng, csbm)
12-
storage_history = run_amp(rng; observations, csbm, init_std, iterations)
12+
storage_history = run_amp(rng; observations, csbm)
1313
overlap_history = [evaluate_amp(; storage, latents) for storage in storage_history]
1414
@test last(overlap_history) > 10 * first(overlap_history)
1515
@test last(overlap_history) > 0.5
1616
return nothing
1717
end
1818

19-
test_recovery(; N=10^3, P=10^3, d=5, λ=2, μ=2, ρ=0.0) # AMP-BP
20-
test_recovery(; N=10^3, P=10^3, d=5, λ=0, μ=2, ρ=0.0) # AMP
21-
test_recovery(; N=10^3, P=10^3, d=5, λ=2, μ=0, ρ=0.0) # BP
19+
function test_jet(csbm::ContextualSBM)
20+
(; observations) = rand(rng, csbm)
21+
@test_opt target_modules = (StochasticBlockModelVariants,) run_amp(
22+
rng; observations, csbm, iterations=2
23+
)
24+
@test_call target_modules = (StochasticBlockModelVariants,) run_amp(
25+
rng; observations, csbm, iterations=2
26+
)
27+
return nothing
28+
end
29+
30+
function test_allocations(csbm::ContextualSBM)
31+
(; observations) = rand(rng, csbm)
32+
(; storage, next_storage, temp_storage) = init_amp(
33+
rng; observations, csbm, init_std=1e-3
34+
)
35+
alloc = @allocated update_amp!(next_storage, temp_storage; storage, observations, csbm)
36+
@test alloc == 0
37+
return nothing
38+
end
39+
40+
@testset "Correct code" begin
41+
test_recovery(ContextualSBM(; N=10^3, P=10^3, d=5, λ=2, μ=2, ρ=0.0)) # AMP-BP
42+
test_recovery(ContextualSBM(; N=10^3, P=10^3, d=5, λ=0, μ=2, ρ=0.0)) # AMP
43+
test_recovery(ContextualSBM(; N=10^3, P=10^3, d=5, λ=2, μ=0, ρ=0.0)) # BP
44+
end
45+
46+
@testset "Good code" begin
47+
test_jet(ContextualSBM(; N=10^3, P=10^3, d=5, λ=2, μ=2, ρ=0.0))
48+
test_allocations(ContextualSBM(; N=10^3, P=10^3, d=5, λ=2, μ=2, ρ=0.0))
49+
end

0 commit comments

Comments
 (0)