-
Notifications
You must be signed in to change notification settings - Fork 37
Move predict from Turing
#716
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 15 commits
1c1c907
bdf90b4
c7d08b0
a425c41
41471f6
90d99ca
ea23b7c
76ef40f
304b63e
53b6749
fcd7c3d
3dc742a
30208ec
bf38627
fd1277b
86eab6b
7b172e2
a3fc8b1
da7fa1c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -42,6 +42,156 @@ | |
| return keys(c.info.varname_to_symbol) | ||
| end | ||
|
|
||
| """ | ||
| predict([rng::AbstractRNG,] model::Model, chain::MCMCChains.Chains; include_all=false) | ||
| Sample from the posterior predictive distribution by executing `model` with parameters fixed to each sample | ||
| in `chain`, and return the resulting `Chains`. | ||
| If `include_all` is `false`, the returned `Chains` will contain only those variables that were not fixed by | ||
| the samples in `chain`. This is useful when you want to sample only new variables from the posterior | ||
| predictive distribution. | ||
| # Examples | ||
| ```jldoctest | ||
| julia> using DynamicPPL, AbstractMCMC, AdvancedHMC, ForwardDiff; | ||
| julia> @model function linear_reg(x, y, σ = 0.1) | ||
| β ~ Normal(0, 1) | ||
| for i ∈ eachindex(y) | ||
| y[i] ~ Normal(β * x[i], σ) | ||
| end | ||
| end; | ||
| julia> σ = 0.1; f(x) = 2 * x + 0.1 * randn(); | ||
| julia> Δ = 0.1; xs_train = 0:Δ:10; ys_train = f.(xs_train); | ||
| julia> xs_test = [10 + Δ, 10 + 2 * Δ]; ys_test = f.(xs_test); | ||
| julia> m_train = linear_reg(xs_train, ys_train, σ); | ||
| julia> n_train_logdensity_function = DynamicPPL.LogDensityFunction(m_train, DynamicPPL.VarInfo(m_train)); | ||
| julia> chain_lin_reg = AbstractMCMC.sample(n_train_logdensity_function, NUTS(0.65), 200; chain_type=MCMCChains.Chains, param_names=[:β], discard_initial=100) | ||
| ┌ Info: Found initial step size | ||
| └ ϵ = 0.003125 | ||
| julia> m_test = linear_reg(xs_test, Vector{Union{Missing, Float64}}(undef, length(ys_test)), σ); | ||
| julia> predictions = predict(m_test, chain_lin_reg) | ||
| Object of type Chains, with data of type 100×2×1 Array{Float64,3} | ||
| Iterations = 1:100 | ||
| Thinning interval = 1 | ||
| Chains = 1 | ||
| Samples per chain = 100 | ||
| parameters = y[1], y[2] | ||
| 2-element Array{ChainDataFrame,1} | ||
| Summary Statistics | ||
| parameters mean std naive_se mcse ess r_hat | ||
| ────────── ─────── ────── ──────── ─────── ──────── ────── | ||
| y[1] 20.1974 0.1007 0.0101 missing 101.0711 0.9922 | ||
| y[2] 20.3867 0.1062 0.0106 missing 101.4889 0.9903 | ||
| Quantiles | ||
| parameters 2.5% 25.0% 50.0% 75.0% 97.5% | ||
| ────────── ─────── ─────── ─────── ─────── ─────── | ||
| y[1] 20.0342 20.1188 20.2135 20.2588 20.4188 | ||
| y[2] 20.1870 20.3178 20.3839 20.4466 20.5895 | ||
| julia> ys_pred = vec(mean(Array(group(predictions, :y)); dims = 1)); | ||
| julia> sum(abs2, ys_test - ys_pred) ≤ 0.1 | ||
| true | ||
| ``` | ||
| """ | ||
| function DynamicPPL.predict( | ||
| rng::DynamicPPL.Random.AbstractRNG, | ||
| model::DynamicPPL.Model, | ||
| chain::MCMCChains.Chains; | ||
| include_all=false, | ||
| ) | ||
| parameter_only_chain = MCMCChains.get_sections(chain, :parameters) | ||
| prototypical_varinfo = DynamicPPL.VarInfo(model) | ||
|
|
||
| iters = Iterators.product(1:size(chain, 1), 1:size(chain, 3)) | ||
| predictive_samples = map(iters) do (sample_idx, chain_idx) | ||
| varinfo = deepcopy(prototypical_varinfo) | ||
| DynamicPPL.setval_and_resample!( | ||
| varinfo, parameter_only_chain, sample_idx, chain_idx | ||
| ) | ||
| model(rng, varinfo, DynamicPPL.SampleFromPrior()) | ||
|
|
||
| vals = DynamicPPL.values_as_in_model(model, varinfo) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is actually changing the behavior from Turing.jl's implementation. This will result in also including variables used in
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ooooh nice catch; thanks! Hmm, uncertain if this is desired behavior though 😕
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I saw your issue on We would need to make a minor release of
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. But isn't this the purpose of this PR? To move the
Whether we're using
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Ideally, I would want this PR to do a proper implementation of
what I was trying to say is that, with
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Improving it in a separate PR sounds good, but please create an issue to track @torfjelde's comment. |
||
| varname_vals = mapreduce( | ||
| collect, | ||
| vcat, | ||
| map(DynamicPPL.varname_and_value_leaves, keys(vals), values(vals)), | ||
| ) | ||
|
|
||
| return (varname_and_values=varname_vals, logp=DynamicPPL.getlogp(varinfo)) | ||
| end | ||
|
|
||
| chain_result = reduce( | ||
| MCMCChains.chainscat, | ||
| [ | ||
| _predictive_samples_to_chains(predictive_samples[:, chain_idx]) for | ||
| chain_idx in 1:size(predictive_samples, 2) | ||
| ], | ||
| ) | ||
| parameter_names = if include_all | ||
| MCMCChains.names(chain_result, :parameters) | ||
| else | ||
| filter( | ||
| k -> !(k in MCMCChains.names(parameter_only_chain, :parameters)), | ||
| names(chain_result, :parameters), | ||
| ) | ||
| end | ||
| return chain_result[parameter_names] | ||
| end | ||
|
|
||
| function _predictive_samples_to_arrays(predictive_samples) | ||
| variable_names_set = DynamicPPL.OrderedCollections.OrderedSet{DynamicPPL.VarName}() | ||
|
|
||
| sample_dicts = map(predictive_samples) do sample | ||
| varname_value_pairs = sample.varname_and_values | ||
| varnames = map(first, varname_value_pairs) | ||
| values = map(last, varname_value_pairs) | ||
| for varname in varnames | ||
| push!(variable_names_set, varname) | ||
| end | ||
|
|
||
| return DynamicPPL.OrderedCollections.OrderedDict(zip(varnames, values)) | ||
| end | ||
|
|
||
| variable_names = collect(variable_names_set) | ||
| variable_values = [ | ||
| get(sample_dicts[i], key, missing) for i in eachindex(sample_dicts), | ||
| key in variable_names | ||
| ] | ||
|
|
||
| return variable_names, variable_values | ||
| end | ||
|
|
||
| function _predictive_samples_to_chains(predictive_samples) | ||
| variable_names, variable_values = _predictive_samples_to_arrays(predictive_samples) | ||
| variable_names_symbols = map(Symbol, variable_names) | ||
|
|
||
| internal_parameters = [:lp] | ||
| log_probabilities = reshape([sample.logp for sample in predictive_samples], :, 1) | ||
|
|
||
| parameter_names = [variable_names_symbols; internal_parameters] | ||
| parameter_values = hcat(variable_values, log_probabilities) | ||
| parameter_values = MCMCChains.concretize(parameter_values) | ||
|
|
||
| return MCMCChains.Chains( | ||
| parameter_values, parameter_names, (internals=internal_parameters,) | ||
| ) | ||
| end | ||
|
|
||
| """ | ||
| generated_quantities(model::Model, chain::MCMCChains.Chains) | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1203,6 +1203,22 @@ function Distributions.loglikelihood(model::Model, chain::AbstractMCMC.AbstractC | |
| end | ||
| end | ||
|
|
||
| """ | ||
| predict([rng::AbstractRNG,] model::Model, chain; include_all=false) | ||
|
|
||
| Sample from the posterior predictive distribution by executing `model` with parameters fixed to each sample | ||
| in `chain`. | ||
|
|
||
| If `include_all` is `false`, the returned `Chains` will contain only those variables that were not fixed by | ||
| the samples in `chain`. This is useful when you want to sample only new variables from the posterior | ||
| predictive distribution. | ||
| """ | ||
| function predict(model::Model, chain; include_all=false) | ||
|
||
| # this is only defined in `ext/DynamicPPLMCMCChainsExt.jl` | ||
| # TODO: add other methods for different type of `chain` arguments: e.g., `VarInfo`, `NamedTuple`, and `OrderedDict` | ||
| return predict(Random.default_rng(), model, chain; include_all) | ||
| end | ||
|
||
|
|
||
| """ | ||
| generated_quantities(model::Model, parameters::NamedTuple) | ||
| generated_quantities(model::Model, values, keys) | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -2,6 +2,7 @@ | |
| ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" | ||
| AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001" | ||
| AbstractPPL = "7a57a42e-76ec-4ea3-a279-07e840d6d9cf" | ||
| AdvancedHMC = "0bf59076-c3b1-5ca4-86bd-e02cd72cde3d" | ||
|
||
| Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697" | ||
| Bijectors = "76274a88-744f-5084-9051-94815aaf08c4" | ||
| Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa" | ||
|
|
@@ -32,6 +33,7 @@ AbstractMCMC = "5" | |
| AbstractPPL = "0.8.4, 0.9" | ||
| Accessors = "0.1" | ||
| Bijectors = "0.13.9, 0.14, 0.15" | ||
| AdvancedHMC = "0.3.0, 0.4.0, 0.5.2, 0.6" | ||
| Combinatorics = "1" | ||
| Compat = "4.3.0" | ||
| Distributions = "0.25" | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -7,3 +7,170 @@ | |
| @test size(chain_generated) == (1000, 1) | ||
| @test mean(chain_generated) ≈ 0 atol = 0.1 | ||
| end | ||
|
|
||
| @testset "predict" begin | ||
| DynamicPPL.Random.seed!(100) | ||
|
|
||
| @model function linear_reg(x, y, σ=0.1) | ||
| β ~ Normal(0, 1) | ||
|
|
||
| for i in eachindex(y) | ||
| y[i] ~ Normal(β * x[i], σ) | ||
| end | ||
| end | ||
|
|
||
| @model function linear_reg_vec(x, y, σ=0.1) | ||
| β ~ Normal(0, 1) | ||
| return y ~ MvNormal(β .* x, σ^2 * I) | ||
| end | ||
|
|
||
| f(x) = 2 * x + 0.1 * randn() | ||
|
|
||
| Δ = 0.1 | ||
| xs_train = 0:Δ:10 | ||
| ys_train = f.(xs_train) | ||
| xs_test = [10 + Δ, 10 + 2 * Δ] | ||
| ys_test = f.(xs_test) | ||
|
|
||
| # Infer | ||
| m_lin_reg = linear_reg(xs_train, ys_train) | ||
| chain_lin_reg = sample( | ||
| DynamicPPL.LogDensityFunction(m_lin_reg), | ||
| AdvancedHMC.NUTS(0.65), | ||
|
||
| 1000; | ||
| chain_type=MCMCChains.Chains, | ||
| param_names=[:β], | ||
| discard_initial=100, | ||
| n_adapt=100, | ||
| ) | ||
|
|
||
| # Predict on two last indices | ||
| m_lin_reg_test = linear_reg(xs_test, fill(missing, length(ys_test))) | ||
| predictions = DynamicPPL.predict(m_lin_reg_test, chain_lin_reg) | ||
|
|
||
| ys_pred = vec(mean(Array(group(predictions, :y)); dims=1)) | ||
|
|
||
| # test like this depends on the variance of the posterior | ||
| # this only makes sense if the posterior variance is about 0.002 | ||
| @test sum(abs2, ys_test - ys_pred) ≤ 0.1 | ||
|
|
||
| # Ensure that `rng` is respected | ||
| predictions1 = let rng = MersenneTwister(42) | ||
| DynamicPPL.predict(rng, m_lin_reg_test, chain_lin_reg[1:2]) | ||
| end | ||
| predictions2 = let rng = MersenneTwister(42) | ||
| DynamicPPL.predict(rng, m_lin_reg_test, chain_lin_reg[1:2]) | ||
| end | ||
| @test all(Array(predictions1) .== Array(predictions2)) | ||
|
|
||
| # Predict on two last indices for vectorized | ||
| m_lin_reg_test = linear_reg_vec(xs_test, missing) | ||
| predictions_vec = DynamicPPL.predict(m_lin_reg_test, chain_lin_reg) | ||
| ys_pred_vec = vec(mean(Array(group(predictions_vec, :y)); dims=1)) | ||
|
|
||
| @test sum(abs2, ys_test - ys_pred_vec) ≤ 0.1 | ||
|
|
||
| # Multiple chains | ||
| chain_lin_reg = sample( | ||
| DynamicPPL.LogDensityFunction(m_lin_reg, DynamicPPL.VarInfo(m_lin_reg)), | ||
| AdvancedHMC.NUTS(0.65), | ||
| MCMCThreads(), | ||
| 1000, | ||
| 2; | ||
| chain_type=MCMCChains.Chains, | ||
| param_names=[:β], | ||
| discard_initial=100, | ||
| n_adapt=100, | ||
| ) | ||
| m_lin_reg_test = linear_reg(xs_test, fill(missing, length(ys_test))) | ||
| predictions = DynamicPPL.predict(m_lin_reg_test, chain_lin_reg) | ||
|
|
||
| @test size(chain_lin_reg, 3) == size(predictions, 3) | ||
|
|
||
| for chain_idx in MCMCChains.chains(chain_lin_reg) | ||
| ys_pred = vec(mean(Array(group(predictions[:, :, chain_idx], :y)); dims=1)) | ||
| @test sum(abs2, ys_test - ys_pred) ≤ 0.1 | ||
| end | ||
|
|
||
| # Predict on two last indices for vectorized | ||
| m_lin_reg_test = linear_reg_vec(xs_test, missing) | ||
| predictions_vec = DynamicPPL.predict(m_lin_reg_test, chain_lin_reg) | ||
|
|
||
| for chain_idx in MCMCChains.chains(chain_lin_reg) | ||
| ys_pred_vec = vec(mean(Array(group(predictions_vec[:, :, chain_idx], :y)); dims=1)) | ||
| @test sum(abs2, ys_test - ys_pred_vec) ≤ 0.1 | ||
| end | ||
|
|
||
| # https://github.com/TuringLang/Turing.jl/issues/1352 | ||
| @model function simple_linear1(x, y) | ||
| intercept ~ Normal(0, 1) | ||
| coef ~ MvNormal(zeros(2), I) | ||
| coef = reshape(coef, 1, size(x, 1)) | ||
|
|
||
| mu = vec(intercept .+ coef * x) | ||
| error ~ truncated(Normal(0, 1), 0, Inf) | ||
| return y ~ MvNormal(mu, error^2 * I) | ||
| end | ||
|
|
||
| @model function simple_linear2(x, y) | ||
| intercept ~ Normal(0, 1) | ||
| coef ~ filldist(Normal(0, 1), 2) | ||
| coef = reshape(coef, 1, size(x, 1)) | ||
|
|
||
| mu = vec(intercept .+ coef * x) | ||
| error ~ truncated(Normal(0, 1), 0, Inf) | ||
| return y ~ MvNormal(mu, error^2 * I) | ||
| end | ||
|
|
||
| @model function simple_linear3(x, y) | ||
| intercept ~ Normal(0, 1) | ||
| coef = Vector(undef, 2) | ||
| for i in axes(coef, 1) | ||
| coef[i] ~ Normal(0, 1) | ||
| end | ||
| coef = reshape(coef, 1, size(x, 1)) | ||
|
|
||
| mu = vec(intercept .+ coef * x) | ||
| error ~ truncated(Normal(0, 1), 0, Inf) | ||
| return y ~ MvNormal(mu, error^2 * I) | ||
| end | ||
|
|
||
| @model function simple_linear4(x, y) | ||
| intercept ~ Normal(0, 1) | ||
| coef1 ~ Normal(0, 1) | ||
| coef2 ~ Normal(0, 1) | ||
| coef = [coef1, coef2] | ||
| coef = reshape(coef, 1, size(x, 1)) | ||
|
|
||
| mu = vec(intercept .+ coef * x) | ||
| error ~ truncated(Normal(0, 1), 0, Inf) | ||
| return y ~ MvNormal(mu, error^2 * I) | ||
| end | ||
|
|
||
| x = randn(2, 100) | ||
| y = [1 + 2 * a + 3 * b for (a, b) in eachcol(x)] | ||
|
|
||
| param_names = Dict( | ||
| simple_linear1 => [:intercept, Symbol("coef[1]"), Symbol("coef[2]"), :error], | ||
| simple_linear2 => [:intercept, Symbol("coef[1]"), Symbol("coef[2]"), :error], | ||
| simple_linear3 => [:intercept, Symbol("coef[1]"), Symbol("coef[2]"), :error], | ||
| simple_linear4 => [:intercept, :coef1, :coef2, :error], | ||
| ) | ||
| @testset "$model" for model in | ||
| [simple_linear1, simple_linear2, simple_linear3, simple_linear4] | ||
| m = model(x, y) | ||
| chain = sample( | ||
| DynamicPPL.LogDensityFunction(m), | ||
| AdvancedHMC.NUTS(0.65), | ||
| 400; | ||
| initial_params=rand(4), | ||
| chain_type=MCMCChains.Chains, | ||
| param_names=param_names[model], | ||
| discard_initial=100, | ||
| n_adapt=100, | ||
| ) | ||
| chain_predict = DynamicPPL.predict(model(x, missing), chain) | ||
| mean_prediction = [mean(chain_predict["y[$i]"].data) for i in 1:length(y)] | ||
| @test mean(abs2, mean_prediction - y) ≤ 1e-3 | ||
| end | ||
| end | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same here: no need to use
AdvancedHMC(or any of the other packages), just construct theChainsby hand.This also doesn't actually show that you need to import
MCMCChainsfor this to work, which might be a good idea