From 1742b9b8a5a77cf21b0efcb46dac4045fedcf92a Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Tue, 30 Sep 2025 18:17:15 +0100 Subject: [PATCH 1/3] Remove resume_from --- HISTORY.md | 5 +++++ src/DynamicPPL.jl | 2 ++ src/sampler.jl | 27 +++++++-------------------- test/sampler.jl | 26 +++----------------------- 4 files changed, 17 insertions(+), 43 deletions(-) diff --git a/HISTORY.md b/HISTORY.md index f69c4a6fd..29bc56493 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -54,6 +54,11 @@ The separation of these functions was primarily implemented to avoid performing Previously `VarInfo` (or more correctly, the `Metadata` object within a `VarInfo`), had a flag called `"del"` for all variables. If it was set to `true` the variable was to be overwritten with a new value at the next evaluation. The new `InitContext` and related changes above make this flag unnecessary, and it has been removed. +### Removal of `resume_from` + +The `resume_from=chn` keyword argument to `sample` has been removed; please use `initial_state=DynamicPPL.loadstate(chn)` instead. +`loadstate` is exported from DynamicPPL. + **Other changes** ### `setleafcontext(model, context)` diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index 31adadb55..43180b091 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -130,6 +130,8 @@ export AbstractVarInfo, prefix, returned, to_submodel, + # Chain save/resume + loadstate, # Convenience macros @addlogprob!, value_iterator_from_chain, diff --git a/src/sampler.jl b/src/sampler.jl index c598e13f5..902e6a9c5 100644 --- a/src/sampler.jl +++ b/src/sampler.jl @@ -58,17 +58,15 @@ function AbstractMCMC.sample( model::Model, sampler::Sampler, N::Integer; - chain_type=default_chain_type(sampler), - resume_from=nothing, initial_params=init_strategy(sampler), - initial_state=loadstate(resume_from), + initial_state=nothing, kwargs..., ) if hasproperty(kwargs, :initial_parameters) @warn "The `initial_parameters` keyword argument is not recognised; please use `initial_params` instead." end return AbstractMCMC.mcmcsample( - rng, model, sampler, N; chain_type, initial_params, initial_state, kwargs... + rng, model, sampler, N; initial_params, initial_state, kwargs... ) end @@ -79,10 +77,8 @@ function AbstractMCMC.sample( parallel::AbstractMCMC.AbstractMCMCEnsemble, N::Integer, nchains::Integer; - chain_type=default_chain_type(sampler), initial_params=fill(init_strategy(sampler), nchains), - resume_from=nothing, - initial_state=loadstate(resume_from), + initial_state=nothing, kwargs..., ) if hasproperty(kwargs, :initial_parameters) @@ -95,7 +91,6 @@ function AbstractMCMC.sample( parallel, N, nchains; - chain_type, initial_params, initial_state, kwargs..., @@ -124,20 +119,12 @@ function AbstractMCMC.step( end """ - loadstate(data) + loadstate(chain::AbstractChains) -Load sampler state from `data`. - -By default, `data` is returned. -""" -loadstate(data) = data - -""" - default_chain_type(sampler) - -Default type of the chain of posterior samples from `sampler`. +Load sampler state from an `AbstractChains` object. This function should be overloaded by a +concrete Chains implementation. """ -default_chain_type(::Sampler) = Any +function loadstate end """ initialstep(rng, model, sampler, varinfo; kwargs...) diff --git a/test/sampler.jl b/test/sampler.jl index 5380ad17e..d19f32f94 100644 --- a/test/sampler.jl +++ b/test/sampler.jl @@ -12,7 +12,7 @@ @test AbstractMCMC.step(Xoshiro(468), g(), spl) isa Any end - @testset "initial_state and resume_from kwargs" begin + @testset "initial_state" begin # Model is unused, but has to be a DynamicPPL.Model otherwise we won't hit our # overloaded method. @model f() = x ~ Normal() @@ -58,19 +58,10 @@ spl, N_iters; progress=false, - initial_state=chn.info.samplerstate, + initial_state=DynamicPPL.loadstate(chn), chain_type=MCMCChains.Chains, ) @test all(chn2[:x] .== initial_value) - # using `resume_from` - chn3 = sample( - model, - spl, - N_iters; - progress=false, - resume_from=chn, - chain_type=MCMCChains.Chains, - ) @test all(chn3[:x] .== initial_value) end @@ -94,21 +85,10 @@ N_iters, N_chains; progress=false, - initial_state=chn.info.samplerstate, + initial_state=DynamicPPL.loadstate(chn), chain_type=MCMCChains.Chains, ) @test all(i -> chn2[:x][i, :] == initial_value, 1:N_iters) - # using `resume_from` - chn3 = sample( - model, - spl, - MCMCThreads(), - N_iters, - N_chains; - progress=false, - resume_from=chn, - chain_type=MCMCChains.Chains, - ) @test all(i -> chn3[:x][i, :] == initial_value, 1:N_iters) end end From 765c7ee4c5d270b4ab84f1cb3e723b7c85614921 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Tue, 30 Sep 2025 18:26:28 +0100 Subject: [PATCH 2/3] Format --- src/sampler.jl | 10 +--------- 1 file changed, 1 insertion(+), 9 deletions(-) diff --git a/src/sampler.jl b/src/sampler.jl index 902e6a9c5..01f056053 100644 --- a/src/sampler.jl +++ b/src/sampler.jl @@ -85,15 +85,7 @@ function AbstractMCMC.sample( @warn "The `initial_parameters` keyword argument is not recognised; please use `initial_params` instead." end return AbstractMCMC.mcmcsample( - rng, - model, - sampler, - parallel, - N, - nchains; - initial_params, - initial_state, - kwargs..., + rng, model, sampler, parallel, N, nchains; initial_params, initial_state, kwargs... ) end From c672e3adecf6e6409decd55bd1cdad90631c65ed Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Tue, 30 Sep 2025 18:30:25 +0100 Subject: [PATCH 3/3] Fix test --- test/sampler.jl | 4 ---- 1 file changed, 4 deletions(-) diff --git a/test/sampler.jl b/test/sampler.jl index d19f32f94..8be54901d 100644 --- a/test/sampler.jl +++ b/test/sampler.jl @@ -52,7 +52,6 @@ chn = sample(model, spl, N_iters; progress=false, chain_type=MCMCChains.Chains) initial_value = chn[:x][1] @test all(chn[:x] .== initial_value) # sanity check - # using `initial_state` chn2 = sample( model, spl, @@ -62,7 +61,6 @@ chain_type=MCMCChains.Chains, ) @test all(chn2[:x] .== initial_value) - @test all(chn3[:x] .== initial_value) end @testset "multiple-chain sampling" begin @@ -77,7 +75,6 @@ ) initial_value = chn[:x][1, :] @test all(i -> chn[:x][i, :] == initial_value, 1:N_iters) # sanity check - # using `initial_state` chn2 = sample( model, spl, @@ -89,7 +86,6 @@ chain_type=MCMCChains.Chains, ) @test all(i -> chn2[:x][i, :] == initial_value, 1:N_iters) - @test all(i -> chn3[:x][i, :] == initial_value, 1:N_iters) end end