|
| 1 | +using JET |
1 | 2 | using Random: default_rng |
2 | 3 | using Statistics |
3 | 4 | using StochasticBlockModelVariants |
4 | 5 | using Test |
5 | 6 |
|
6 | 7 | rng = default_rng() |
7 | 8 |
|
8 | | -function test_recovery(; N, P, d, λ, μ, ρ, init_std=1e-3, iterations=20) |
9 | | - csbm = ContextualSBM(; N, P, d, λ, μ, ρ) |
| 9 | +function test_recovery(csbm::ContextualSBM) |
10 | 10 | @assert effective_snr(csbm) > 1 |
11 | 11 | (; latents, observations) = rand(rng, csbm) |
12 | | - storage_history = run_amp(rng; observations, csbm, init_std, iterations) |
| 12 | + storage_history = run_amp(rng; observations, csbm) |
13 | 13 | overlap_history = [evaluate_amp(; storage, latents) for storage in storage_history] |
14 | 14 | @test last(overlap_history) > 10 * first(overlap_history) |
15 | 15 | @test last(overlap_history) > 0.5 |
16 | 16 | return nothing |
17 | 17 | end |
18 | 18 |
|
19 | | -test_recovery(; N=10^3, P=10^3, d=5, λ=2, μ=2, ρ=0.0) # AMP-BP |
20 | | -test_recovery(; N=10^3, P=10^3, d=5, λ=0, μ=2, ρ=0.0) # AMP |
21 | | -test_recovery(; N=10^3, P=10^3, d=5, λ=2, μ=0, ρ=0.0) # BP |
| 19 | +function test_jet(csbm::ContextualSBM) |
| 20 | + (; observations) = rand(rng, csbm) |
| 21 | + @test_opt target_modules = (StochasticBlockModelVariants,) run_amp( |
| 22 | + rng; observations, csbm, iterations=2 |
| 23 | + ) |
| 24 | + @test_call target_modules = (StochasticBlockModelVariants,) run_amp( |
| 25 | + rng; observations, csbm, iterations=2 |
| 26 | + ) |
| 27 | + return nothing |
| 28 | +end |
| 29 | + |
| 30 | +function test_allocations(csbm::ContextualSBM) |
| 31 | + (; observations) = rand(rng, csbm) |
| 32 | + (; storage, next_storage, temp_storage) = init_amp( |
| 33 | + rng; observations, csbm, init_std=1e-3 |
| 34 | + ) |
| 35 | + alloc = @allocated update_amp!(next_storage, temp_storage; storage, observations, csbm) |
| 36 | + @test alloc == 0 |
| 37 | + return nothing |
| 38 | +end |
| 39 | + |
| 40 | +@testset "Correct code" begin |
| 41 | + test_recovery(ContextualSBM(; N=10^3, P=10^3, d=5, λ=2, μ=2, ρ=0.0)) # AMP-BP |
| 42 | + test_recovery(ContextualSBM(; N=10^3, P=10^3, d=5, λ=0, μ=2, ρ=0.0)) # AMP |
| 43 | + test_recovery(ContextualSBM(; N=10^3, P=10^3, d=5, λ=2, μ=0, ρ=0.0)) # BP |
| 44 | +end |
| 45 | + |
| 46 | +@testset "Good code" begin |
| 47 | + test_jet(ContextualSBM(; N=10^3, P=10^3, d=5, λ=2, μ=2, ρ=0.0)) |
| 48 | + test_allocations(ContextualSBM(; N=10^3, P=10^3, d=5, λ=2, μ=2, ρ=0.0)) |
| 49 | +end |
0 commit comments