From 602087801aee285d5013b26e735595febbaa80fd Mon Sep 17 00:00:00 2001 From: Samuel Brand <48288458+SamuelBrand1@users.noreply.github.com> Date: Fri, 10 Oct 2025 14:04:25 +0100 Subject: [PATCH 1/6] Add support for custom metrics in HMC Extended the HMC sampler to support arbitrary metrics (mass matrices), including vectors, Diagonal, and dense matrices, for improved sampling efficiency. Updated momenta sampling and log-probability assessment to handle these metrics. Added smoke unit tests --- src/inference/hmc.jl | 73 ++++++++++++++++++++++++++++++++++++++----- test/inference/hmc.jl | 44 +++++++++++++++++++++++++- 2 files changed, 109 insertions(+), 8 deletions(-) diff --git a/src/inference/hmc.jl b/src/inference/hmc.jl index 95a6fdeeb..00716f9a5 100644 --- a/src/inference/hmc.jl +++ b/src/inference/hmc.jl @@ -1,7 +1,27 @@ +# Momenta sampling with different metrics + function sample_momenta(n::Int) Float64[random(normal, 0, 1) for _=1:n] end +function sample_momenta(n::Int, metric::AbstractVector) + return sqrt.(metric) .* sample_momenta(n) +end + +function sample_momenta(n::Int, metric::LinearAlgebra.Diagonal) + sample_momenta(n::Int, LinearAlgebra.diag(metric)) +end + +function sample_momenta(n::Int, metric::AbstractMatrix) + mvnormal(zeros(n), metric) +end + +function sample_momenta(n::Int, metric::Nothing) + sample_momenta(n) +end + +# Assessing momenta log probabilities with different metrics + function assess_momenta(momenta) logprob = 0. for val in momenta @@ -10,21 +30,60 @@ function assess_momenta(momenta) logprob end +function assess_momenta(momenta, metric::AbstractVector) + logprob = 0. + for (val, m) in zip(momenta, metric) + logprob += logpdf(normal, val, 0, sqrt(m)) + end + logprob +end + +function assess_momenta(momenta, metric::LinearAlgebra.Diagonal) + assess_momenta(momenta, LinearAlgebra.diag(metric)) +end + +function assess_momenta(momenta, metric::AbstractMatrix) + logpdf(mvnormal, momenta, zeros(length(momenta)), metric) +end + +function assess_momenta(momenta, metric::Nothing) + assess_momenta(momenta) +end + """ (new_trace, accepted) = hmc( trace, selection::Selection; L=10, eps=0.1, - check=false, observations=EmptyChoiceMap()) + check=false, observations=EmptyChoiceMap(), metric = nothing) + +Apply a Hamiltonian Monte Carlo (HMC) update that proposes new values for the +selected addresses, returning the new trace (which is equal to the previous trace +if the move was not accepted) and a `Bool` indicating whether the move was accepted or not. + +Hamilton's equations are numerically integrated using leapfrog integration with +step size `eps` for `L` steps and initial momenta sampled from a Gaussian distribution with +covariance given by `metric` (mass matrix). + +## `metric` options + +Sampling with HMC is improved by using a metric/mass matrix that approximates the +**inverse** covariance of the target distribution, and is equivalent to a linear transformation +of the parameter space (see Neal, 2011). + +The following options are supported for `metric`: -Apply a Hamiltonian Monte Carlo (HMC) update that proposes new values for the selected addresses, returning the new trace (which is equal to the previous trace if the move was not accepted) and a `Bool` indicating whether the move was accepted or not. +- `nothing` (default): identity matrix +- `Vector`: diagonal matrix with the given vector as the diagonal +- `Diagonal`: diagonal matrix lowers to the vector of the diagonal entries +- `Matrix`: dense matrix -Hamilton's equations are numerically integrated using leapfrog integration with step size `eps` for `L` steps. See equations (5.18)-(5.20) of Neal (2011). +See equations (5.18)-(5.20) of Neal (2011). # References Neal, Radford M. (2011), "MCMC Using Hamiltonian Dynamics", Handbook of Markov Chain Monte Carlo, pp. 113-162. URL: http://www.mcmchandbook.net/HandbookChapter5.pdf """ function hmc( trace::Trace, selection::Selection; L=10, eps=0.1, - check=false, observations=EmptyChoiceMap()) + check=false, observations=EmptyChoiceMap(), metric = nothing) prev_model_score = get_score(trace) args = get_args(trace) retval_grad = accepts_output_grad(get_gen_fn(trace)) ? zero(get_retval(trace)) : nothing @@ -35,8 +94,8 @@ function hmc( (_, values_trie, gradient_trie) = choice_gradients(new_trace, selection, retval_grad) values = to_array(values_trie, Float64) gradient = to_array(gradient_trie, Float64) - momenta = sample_momenta(length(values)) - prev_momenta_score = assess_momenta(momenta) + momenta = sample_momenta(length(values), metric) + prev_momenta_score = assess_momenta(momenta, metric) for step=1:L # half step on momenta @@ -60,7 +119,7 @@ function hmc( new_model_score = get_score(new_trace) # assess new momenta score (negative kinetic energy) - new_momenta_score = assess_momenta(-momenta) + new_momenta_score = assess_momenta(-momenta, metric) # accept or reject alpha = new_model_score - prev_model_score + new_momenta_score - prev_momenta_score diff --git a/test/inference/hmc.jl b/test/inference/hmc.jl index c5887751c..25575154a 100644 --- a/test/inference/hmc.jl +++ b/test/inference/hmc.jl @@ -1,5 +1,5 @@ @testset "hmc" begin - + import LinearAlgebra # smoke test a function without retval gradient @gen function foo() x = @trace(normal(0, 1), :x) @@ -17,4 +17,46 @@ (trace, _) = generate(foo, ()) (new_trace, accepted) = hmc(trace, select(:x)) + + # smoke test with vector metric + @gen function bar() + x = @trace(normal(0, 1), :x) + y = @trace(normal(0, 1), :y) + return x + y + end + + (trace, _) = generate(bar, ()) + metric_vec = [1.0, 2.0] + (new_trace, accepted) = hmc(trace, select(:x, :y); metric=metric_vec) + + # smoke test with Diagonal metric + (trace, _) = generate(bar, ()) + metric_diag = LinearAlgebra.Diagonal([1.0, 2.0]) + (new_trace, accepted) = hmc(trace, select(:x, :y); metric=metric_diag) + + # smoke test with Dense matrix metric + (trace, _) = generate(bar, ()) + metric_dense = [1.0 0.1; 0.1 2.0] + (new_trace, accepted) = hmc(trace, select(:x, :y); metric=metric_dense) + + # smoke test with vector metric and retval gradient + @gen (grad) function bar_grad() + x = @trace(normal(0, 1), :x) + y = @trace(normal(0, 1), :y) + return x + y + end + + (trace, _) = generate(bar_grad, ()) + metric_vec = [0.5, 1.5] + (new_trace, accepted) = hmc(trace, select(:x, :y); metric=metric_vec) + + # smoke test with Diagonal metric and retval gradient + (trace, _) = generate(bar_grad, ()) + metric_diag = LinearAlgebra.Diagonal([0.5, 1.5]) + (new_trace, accepted) = hmc(trace, select(:x, :y); metric=metric_diag) + + # smoke test with Dense matrix metric and retval gradient + (trace, _) = generate(bar_grad, ()) + metric_dense = [0.5 0.2; 0.2 1.5] + (new_trace, accepted) = hmc(trace, select(:x, :y); metric=metric_dense) end From 7ee3a1b2f6390d1b4d4512eef36a651e90be513a Mon Sep 17 00:00:00 2001 From: Samuel Brand <48288458+SamuelBrand1@users.noreply.github.com> Date: Fri, 10 Oct 2025 14:30:01 +0100 Subject: [PATCH 2/6] add check for diag case --- src/inference/hmc.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/src/inference/hmc.jl b/src/inference/hmc.jl index 00716f9a5..db32a1467 100644 --- a/src/inference/hmc.jl +++ b/src/inference/hmc.jl @@ -5,6 +5,7 @@ function sample_momenta(n::Int) end function sample_momenta(n::Int, metric::AbstractVector) + @assert all(>(0), metric) "All diagonal metric values must be positive" return sqrt.(metric) .* sample_momenta(n) end From c865885b8bcf195cd50951d4e2686c26ab2b64bb Mon Sep 17 00:00:00 2001 From: Samuel Brand <48288458+SamuelBrand1@users.noreply.github.com> Date: Fri, 10 Oct 2025 15:04:46 +0100 Subject: [PATCH 3/6] Add a grad test --- test/inference/hmc.jl | 115 +++++++++++++++++++++++++++++++++++++++++- 1 file changed, 113 insertions(+), 2 deletions(-) diff --git a/test/inference/hmc.jl b/test/inference/hmc.jl index 25575154a..4ca9d9c34 100644 --- a/test/inference/hmc.jl +++ b/test/inference/hmc.jl @@ -1,5 +1,6 @@ -@testset "hmc" begin - import LinearAlgebra +@testset "hmc tests" begin + import Distributions, LinearAlgebra, Random + # smoke test a function without retval gradient @gen function foo() x = @trace(normal(0, 1), :x) @@ -18,6 +19,13 @@ (trace, _) = generate(foo, ()) (new_trace, accepted) = hmc(trace, select(:x)) + (_, values_trie, gradient_trie) = choice_gradients(trace, select(:x), 0) + values = to_array(values_trie, Float64) + grad = to_array(gradient_trie, Float64) + + # For Normal(0,1), grad should be -x + @test values ≈ -grad + # smoke test with vector metric @gen function bar() x = @trace(normal(0, 1), :x) @@ -29,6 +37,10 @@ metric_vec = [1.0, 2.0] (new_trace, accepted) = hmc(trace, select(:x, :y); metric=metric_vec) + + # For Normal(0,1), grad should be -x + @test values ≈ -grad + # smoke test with Diagonal metric (trace, _) = generate(bar, ()) metric_diag = LinearAlgebra.Diagonal([1.0, 2.0]) @@ -50,6 +62,13 @@ metric_vec = [0.5, 1.5] (new_trace, accepted) = hmc(trace, select(:x, :y); metric=metric_vec) + (_, values_trie, gradient_trie) = choice_gradients(trace, select(:x, :y), 0) + values = to_array(values_trie, Float64) + grad = to_array(gradient_trie, Float64) + + # For each Normal(0,1), grad should be -x + @test values ≈ -grad + # smoke test with Diagonal metric and retval gradient (trace, _) = generate(bar_grad, ()) metric_diag = LinearAlgebra.Diagonal([0.5, 1.5]) @@ -60,3 +79,95 @@ metric_dense = [0.5 0.2; 0.2 1.5] (new_trace, accepted) = hmc(trace, select(:x, :y); metric=metric_dense) end + +## + +@testset "hmc metric behavior" begin + import Distributions, LinearAlgebra, Random + + # RNG state for reproducibility + # As per 1.12 this can be passed to the testset + rng=Random.Xoshiro(0x2e026445595ed28e) + + # test that different metrics produce different behavior + @gen function test_metric_effect() + x = @trace(normal(0, 1), :x) + y = @trace(normal(0, 1), :y) + return x + y + end + + (trace1, _) = generate(test_metric_effect, ()) + + + # Set RNG to a known state for comparison + Random.seed!(rng, 1) + + # Run HMC with identity metric (default) + (trace_identity, _) = hmc(trace1, select(:x, :y); L=5) + + # Reset RNG to same state for comparison + Random.seed!(rng, 1) + + # Run HMC with scaled metric (should behave differently) + metric_scaled = [10.0, 0.1] # Very different scales + (trace_scaled, _) = hmc(trace1, select(:x, :y); L=5, metric=metric_scaled) + + # With same RNG sequence but different metrics, should get different results + @test get_choices(trace_identity) != get_choices(trace_scaled) + + # test that diagonal and dense metrics with same diagonal values are similar + # @gen function test_diagonal_equivalence() + # x = @trace(normal(0, 1), :x) + # y = @trace(normal(0, 1), :y) + # return x + y + # end + + # (trace2, _) = generate(test_diagonal_equivalence, ()) + + # # Test many times to check statistical similarity + # acceptances_diag = Float64[] + # acceptances_dense = Float64[] + + # for i in 1:50 + # # Reset to predictable state for each iteration + # Random.seed!(Random.default_rng(), (initial_state[1] + i, initial_state[2], initial_state[3], initial_state[4], initial_state[5])) + # (_, accepted_diag) = hmc(trace2, select(:x, :y); + # metric=LinearAlgebra.Diagonal([2.0, 3.0])) + + # # Reset to same state for comparison + # Random.seed!(Random.default_rng(), (initial_state[1] + i, initial_state[2], initial_state[3], initial_state[4], initial_state[5])) + # (_, accepted_dense) = hmc(trace2, select(:x, :y); + # metric=[2.0 0.0; 0.0 3.0]) + + # push!(acceptances_diag, accepted_diag ? 1.0 : 0.0) + # push!(acceptances_dense, accepted_dense ? 1.0 : 0.0) + # end + + # # Should have similar acceptance rates (within 20%) + # rate_diag = Distributions.mean(acceptances_diag) + # rate_dense = Distributions.mean(acceptances_dense) + # @test abs(rate_diag - rate_dense) < 0.2 + + +end + +## + +@testset "Bad metric catches" begin + @gen function bar() + x = @trace(normal(0, 1), :x) + y = @trace(normal(0, 1), :y) + return x + y + end + + bad_metrics =([-1.0 -20.0; 0.0 1.0], # Bad dense, + LinearAlgebra.Diagonal([-1.0, -20.0]), # Bad diag + [-5.0, 20.0], # Bad vector diag + ) + + for bad_metric in bad_metrics + (trace3, _) = generate(bar, ()) + @test_throws Exception hmc(trace3, select(:x, :y); metric=bad_metric) + end + +end \ No newline at end of file From 265ee8e40e6af44ed5f657daed7bc0419caa767e Mon Sep 17 00:00:00 2001 From: Samuel Brand <48288458+SamuelBrand1@users.noreply.github.com> Date: Fri, 10 Oct 2025 15:22:07 +0100 Subject: [PATCH 4/6] Refactor and expand HMC tests for metrics and gradients - Classic grad = -x for logpdf of x ~ N(0,1) check - Check that identical metrics in different forms give similar sampling - Check that different metrics give different sampling - Check the bad metric catches --- test/inference/hmc.jl | 71 ++++++++++++++++--------------------------- 1 file changed, 27 insertions(+), 44 deletions(-) diff --git a/test/inference/hmc.jl b/test/inference/hmc.jl index 4ca9d9c34..c95e54e43 100644 --- a/test/inference/hmc.jl +++ b/test/inference/hmc.jl @@ -1,5 +1,5 @@ @testset "hmc tests" begin - import Distributions, LinearAlgebra, Random + import LinearAlgebra, Random # smoke test a function without retval gradient @gen function foo() @@ -19,11 +19,10 @@ (trace, _) = generate(foo, ()) (new_trace, accepted) = hmc(trace, select(:x)) + # For Normal(0,1), grad should be -x (_, values_trie, gradient_trie) = choice_gradients(trace, select(:x), 0) values = to_array(values_trie, Float64) grad = to_array(gradient_trie, Float64) - - # For Normal(0,1), grad should be -x @test values ≈ -grad # smoke test with vector metric @@ -37,10 +36,6 @@ metric_vec = [1.0, 2.0] (new_trace, accepted) = hmc(trace, select(:x, :y); metric=metric_vec) - - # For Normal(0,1), grad should be -x - @test values ≈ -grad - # smoke test with Diagonal metric (trace, _) = generate(bar, ()) metric_diag = LinearAlgebra.Diagonal([1.0, 2.0]) @@ -62,11 +57,10 @@ metric_vec = [0.5, 1.5] (new_trace, accepted) = hmc(trace, select(:x, :y); metric=metric_vec) + # For each Normal(0,1), grad should be -x (_, values_trie, gradient_trie) = choice_gradients(trace, select(:x, :y), 0) values = to_array(values_trie, Float64) grad = to_array(gradient_trie, Float64) - - # For each Normal(0,1), grad should be -x @test values ≈ -grad # smoke test with Diagonal metric and retval gradient @@ -80,13 +74,11 @@ (new_trace, accepted) = hmc(trace, select(:x, :y); metric=metric_dense) end -## - @testset "hmc metric behavior" begin - import Distributions, LinearAlgebra, Random + import LinearAlgebra, Random # RNG state for reproducibility - # As per 1.12 this can be passed to the testset + # As per Julia 1.12 this can be passed to the testset but would fail ci rng=Random.Xoshiro(0x2e026445595ed28e) # test that different metrics produce different behavior @@ -115,44 +107,35 @@ end # With same RNG sequence but different metrics, should get different results @test get_choices(trace_identity) != get_choices(trace_scaled) - # test that diagonal and dense metrics with same diagonal values are similar - # @gen function test_diagonal_equivalence() - # x = @trace(normal(0, 1), :x) - # y = @trace(normal(0, 1), :y) - # return x + y - # end - - # (trace2, _) = generate(test_diagonal_equivalence, ()) + # With same metric but different representations, should get similar results + # Test many times to check statistical similarity + acceptances_diag = Float64[] + acceptances_dense = Float64[] - # # Test many times to check statistical similarity - # acceptances_diag = Float64[] - # acceptances_dense = Float64[] - - # for i in 1:50 - # # Reset to predictable state for each iteration - # Random.seed!(Random.default_rng(), (initial_state[1] + i, initial_state[2], initial_state[3], initial_state[4], initial_state[5])) - # (_, accepted_diag) = hmc(trace2, select(:x, :y); - # metric=LinearAlgebra.Diagonal([2.0, 3.0])) + for i in 1:50 + # Reset to predictable state for each iteration + Random.seed!(rng, i) + (_, accepted_diag) = hmc(trace1, select(:x, :y); + metric=LinearAlgebra.Diagonal([2.0, 3.0])) - # # Reset to same state for comparison - # Random.seed!(Random.default_rng(), (initial_state[1] + i, initial_state[2], initial_state[3], initial_state[4], initial_state[5])) - # (_, accepted_dense) = hmc(trace2, select(:x, :y); - # metric=[2.0 0.0; 0.0 3.0]) + # Reset to same state for comparison + Random.seed!(rng, i) + (_, accepted_dense) = hmc(trace1, select(:x, :y); + metric=[2.0 0.0; 0.0 3.0]) - # push!(acceptances_diag, accepted_diag ? 1.0 : 0.0) - # push!(acceptances_dense, accepted_dense ? 1.0 : 0.0) - # end + # Collect acceptance results + push!(acceptances_diag, float(accepted_diag)) + push!(acceptances_dense, float(accepted_dense)) + end # # Should have similar acceptance rates (within 20%) - # rate_diag = Distributions.mean(acceptances_diag) - # rate_dense = Distributions.mean(acceptances_dense) - # @test abs(rate_diag - rate_dense) < 0.2 + rate_diag = Distributions.mean(acceptances_diag) + rate_dense = Distributions.mean(acceptances_dense) + @test abs(rate_diag - rate_dense) < 0.2 end -## - @testset "Bad metric catches" begin @gen function bar() x = @trace(normal(0, 1), :x) @@ -166,8 +149,8 @@ end ) for bad_metric in bad_metrics - (trace3, _) = generate(bar, ()) - @test_throws Exception hmc(trace3, select(:x, :y); metric=bad_metric) + (trace, _) = generate(bar, ()) + @test_throws Exception hmc(trace, select(:x, :y); metric=bad_metric) end end \ No newline at end of file From 7d0f535e23097b4e7da59d5730006f31f7514525 Mon Sep 17 00:00:00 2001 From: Samuel Brand <48288458+SamuelBrand1@users.noreply.github.com> Date: Fri, 17 Oct 2025 15:43:17 +0100 Subject: [PATCH 5/6] Change to global rng This maintains backwards compat with Julia 1.6 --- test/inference/hmc.jl | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/test/inference/hmc.jl b/test/inference/hmc.jl index c95e54e43..e465351b6 100644 --- a/test/inference/hmc.jl +++ b/test/inference/hmc.jl @@ -76,10 +76,6 @@ end @testset "hmc metric behavior" begin import LinearAlgebra, Random - - # RNG state for reproducibility - # As per Julia 1.12 this can be passed to the testset but would fail ci - rng=Random.Xoshiro(0x2e026445595ed28e) # test that different metrics produce different behavior @gen function test_metric_effect() @@ -92,13 +88,13 @@ end # Set RNG to a known state for comparison - Random.seed!(rng, 1) + Random.seed!(1) # Run HMC with identity metric (default) (trace_identity, _) = hmc(trace1, select(:x, :y); L=5) # Reset RNG to same state for comparison - Random.seed!(rng, 1) + Random.seed!(1) # Run HMC with scaled metric (should behave differently) metric_scaled = [10.0, 0.1] # Very different scales @@ -114,12 +110,12 @@ end for i in 1:50 # Reset to predictable state for each iteration - Random.seed!(rng, i) + Random.seed!(i) (_, accepted_diag) = hmc(trace1, select(:x, :y); metric=LinearAlgebra.Diagonal([2.0, 3.0])) # Reset to same state for comparison - Random.seed!(rng, i) + Random.seed!(i) (_, accepted_dense) = hmc(trace1, select(:x, :y); metric=[2.0 0.0; 0.0 3.0]) From 3609d3e92406e4bbf3a5311670fe7be8d9377cb8 Mon Sep 17 00:00:00 2001 From: ztangent Date: Wed, 22 Oct 2025 16:37:48 +0800 Subject: [PATCH 6/6] Minor docstring changes. --- src/inference/hmc.jl | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/src/inference/hmc.jl b/src/inference/hmc.jl index db32a1467..5951aeea1 100644 --- a/src/inference/hmc.jl +++ b/src/inference/hmc.jl @@ -64,13 +64,9 @@ Hamilton's equations are numerically integrated using leapfrog integration with step size `eps` for `L` steps and initial momenta sampled from a Gaussian distribution with covariance given by `metric` (mass matrix). -## `metric` options - Sampling with HMC is improved by using a metric/mass matrix that approximates the **inverse** covariance of the target distribution, and is equivalent to a linear transformation -of the parameter space (see Neal, 2011). - -The following options are supported for `metric`: +of the parameter space (see Neal, 2011). The following options are supported for `metric`: - `nothing` (default): identity matrix - `Vector`: diagonal matrix with the given vector as the diagonal