diff --git a/.github/workflows/Enzyme.yml b/.github/workflows/Enzyme.yml index 36f11b914..8074095f8 100644 --- a/.github/workflows/Enzyme.yml +++ b/.github/workflows/Enzyme.yml @@ -18,7 +18,7 @@ concurrency: jobs: enzyme: - runs-on: ubuntu-latest + runs-on: macos-latest steps: - uses: actions/checkout@v5 @@ -27,9 +27,19 @@ jobs: version: "1.11" - uses: julia-actions/cache@v2 + id: julia-cache - name: Run AD with Enzyme on demo models working-directory: test/integration/enzyme run: | julia --project=. --color=yes -e 'using Pkg; Pkg.instantiate()' julia --project=. --color=yes main.jl + + - name: Save Julia depot cache on cancel or failure + id: julia-cache-save + if: cancelled() || failure() + uses: actions/cache/save@v4 + with: + path: | + ${{ steps.julia-cache.outputs.cache-paths }} + key: ${{ steps.julia-cache.outputs.cache-key }} diff --git a/HISTORY.md b/HISTORY.md index 24c0df3d0..777f3f32c 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -1,5 +1,54 @@ # DynamicPPL Changelog +## 0.39.0 + +### Breaking changes + +#### Fast Log Density Functions + +This version provides a reimplementation of `LogDensityFunction` that provides performance improvements on the order of 2–10× for both model evaluation as well as automatic differentiation. +Exact speedups depend on the model size: larger models have less significant speedups because the bulk of the work is done in calls to `logpdf`. + +For more information about how this is accomplished, please see https://github.com/TuringLang/DynamicPPL.jl/pull/1113 as well as the `src/fasteval.jl` file, which contains extensive comments. + +As a result of this change, `LogDensityFunction` no longer stores a VarInfo inside it. +In general, if `ldf` is a `LogDensityFunction`, it is now only valid to access `ldf.model` and `ldf.adtype`. +If you were previously relying on this behaviour, you will need to store a VarInfo separately. + +Along with this change, DynamicPPL now exposes the `fast_evaluate!!` method which allows you to hook into this 'fast evaluation' pipeline directly. +Please see the documentation for details. + +#### Parent and leaf contexts + +The `DynamicPPL.NodeTrait` function has been removed. +Instead of implementing this, parent contexts should subtype `DynamicPPL.AbstractParentContext`. +This is an abstract type which requires you to overload two functions, `DynamicPPL.childcontext` and `DynamicPPL.setchildcontext`. + +There should generally be few reasons to define your own parent contexts (the only one we are aware of, outside of DynamicPPL itself, is `Turing.Inference.GibbsContext`), so this change should not really affect users. + +Leaf contexts require no changes, apart from a removal of the `NodeTrait` function. + +`ConditionContext` and `PrefixContext` are no longer exported. +You should not need to use these directly, please use `AbstractPPL.condition` and `DynamicPPL.prefix` instead. + +#### SimpleVarInfo + +`SimpleVarInfo` has been removed. +Its main purpose was for evaluating models rapidly. +However, `fast_evaluate!!` provides a cleaner way of doing this. +In particular, if you want to evaluate a model at a given set of parameters, you can do: + +```julia +retval, vi = DynamicPPL.fast_evaluate!!(rng, model, InitFromParams(params), accs) +``` + +#### Miscellaneous + +Removed the method `returned(::Model, values, keys)`; please use `returned(::Model, ::AbstractDict{<:VarName})` instead. + +The method `DynamicPPL.init` (for implementing `AbstractInitStrategy`) now has a different signature: it must return a tuple of the generated value, plus a transform function that maps it back to unlinked space. +This is a generalisation of the previous behaviour, where `init` would always return an unlinked value (in effect forcing the transform to be the identity function). + ## 0.38.9 Remove warning when using Enzyme as the AD backend. diff --git a/Project.toml b/Project.toml index 23f5eec5b..1b5e52492 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "DynamicPPL" uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8" -version = "0.38.9" +version = "0.39.0" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" diff --git a/benchmarks/Project.toml b/benchmarks/Project.toml index c154c5ca5..523889a7a 100644 --- a/benchmarks/Project.toml +++ b/benchmarks/Project.toml @@ -24,7 +24,7 @@ DynamicPPL = {path = "../"} ADTypes = "1.14.0" Chairmarks = "1.3.1" Distributions = "0.25.117" -DynamicPPL = "0.38" +DynamicPPL = "0.39" Enzyme = "0.13" ForwardDiff = "1" JSON = "1.3.0" diff --git a/benchmarks/benchmarks.jl b/benchmarks/benchmarks.jl index 3af6573cf..ba3439986 100644 --- a/benchmarks/benchmarks.jl +++ b/benchmarks/benchmarks.jl @@ -68,9 +68,7 @@ function run(; to_json=false) false, ), ("Smorgasbord", smorgasbord_instance, :typed, :forwarddiff, false), - ("Smorgasbord", smorgasbord_instance, :simple_namedtuple, :forwarddiff, true), ("Smorgasbord", smorgasbord_instance, :untyped, :forwarddiff, true), - ("Smorgasbord", smorgasbord_instance, :simple_dict, :forwarddiff, true), ("Smorgasbord", smorgasbord_instance, :typed_vector, :forwarddiff, true), ("Smorgasbord", smorgasbord_instance, :untyped_vector, :forwarddiff, true), ("Smorgasbord", smorgasbord_instance, :typed, :reversediff, true), @@ -98,12 +96,15 @@ function run(; to_json=false) }[] for (model_name, model, varinfo_choice, adbackend, islinked) in chosen_combinations - @info "Running benchmark for $model_name" + @info "Running benchmark for $model_name, $varinfo_choice, $adbackend, $islinked" relative_eval_time, relative_ad_eval_time = try results = benchmark(model, varinfo_choice, adbackend, islinked) + @info " t(eval) = $(results.primal_time)" + @info " t(grad) = $(results.grad_time)" (results.primal_time / reference_time), (results.grad_time / results.primal_time) catch e + @info "benchmark errored: $e" missing, missing end push!( @@ -155,18 +156,33 @@ function combine(head_filename::String, base_filename::String) all_testcases = union(Set(keys(head_testcases)), Set(keys(base_testcases))) @info "$(length(all_testcases)) unique test cases found" sorted_testcases = sort( - collect(all_testcases); by=(c -> (c.model_name, c.ad_backend, c.varinfo, c.linked)) + collect(all_testcases); by=(c -> (c.model_name, c.linked, c.varinfo, c.ad_backend)) ) results_table = Tuple{ - String,Int,String,String,Bool,String,String,String,String,String,String + String, + Int, + String, + String, + Bool, + String, + String, + String, + String, + String, + String, + String, + String, + String, }[] + sublabels = ["base", "this PR", "speedup"] results_colnames = [ [ EmptyCells(5), MultiColumn(3, "t(eval) / t(ref)"), MultiColumn(3, "t(grad) / t(eval)"), + MultiColumn(3, "t(grad) / t(ref)"), ], - [colnames[1:5]..., "base", "this PR", "speedup", "base", "this PR", "speedup"], + [colnames[1:5]..., sublabels..., sublabels..., sublabels...], ] sprint_float(x::Float64) = @sprintf("%.2f", x) sprint_float(m::Missing) = "err" @@ -183,6 +199,10 @@ function combine(head_filename::String, base_filename::String) # Finally that lets us do this division safely speedup_eval = base_eval / head_eval speedup_grad = base_grad / head_grad + # As well as this multiplication, which is t(grad) / t(ref) + head_grad_vs_ref = head_grad * head_eval + base_grad_vs_ref = base_grad * base_eval + speedup_grad_vs_ref = base_grad_vs_ref / head_grad_vs_ref push!( results_table, ( @@ -197,6 +217,9 @@ function combine(head_filename::String, base_filename::String) sprint_float(base_grad), sprint_float(head_grad), sprint_float(speedup_grad), + sprint_float(base_grad_vs_ref), + sprint_float(head_grad_vs_ref), + sprint_float(speedup_grad_vs_ref), ), ) end @@ -212,7 +235,10 @@ function combine(head_filename::String, base_filename::String) backend=:text, fit_table_in_display_horizontally=false, fit_table_in_display_vertically=false, - table_format=TextTableFormat(; horizontal_line_at_merged_column_labels=true), + table_format=TextTableFormat(; + horizontal_line_at_merged_column_labels=true, + horizontal_lines_at_data_rows=collect(3:3:length(results_table)), + ), ) println("```") end diff --git a/benchmarks/src/DynamicPPLBenchmarks.jl b/benchmarks/src/DynamicPPLBenchmarks.jl index 0dc7ece6e..00d2e071b 100644 --- a/benchmarks/src/DynamicPPLBenchmarks.jl +++ b/benchmarks/src/DynamicPPLBenchmarks.jl @@ -1,6 +1,6 @@ module DynamicPPLBenchmarks -using DynamicPPL: VarInfo, SimpleVarInfo, VarName +using DynamicPPL: VarInfo, VarName using DynamicPPL: DynamicPPL using DynamicPPL.TestUtils.AD: run_ad, NoTest using ADTypes: ADTypes @@ -60,8 +60,6 @@ and AD backend. Available varinfo choices: • `:untyped` → uses `DynamicPPL.untyped_varinfo(model)` • `:typed` → uses `DynamicPPL.typed_varinfo(model)` - • `:simple_namedtuple` → uses `SimpleVarInfo{Float64}(model())` - • `:simple_dict` → builds a `SimpleVarInfo{Float64}` from a Dict (pre-populated with the model’s outputs) The AD backend should be specified as a Symbol (e.g. `:forwarddiff`, `:reversediff`, `:zygote`). @@ -76,12 +74,6 @@ function benchmark(model, varinfo_choice::Symbol, adbackend::Symbol, islinked::B DynamicPPL.untyped_varinfo(rng, model) elseif varinfo_choice == :typed DynamicPPL.typed_varinfo(rng, model) - elseif varinfo_choice == :simple_namedtuple - SimpleVarInfo{Float64}(model(rng)) - elseif varinfo_choice == :simple_dict - retvals = model(rng) - vns = [VarName{k}() for k in keys(retvals)] - SimpleVarInfo{Float64}(Dict(zip(vns, values(retvals)))) elseif varinfo_choice == :typed_vector DynamicPPL.typed_vector_varinfo(rng, model) elseif varinfo_choice == :untyped_vector diff --git a/benchmarks/src/Models.jl b/benchmarks/src/Models.jl index 2c881aa95..76d4b2e93 100644 --- a/benchmarks/src/Models.jl +++ b/benchmarks/src/Models.jl @@ -2,7 +2,7 @@ Models for benchmarking Turing.jl. Each model returns a NamedTuple of all the random variables in the model that are not -observed (this is used for constructing SimpleVarInfos). +observed. """ module Models diff --git a/docs/Project.toml b/docs/Project.toml index 03a3ff0a0..10a4a5c8a 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -21,7 +21,7 @@ Accessors = "0.1" Distributions = "0.25" Documenter = "1" DocumenterMermaid = "0.1, 0.2" -DynamicPPL = "0.38" +DynamicPPL = "0.39" FillArrays = "0.13, 1" ForwardDiff = "0.10, 1" JET = "0.9, 0.10, 0.11" diff --git a/docs/src/api.md b/docs/src/api.md index bbe39fb73..a3b9e2fdc 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -66,6 +66,13 @@ The [LogDensityProblems.jl](https://github.com/tpapp/LogDensityProblems.jl) inte LogDensityFunction ``` +Internally, this is accomplished using: + +```@docs +OnlyAccsVarInfo +fast_evaluate!! +``` + ## Condition and decondition A [`Model`](@ref) can be conditioned on a set of observations with [`AbstractPPL.condition`](@ref) or its alias [`|`](@ref). @@ -170,6 +177,12 @@ DynamicPPL.prefix ## Utilities +`typed_identity` is the same as `identity`, but with an overload for `with_logabsdet_jacobian` that ensures that it never errors. + +```@docs +typed_identity +``` + It is possible to manually increase (or decrease) the accumulated log likelihood or prior from within a model function. ```@docs @@ -346,19 +359,6 @@ set_transformed!! Base.empty! ``` -#### `SimpleVarInfo` - -```@docs -SimpleVarInfo -``` - -### Tilde-pipeline - -```@docs -tilde_assume!! -tilde_observe!! -``` - ### Accumulators The subtypes of [`AbstractVarInfo`](@ref) store the cumulative log prior and log likelihood, and sometimes other variables that change during executing, in what are called accumulators. @@ -463,15 +463,48 @@ By default, it does not perform any actual sampling: it only evaluates the model If you wish to sample new values, see the section on [VarInfo initialisation](#VarInfo-initialisation) just below this. The behaviour of a model execution can be changed with evaluation contexts, which are a field of the model. -Contexts are subtypes of `AbstractPPL.AbstractContext`. + +All contexts are subtypes of `AbstractPPL.AbstractContext`. + +Contexts are split into two kinds: + +**Leaf contexts**: These are the most important contexts as they ultimately decide how model evaluation proceeds. +For example, `DefaultContext` evaluates the model using values stored inside a VarInfo's metadata, whereas `InitContext` obtains new values either by sampling or from a known set of parameters. +DynamicPPL has more leaf contexts which are used for internal purposes, but these are the two that are exported. ```@docs DefaultContext -PrefixContext -ConditionContext InitContext ``` +To implement a leaf context, you need to subtype `AbstractPPL.AbstractContext` and implement the `tilde_assume!!` and `tilde_observe!!` methods for your context. + +```@docs +tilde_assume!! +tilde_observe!! +``` + +**Parent contexts**: These essentially act as 'modifiers' for leaf contexts. +For example, `PrefixContext` adds a prefix to all variable names during evaluation, while `ConditionContext` marks certain variables as observed. + +To implement a parent context, you have to subtype `DynamicPPL.AbstractParentContext`, and implement the `childcontext` and `setchildcontext` methods. +If needed, you can also implement `tilde_assume!!` and `tilde_observe!!` for your context. +This is optional; the default implementation is to simply delegate to the child context. + +```@docs +AbstractParentContext +childcontext +setchildcontext +``` + +Since contexts form a tree structure, these functions are automatically defined for manipulating context stacks. +They are mainly useful for modifying the fundamental behaviour (i.e. the leaf context), without affecting any of the modifiers (i.e. parent contexts). + +```@docs +leafcontext +setleafcontext +``` + ### VarInfo initialisation The function `init!!` is used to initialise, or overwrite, values in a VarInfo. @@ -491,10 +524,12 @@ InitFromParams ``` If you wish to write your own, you have to subtype [`DynamicPPL.AbstractInitStrategy`](@ref) and implement the `init` method. +In very rare situations, you may also need to implement `get_param_eltype`, which defines the element type of the parameters generated by the strategy. ```@docs -DynamicPPL.AbstractInitStrategy -DynamicPPL.init +AbstractInitStrategy +init +get_param_eltype ``` ### Choosing a suitable VarInfo diff --git a/ext/DynamicPPLEnzymeCoreExt.jl b/ext/DynamicPPLEnzymeCoreExt.jl index 35159636f..ef21c255b 100644 --- a/ext/DynamicPPLEnzymeCoreExt.jl +++ b/ext/DynamicPPLEnzymeCoreExt.jl @@ -1,16 +1,15 @@ module DynamicPPLEnzymeCoreExt -if isdefined(Base, :get_extension) - using DynamicPPL: DynamicPPL - using EnzymeCore -else - using ..DynamicPPL: DynamicPPL - using ..EnzymeCore -end +using DynamicPPL: DynamicPPL +using EnzymeCore # Mark is_transformed as having 0 derivative. The `nothing` return value is not significant, Enzyme # only checks whether such a method exists, and never runs it. @inline EnzymeCore.EnzymeRules.inactive(::typeof(DynamicPPL.is_transformed), args...) = nothing +# Likewise for get_range_and_linked. +@inline EnzymeCore.EnzymeRules.inactive( + ::typeof(DynamicPPL._get_range_and_linked), args... +) = nothing end diff --git a/ext/DynamicPPLMCMCChainsExt.jl b/ext/DynamicPPLMCMCChainsExt.jl index d8c343917..e74f0b8a9 100644 --- a/ext/DynamicPPLMCMCChainsExt.jl +++ b/ext/DynamicPPLMCMCChainsExt.jl @@ -140,6 +140,43 @@ function AbstractMCMC.to_samples( end end +function AbstractMCMC.bundle_samples( + ts::Vector{<:DynamicPPL.ParamsWithStats}, + model::DynamicPPL.Model, + spl::AbstractMCMC.AbstractSampler, + state, + chain_type::Type{MCMCChains.Chains}; + save_state=false, + stats=missing, + sort_chain=false, + discard_initial=0, + thinning=1, + kwargs..., +) + bare_chain = AbstractMCMC.from_samples(MCMCChains.Chains, reshape(ts, :, 1)) + + # Add additional MCMC-specific info + info = bare_chain.info + if save_state + info = merge(info, (model=model, sampler=spl, samplerstate=state)) + end + if !ismissing(stats) + info = merge(info, (start_time=stats.start, stop_time=stats.stop)) + end + + # Reconstruct the chain with the extra information + # Yeah, this is quite ugly. Blame MCMCChains. + chain = MCMCChains.Chains( + bare_chain.value.data, + names(bare_chain), + bare_chain.name_map; + info=info, + start=discard_initial + 1, + thin=thinning, + ) + return sort_chain ? sort(chain) : chain +end + """ predict([rng::AbstractRNG,] model::Model, chain::MCMCChains.Chains; include_all=false) diff --git a/ext/DynamicPPLMarginalLogDensitiesExt.jl b/ext/DynamicPPLMarginalLogDensitiesExt.jl index 2155fa161..8b3040757 100644 --- a/ext/DynamicPPLMarginalLogDensitiesExt.jl +++ b/ext/DynamicPPLMarginalLogDensitiesExt.jl @@ -6,8 +6,13 @@ using MarginalLogDensities: MarginalLogDensities # A thin wrapper to adapt a DynamicPPL.LogDensityFunction to the interface expected by # MarginalLogDensities. It's helpful to have a struct so that we can dispatch on its type # below. -struct LogDensityFunctionWrapper{L<:DynamicPPL.LogDensityFunction} +struct LogDensityFunctionWrapper{ + L<:DynamicPPL.LogDensityFunction,V<:DynamicPPL.AbstractVarInfo +} logdensity::L + # This field is used only to reconstruct the VarInfo later on; it's not needed for the + # actual log-density evaluation. + varinfo::V end function (lw::LogDensityFunctionWrapper)(x, _) return LogDensityProblems.logdensity(lw.logdensity, x) @@ -101,7 +106,7 @@ function DynamicPPL.marginalize( # Construct the marginal log-density model. f = DynamicPPL.LogDensityFunction(model, getlogprob, varinfo) mld = MarginalLogDensities.MarginalLogDensity( - LogDensityFunctionWrapper(f), varinfo[:], varindices, (), method; kwargs... + LogDensityFunctionWrapper(f, varinfo), varinfo[:], varindices, (), method; kwargs... ) return mld end @@ -190,7 +195,7 @@ function DynamicPPL.VarInfo( unmarginalized_params::Union{AbstractVector,Nothing}=nothing, ) # Extract the original VarInfo. Its contents will in general be junk. - original_vi = mld.logdensity.logdensity.varinfo + original_vi = mld.logdensity.varinfo # Extract the stored parameters, which includes the modes for any marginalized # parameters full_params = MarginalLogDensities.cached_params(mld) diff --git a/ext/DynamicPPLMooncakeExt.jl b/ext/DynamicPPLMooncakeExt.jl index 23a3430eb..8adf66030 100644 --- a/ext/DynamicPPLMooncakeExt.jl +++ b/ext/DynamicPPLMooncakeExt.jl @@ -5,5 +5,8 @@ using Mooncake: Mooncake # This is purely an optimisation. Mooncake.@zero_derivative Mooncake.DefaultCtx Tuple{typeof(is_transformed),Vararg} +Mooncake.@zero_derivative Mooncake.DefaultCtx Tuple{ + typeof(DynamicPPL._get_range_and_linked),Vararg +} end # module diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index e66f3fe11..e002c5f42 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -46,7 +46,6 @@ import Base: # VarInfo export AbstractVarInfo, VarInfo, - SimpleVarInfo, AbstractAccumulator, LogLikelihoodAccumulator, LogPriorAccumulator, @@ -84,30 +83,39 @@ export AbstractVarInfo, # Compiler @model, # Utilities - init, OrderedDict, + typed_identity, # Model Model, getmissings, getargnames, extract_priors, values_as_in_model, - # LogDensityFunction + # LogDensityFunction and fasteval LogDensityFunction, - # Contexts + fast_evaluate!!, + OnlyAccsVarInfo, + # Leaf contexts + AbstractContext, contextualize, DefaultContext, - PrefixContext, - ConditionContext, + InitContext, + # Parent contexts + AbstractParentContext, + childcontext, + setchildcontext, + leafcontext, + setleafcontext, # Tilde pipeline tilde_assume!!, tilde_observe!!, # Initialisation - InitContext, AbstractInitStrategy, InitFromPrior, InitFromUniform, InitFromParams, + init, + get_param_eltype, # Pseudo distributions NamedDist, NoDist, @@ -165,7 +173,7 @@ Abstract supertype for data structures that capture random variables when execut probabilistic model and accumulate log densities such as the log likelihood or the log joint probability of the model. -See also: [`VarInfo`](@ref), [`SimpleVarInfo`](@ref). +See also: [`VarInfo`](@ref). """ abstract type AbstractVarInfo <: AbstractModelTrace end @@ -187,13 +195,14 @@ include("default_accumulators.jl") include("abstract_varinfo.jl") include("threadsafe.jl") include("varinfo.jl") -include("simple_varinfo.jl") +include("onlyaccs.jl") include("compiler.jl") include("pointwise_logdensities.jl") -include("logdensityfunction.jl") +include("fasteval.jl") include("model_utils.jl") include("extract_priors.jl") include("values_as_in_model.jl") +include("experimental.jl") include("chains.jl") include("bijector.jl") @@ -201,7 +210,6 @@ include("debug_utils.jl") using .DebugUtils include("test_utils.jl") -include("experimental.jl") include("deprecated.jl") if isdefined(Base.Experimental, :register_error_hint) diff --git a/src/abstract_varinfo.jl b/src/abstract_varinfo.jl index ec5e1ea10..14528522b 100644 --- a/src/abstract_varinfo.jl +++ b/src/abstract_varinfo.jl @@ -502,52 +502,6 @@ If no `Type` is provided, return values as stored in `varinfo`. # Examples -`SimpleVarInfo` with `NamedTuple`: - -```jldoctest -julia> data = (x = 1.0, m = [2.0]); - -julia> values_as(SimpleVarInfo(data)) -(x = 1.0, m = [2.0]) - -julia> values_as(SimpleVarInfo(data), NamedTuple) -(x = 1.0, m = [2.0]) - -julia> values_as(SimpleVarInfo(data), OrderedDict) -OrderedDict{VarName{sym, typeof(identity)} where sym, Any} with 2 entries: - x => 1.0 - m => [2.0] - -julia> values_as(SimpleVarInfo(data), Vector) -2-element Vector{Float64}: - 1.0 - 2.0 -``` - -`SimpleVarInfo` with `OrderedDict`: - -```jldoctest -julia> data = OrderedDict{Any,Any}(@varname(x) => 1.0, @varname(m) => [2.0]); - -julia> values_as(SimpleVarInfo(data)) -OrderedDict{Any, Any} with 2 entries: - x => 1.0 - m => [2.0] - -julia> values_as(SimpleVarInfo(data), NamedTuple) -(x = 1.0, m = [2.0]) - -julia> values_as(SimpleVarInfo(data), OrderedDict) -OrderedDict{Any, Any} with 2 entries: - x => 1.0 - m => [2.0] - -julia> values_as(SimpleVarInfo(data), Vector) -2-element Vector{Float64}: - 1.0 - 2.0 -``` - `VarInfo` with `NamedTuple` of `Metadata`: ```jldoctest @@ -828,8 +782,8 @@ function link!!(vi::AbstractVarInfo, vns::VarNameTuple, model::Model) return link!!(default_transformation(model, vi), vi, vns, model) end function link!!(t::DynamicTransformation, vi::AbstractVarInfo, model::Model) - # Note that in practice this method is only called for SimpleVarInfo, because VarInfo - # has a dedicated implementation + # Note that VarInfo has a dedicated implementation so this is only a generic + # fallback (previously used for SimpleVarInfo) model = setleafcontext(model, DynamicTransformationContext{false}()) vi = last(evaluate!!(model, vi)) return set_transformed!!(vi, t) @@ -890,8 +844,8 @@ function invlink!!(vi::AbstractVarInfo, vns::VarNameTuple, model::Model) return invlink!!(default_transformation(model, vi), vi, vns, model) end function invlink!!(::DynamicTransformation, vi::AbstractVarInfo, model::Model) - # Note that in practice this method is only called for SimpleVarInfo, because VarInfo - # has a dedicated implementation + # Note that VarInfo has a dedicated implementation so this is only a generic + # fallback (previously used for SimpleVarInfo) model = setleafcontext(model, DynamicTransformationContext{true}()) vi = last(evaluate!!(model, vi)) return set_transformed!!(vi, NoTransformation()) @@ -946,47 +900,6 @@ This will be called prior to `model` evaluation, allowing one to perform a singl basis as is done with [`DynamicTransformation`](@ref). See also: [`StaticTransformation`](@ref), [`DynamicTransformation`](@ref). - -# Examples -```julia-repl -julia> using DynamicPPL, Distributions, Bijectors - -julia> @model demo() = x ~ Normal() -demo (generic function with 2 methods) - -julia> # By subtyping `Transform`, we inherit the `(inv)link!!`. - struct MyBijector <: Bijectors.Transform end - -julia> # Define some dummy `inverse` which will be used in the `link!!` call. - Bijectors.inverse(f::MyBijector) = identity - -julia> # We need to define `with_logabsdet_jacobian` for `MyBijector` - # (`identity` already has `with_logabsdet_jacobian` defined) - function Bijectors.with_logabsdet_jacobian(::MyBijector, x) - # Just using a large number of the logabsdet-jacobian term - # for demonstration purposes. - return (x, 1000) - end - -julia> # Change the `default_transformation` for our model to be a - # `StaticTransformation` using `MyBijector`. - function DynamicPPL.default_transformation(::Model{typeof(demo)}) - return DynamicPPL.StaticTransformation(MyBijector()) - end - -julia> model = demo(); - -julia> vi = SimpleVarInfo(x=1.0) -SimpleVarInfo((x = 1.0,), 0.0) - -julia> # Uses the `inverse` of `MyBijector`, which we have defined as `identity` - vi_linked = link!!(vi, model) -Transformed SimpleVarInfo((x = 1.0,), 0.0) - -julia> # Now performs a single `invlink!!` before model evaluation. - logjoint(model, vi_linked) --1001.4189385332047 -``` """ function maybe_invlink_before_eval!!(vi::AbstractVarInfo, model::Model) return maybe_invlink_before_eval!!(transformation(vi), vi, model) diff --git a/src/chains.jl b/src/chains.jl index 2b5976b9b..892423822 100644 --- a/src/chains.jl +++ b/src/chains.jl @@ -133,3 +133,60 @@ function ParamsWithStats( end return ParamsWithStats(params, stats) end + +""" + ParamsWithStats( + param_vector::AbstractVector, + ldf::DynamicPPL.Experimental.FastLDF, + stats::NamedTuple=NamedTuple(); + include_colon_eq::Bool=true, + include_log_probs::Bool=true, + ) + +Generate a `ParamsWithStats` by re-evaluating the given `ldf` with the provided +`param_vector`. + +This method is intended to replace the old method of obtaining parameters and statistics +via `unflatten` plus re-evaluation. It is faster for two reasons: + +1. It does not rely on `deepcopy`-ing the VarInfo object (this used to be mandatory as + otherwise re-evaluation would mutate the VarInfo, rendering it unusable for subsequent + MCMC iterations). +2. The re-evaluation is faster as it uses `OnlyAccsVarInfo`. +""" +function ParamsWithStats( + param_vector::AbstractVector, + ldf::DynamicPPL.Experimental.FastLDF, + stats::NamedTuple=NamedTuple(); + include_colon_eq::Bool=true, + include_log_probs::Bool=true, +) + strategy = InitFromParams( + VectorWithRanges(ldf._iden_varname_ranges, ldf._varname_ranges, param_vector), + nothing, + ) + accs = if include_log_probs + ( + DynamicPPL.LogPriorAccumulator(), + DynamicPPL.LogLikelihoodAccumulator(), + DynamicPPL.ValuesAsInModelAccumulator(include_colon_eq), + ) + else + (DynamicPPL.ValuesAsInModelAccumulator(include_colon_eq),) + end + _, vi = DynamicPPL.Experimental.fast_evaluate!!( + ldf.model, strategy, AccumulatorTuple(accs) + ) + params = DynamicPPL.getacc(vi, Val(:ValuesAsInModel)).values + if include_log_probs + stats = merge( + stats, + ( + logprior=DynamicPPL.getlogprior(vi), + loglikelihood=DynamicPPL.getloglikelihood(vi), + lp=DynamicPPL.getlogjoint(vi), + ), + ) + end + return ParamsWithStats(params, stats) +end diff --git a/src/compiler.jl b/src/compiler.jl index badba9f9d..3324780ca 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -718,14 +718,15 @@ end # TODO(mhauru) matchingvalue has methods that can accept both types and values. Why? # TODO(mhauru) This function needs a more comprehensive docstring. """ - matchingvalue(vi, value) + matchingvalue(param_eltype, value) -Convert the `value` to the correct type for the `vi` object. +Convert the `value` to the correct type, given the element type of the parameters +being used to evaluate the model. """ -function matchingvalue(vi, value) +function matchingvalue(param_eltype, value) T = typeof(value) if hasmissing(T) - _value = convert(get_matching_type(vi, T), value) + _value = convert(get_matching_type(param_eltype, T), value) # TODO(mhauru) Why do we make a deepcopy, even though in the !hasmissing branch we # are happy to return `value` as-is? if _value === value @@ -738,29 +739,30 @@ function matchingvalue(vi, value) end end -function matchingvalue(vi, value::FloatOrArrayType) - return get_matching_type(vi, value) +function matchingvalue(param_eltype, value::FloatOrArrayType) + return get_matching_type(param_eltype, value) end -function matchingvalue(vi, ::TypeWrap{T}) where {T} - return TypeWrap{get_matching_type(vi, T)}() +function matchingvalue(param_eltype, ::TypeWrap{T}) where {T} + return TypeWrap{get_matching_type(param_eltype, T)}() end # TODO(mhauru) This function needs a more comprehensive docstring. What is it for? """ - get_matching_type(vi, ::TypeWrap{T}) where {T} + get_matching_type(param_eltype, ::TypeWrap{T}) where {T} -Get the specialized version of type `T` for `vi`. +Get the specialized version of type `T`, given an element type of the parameters +being used to evaluate the model. """ get_matching_type(_, ::Type{T}) where {T} = T -function get_matching_type(vi, ::Type{<:Union{Missing,AbstractFloat}}) - return Union{Missing,float_type_with_fallback(eltype(vi))} +function get_matching_type(param_eltype, ::Type{<:Union{Missing,AbstractFloat}}) + return Union{Missing,float_type_with_fallback(param_eltype)} end -function get_matching_type(vi, ::Type{<:AbstractFloat}) - return float_type_with_fallback(eltype(vi)) +function get_matching_type(param_eltype, ::Type{<:AbstractFloat}) + return float_type_with_fallback(param_eltype) end -function get_matching_type(vi, ::Type{<:Array{T,N}}) where {T,N} - return Array{get_matching_type(vi, T),N} +function get_matching_type(param_eltype, ::Type{<:Array{T,N}}) where {T,N} + return Array{get_matching_type(param_eltype, T),N} end -function get_matching_type(vi, ::Type{<:Array{T}}) where {T} - return Array{get_matching_type(vi, T)} +function get_matching_type(param_eltype, ::Type{<:Array{T}}) where {T} + return Array{get_matching_type(param_eltype, T)} end diff --git a/src/contexts.jl b/src/contexts.jl index 32a236e8e..46c5b8855 100644 --- a/src/contexts.jl +++ b/src/contexts.jl @@ -1,48 +1,32 @@ """ - NodeTrait(context) - NodeTrait(f, context) + AbstractParentContext -Specifies the role of `context` in the context-tree. +An abstract context that has a child context. -The officially supported traits are: -- `IsLeaf`: `context` does not have any decendants. -- `IsParent`: `context` has a child context to which we often defer. - Expects the following methods to be implemented: - - [`childcontext`](@ref) - - [`setchildcontext`](@ref) -""" -abstract type NodeTrait end -NodeTrait(_, context) = NodeTrait(context) - -""" - IsLeaf - -Specifies that the context is a leaf in the context-tree. -""" -struct IsLeaf <: NodeTrait end -""" - IsParent +Subtypes of `AbstractParentContext` must implement the following interface: -Specifies that the context is a parent in the context-tree. +- `DynamicPPL.childcontext(context::AbstractParentContext)`: Return the child context. +- `DynamicPPL.setchildcontext(parent::AbstractParentContext, child::AbstractContext)`: Reconstruct + `parent` but now using `child` as its child context. """ -struct IsParent <: NodeTrait end +abstract type AbstractParentContext <: AbstractContext end """ - childcontext(context) + childcontext(context::AbstractParentContext) Return the descendant context of `context`. """ childcontext """ - setchildcontext(parent::AbstractContext, child::AbstractContext) + setchildcontext(parent::AbstractParentContext, child::AbstractContext) Reconstruct `parent` but now using `child` is its [`childcontext`](@ref), effectively updating the child context. # Examples ```jldoctest -julia> using DynamicPPL: DynamicTransformationContext +julia> using DynamicPPL: DynamicTransformationContext, ConditionContext julia> ctx = ConditionContext((; a = 1)); @@ -60,12 +44,11 @@ setchildcontext """ leafcontext(context::AbstractContext) -Return the leaf of `context`, i.e. the first descendant context that `IsLeaf`. +Return the leaf of `context`, i.e. the first descendant context that is not an +`AbstractParentContext`. """ -leafcontext(context::AbstractContext) = - leafcontext(NodeTrait(leafcontext, context), context) -leafcontext(::IsLeaf, context::AbstractContext) = context -leafcontext(::IsParent, context::AbstractContext) = leafcontext(childcontext(context)) +leafcontext(context::AbstractContext) = context +leafcontext(context::AbstractParentContext) = leafcontext(childcontext(context)) """ setleafcontext(left::AbstractContext, right::AbstractContext) @@ -80,12 +63,10 @@ original leaf context of `left`. ```jldoctest julia> using DynamicPPL: leafcontext, setleafcontext, childcontext, setchildcontext, AbstractContext, DynamicTransformationContext -julia> struct ParentContext{C} <: AbstractContext +julia> struct ParentContext{C} <: AbstractParentContext context::C end -julia> DynamicPPL.NodeTrait(::ParentContext) = DynamicPPL.IsParent() - julia> DynamicPPL.childcontext(context::ParentContext) = context.context julia> DynamicPPL.setchildcontext(::ParentContext, child) = ParentContext(child) @@ -104,21 +85,10 @@ julia> # Append another parent context. ParentContext(ParentContext(ParentContext(DefaultContext()))) ``` """ -function setleafcontext(left::AbstractContext, right::AbstractContext) - return setleafcontext( - NodeTrait(setleafcontext, left), NodeTrait(setleafcontext, right), left, right - ) -end -function setleafcontext( - ::IsParent, ::IsParent, left::AbstractContext, right::AbstractContext -) +function setleafcontext(left::AbstractParentContext, right::AbstractContext) return setchildcontext(left, setleafcontext(childcontext(left), right)) end -function setleafcontext(::IsParent, ::IsLeaf, left::AbstractContext, right::AbstractContext) - return setchildcontext(left, setleafcontext(childcontext(left), right)) -end -setleafcontext(::IsLeaf, ::IsParent, left::AbstractContext, right::AbstractContext) = right -setleafcontext(::IsLeaf, ::IsLeaf, left::AbstractContext, right::AbstractContext) = right +setleafcontext(::AbstractContext, right::AbstractContext) = right """ DynamicPPL.tilde_assume!!( @@ -138,10 +108,15 @@ This function should return a tuple `(x, vi)`, where `x` is the sampled value (w must be in unlinked space!) and `vi` is the updated VarInfo. """ function tilde_assume!!( - context::AbstractContext, right::Distribution, vn::VarName, vi::AbstractVarInfo + context::AbstractParentContext, right::Distribution, vn::VarName, vi::AbstractVarInfo ) return tilde_assume!!(childcontext(context), right, vn, vi) end +function tilde_assume!!( + context::AbstractContext, ::Distribution, ::VarName, ::AbstractVarInfo +) + return error("tilde_assume!! not implemented for context of type $(typeof(context))") +end """ DynamicPPL.tilde_observe!!( @@ -171,7 +146,7 @@ This function should return a tuple `(left, vi)`, where `left` is the same as th `vi` is the updated VarInfo. """ function tilde_observe!!( - context::AbstractContext, + context::AbstractParentContext, right::Distribution, left, vn::Union{VarName,Nothing}, @@ -179,3 +154,12 @@ function tilde_observe!!( ) return tilde_observe!!(childcontext(context), right, left, vn, vi) end +function tilde_observe!!( + context::AbstractContext, + ::Distribution, + ::Any, + ::Union{VarName,Nothing}, + ::AbstractVarInfo, +) + return error("tilde_observe!! not implemented for context of type $(typeof(context))") +end diff --git a/src/contexts/conditionfix.jl b/src/contexts/conditionfix.jl index d3802de85..7a34db5cb 100644 --- a/src/contexts/conditionfix.jl +++ b/src/contexts/conditionfix.jl @@ -11,7 +11,7 @@ when there are varnames that cannot be represented as symbols, e.g. """ struct ConditionContext{ Values<:Union{NamedTuple,AbstractDict{<:VarName}},Ctx<:AbstractContext -} <: AbstractContext +} <: AbstractParentContext values::Values context::Ctx end @@ -41,9 +41,10 @@ function Base.show(io::IO, context::ConditionContext) return print(io, "ConditionContext($(context.values), $(childcontext(context)))") end -NodeTrait(::ConditionContext) = IsParent() childcontext(context::ConditionContext) = context.context -setchildcontext(parent::ConditionContext, child) = ConditionContext(parent.values, child) +function setchildcontext(parent::ConditionContext, child::AbstractContext) + return ConditionContext(parent.values, child) +end """ hasconditioned(context::AbstractContext, vn::VarName) @@ -76,11 +77,8 @@ Return `true` if `vn` is found in `context` or any of its descendants. This is contrast to [`hasconditioned(::AbstractContext, ::VarName)`](@ref) which only checks for `vn` in `context`, not recursively checking if `vn` is in any of its descendants. """ -function hasconditioned_nested(context::AbstractContext, vn) - return hasconditioned_nested(NodeTrait(hasconditioned_nested, context), context, vn) -end -hasconditioned_nested(::IsLeaf, context, vn) = hasconditioned(context, vn) -function hasconditioned_nested(::IsParent, context, vn) +hasconditioned_nested(context::AbstractContext, vn) = hasconditioned(context, vn) +function hasconditioned_nested(context::AbstractParentContext, vn) return hasconditioned(context, vn) || hasconditioned_nested(childcontext(context), vn) end function hasconditioned_nested(context::PrefixContext, vn) @@ -96,15 +94,12 @@ This is contrast to [`getconditioned`](@ref) which only returns the value `vn` i not recursively looking into its descendants. """ function getconditioned_nested(context::AbstractContext, vn) - return getconditioned_nested(NodeTrait(getconditioned_nested, context), context, vn) -end -function getconditioned_nested(::IsLeaf, context, vn) return error("context $(context) does not contain value for $vn") end function getconditioned_nested(context::PrefixContext, vn) return getconditioned_nested(collapse_prefix_stack(context), vn) end -function getconditioned_nested(::IsParent, context, vn) +function getconditioned_nested(context::AbstractParentContext, vn) return if hasconditioned(context, vn) getconditioned(context, vn) else @@ -113,7 +108,7 @@ function getconditioned_nested(::IsParent, context, vn) end """ - decondition(context::AbstractContext, syms...) + decondition_context(context::AbstractContext, syms...) Return `context` but with `syms` no longer conditioned on. @@ -121,13 +116,10 @@ Note that this recursively traverses contexts, deconditioning all along the way. See also: [`condition`](@ref) """ -decondition_context(::IsLeaf, context, args...) = context -function decondition_context(::IsParent, context, args...) +decondition_context(context::AbstractContext, args...) = context +function decondition_context(context::AbstractParentContext, args...) return setchildcontext(context, decondition_context(childcontext(context), args...)) end -function decondition_context(context, args...) - return decondition_context(NodeTrait(context), context, args...) -end function decondition_context(context::ConditionContext) return decondition_context(childcontext(context)) end @@ -160,11 +152,8 @@ Return `NamedTuple` of values that are conditioned on under context`. Note that this will recursively traverse the context stack and return a merged version of the condition values. """ -function conditioned(context::AbstractContext) - return conditioned(NodeTrait(conditioned, context), context) -end -conditioned(::IsLeaf, context) = NamedTuple() -conditioned(::IsParent, context) = conditioned(childcontext(context)) +conditioned(::AbstractContext) = NamedTuple() +conditioned(context::AbstractParentContext) = conditioned(childcontext(context)) function conditioned(context::ConditionContext) # Note the order of arguments to `merge`. The behavior of the rest of DPPL # is that the outermost `context` takes precendence, hence when resolving @@ -176,7 +165,7 @@ function conditioned(context::PrefixContext) return conditioned(collapse_prefix_stack(context)) end -struct FixedContext{Values,Ctx<:AbstractContext} <: AbstractContext +struct FixedContext{Values,Ctx<:AbstractContext} <: AbstractParentContext values::Values context::Ctx end @@ -197,16 +186,17 @@ function Base.show(io::IO, context::FixedContext) return print(io, "FixedContext($(context.values), $(childcontext(context)))") end -NodeTrait(::FixedContext) = IsParent() childcontext(context::FixedContext) = context.context -setchildcontext(parent::FixedContext, child) = FixedContext(parent.values, child) +function setchildcontext(parent::FixedContext, child::AbstractContext) + return FixedContext(parent.values, child) +end """ hasfixed(context::AbstractContext, vn::VarName) Return `true` if a fixed value for `vn` is found in `context`. """ -hasfixed(context::AbstractContext, vn::VarName) = false +hasfixed(::AbstractContext, ::VarName) = false hasfixed(context::FixedContext, vn::VarName) = hasvalue(context.values, vn) function hasfixed(context::FixedContext, vns::AbstractArray{<:VarName}) return all(Base.Fix1(hasvalue, context.values), vns) @@ -230,11 +220,8 @@ Return `true` if a fixed value for `vn` is found in `context` or any of its desc This is contrast to [`hasfixed(::AbstractContext, ::VarName)`](@ref) which only checks for `vn` in `context`, not recursively checking if `vn` is in any of its descendants. """ -function hasfixed_nested(context::AbstractContext, vn) - return hasfixed_nested(NodeTrait(hasfixed_nested, context), context, vn) -end -hasfixed_nested(::IsLeaf, context, vn) = hasfixed(context, vn) -function hasfixed_nested(::IsParent, context, vn) +hasfixed_nested(context::AbstractContext, vn) = hasfixed(context, vn) +function hasfixed_nested(context::AbstractParentContext, vn) return hasfixed(context, vn) || hasfixed_nested(childcontext(context), vn) end function hasfixed_nested(context::PrefixContext, vn) @@ -250,15 +237,12 @@ This is contrast to [`getfixed`](@ref) which only returns the value `vn` in `con not recursively looking into its descendants. """ function getfixed_nested(context::AbstractContext, vn) - return getfixed_nested(NodeTrait(getfixed_nested, context), context, vn) -end -function getfixed_nested(::IsLeaf, context, vn) return error("context $(context) does not contain value for $vn") end function getfixed_nested(context::PrefixContext, vn) return getfixed_nested(collapse_prefix_stack(context), vn) end -function getfixed_nested(::IsParent, context, vn) +function getfixed_nested(context::AbstractParentContext, vn) return if hasfixed(context, vn) getfixed(context, vn) else @@ -283,7 +267,7 @@ end function fix(values::NTuple{<:Any,<:Pair{<:VarName}}) return fix(DefaultContext(), values) end -fix(context::AbstractContext, values::NamedTuple{()}) = context +fix(context::AbstractContext, ::NamedTuple{()}) = context function fix(context::AbstractContext, values::Union{AbstractDict,NamedTuple}) return FixedContext(values, context) end @@ -306,13 +290,10 @@ Note that this recursively traverses contexts, unfixing all along the way. See also: [`fix`](@ref) """ -unfix(::IsLeaf, context, args...) = context -function unfix(::IsParent, context, args...) +unfix(context::AbstractContext, args...) = context +function unfix(context::AbstractParentContext, args...) return setchildcontext(context, unfix(childcontext(context), args...)) end -function unfix(context, args...) - return unfix(NodeTrait(context), context, args...) -end function unfix(context::FixedContext) return unfix(childcontext(context)) end @@ -341,9 +322,8 @@ Return the values that are fixed under `context`. Note that this will recursively traverse the context stack and return a merged version of the fix values. """ -fixed(context::AbstractContext) = fixed(NodeTrait(fixed, context), context) -fixed(::IsLeaf, context) = NamedTuple() -fixed(::IsParent, context) = fixed(childcontext(context)) +fixed(::AbstractContext) = NamedTuple() +fixed(context::AbstractParentContext) = fixed(childcontext(context)) function fixed(context::FixedContext) # Note the order of arguments to `merge`. The behavior of the rest of DPPL # is that the outermost `context` takes precendence, hence when resolving @@ -374,7 +354,7 @@ topic](https://turinglang.org/DynamicPPL.jl/previews/PR892/internals/submodel_co which explains this in much more detail. ```jldoctest -julia> using DynamicPPL: collapse_prefix_stack +julia> using DynamicPPL: collapse_prefix_stack, PrefixContext, ConditionContext julia> c1 = PrefixContext(@varname(a), ConditionContext((x=1, ))); @@ -403,11 +383,8 @@ function collapse_prefix_stack(context::PrefixContext) # depth of the context stack. return prefix_cond_and_fixed_variables(collapsed, context.vn_prefix) end -function collapse_prefix_stack(context::AbstractContext) - return collapse_prefix_stack(NodeTrait(collapse_prefix_stack, context), context) -end -collapse_prefix_stack(::IsLeaf, context) = context -function collapse_prefix_stack(::IsParent, context) +collapse_prefix_stack(context::AbstractContext) = context +function collapse_prefix_stack(context::AbstractParentContext) new_child_context = collapse_prefix_stack(childcontext(context)) return setchildcontext(context, new_child_context) end @@ -448,19 +425,10 @@ function prefix_cond_and_fixed_variables(ctx::FixedContext, prefix::VarName) prefixed_child_ctx = prefix_cond_and_fixed_variables(childcontext(ctx), prefix) return FixedContext(prefixed_vn_dict, prefixed_child_ctx) end -function prefix_cond_and_fixed_variables(c::AbstractContext, prefix::VarName) - return prefix_cond_and_fixed_variables( - NodeTrait(prefix_cond_and_fixed_variables, c), c, prefix - ) -end -function prefix_cond_and_fixed_variables( - ::IsLeaf, context::AbstractContext, prefix::VarName -) +function prefix_cond_and_fixed_variables(context::AbstractContext, ::VarName) return context end -function prefix_cond_and_fixed_variables( - ::IsParent, context::AbstractContext, prefix::VarName -) +function prefix_cond_and_fixed_variables(context::AbstractParentContext, prefix::VarName) return setchildcontext( context, prefix_cond_and_fixed_variables(childcontext(context), prefix) ) diff --git a/src/contexts/default.jl b/src/contexts/default.jl index ec21e1a56..3cafe39f1 100644 --- a/src/contexts/default.jl +++ b/src/contexts/default.jl @@ -17,7 +17,6 @@ with `DefaultContext` means 'calculating the log-probability associated with the in the `AbstractVarInfo`'. """ struct DefaultContext <: AbstractContext end -NodeTrait(::DefaultContext) = IsLeaf() """ DynamicPPL.tilde_assume!!( diff --git a/src/contexts/init.jl b/src/contexts/init.jl index 44dbc5508..a79969a13 100644 --- a/src/contexts/init.jl +++ b/src/contexts/init.jl @@ -1,11 +1,11 @@ """ AbstractInitStrategy -Abstract type representing the possible ways of initialising new values for -the random variables in a model (e.g., when creating a new VarInfo). +Abstract type representing the possible ways of initialising new values for the random +variables in a model (e.g., when creating a new VarInfo). -Any subtype of `AbstractInitStrategy` must implement the -[`DynamicPPL.init`](@ref) method. +Any subtype of `AbstractInitStrategy` must implement the [`DynamicPPL.init`](@ref) method, +and in some cases, [`DynamicPPL.get_param_eltype`](@ref) (see its docstring for details). """ abstract type AbstractInitStrategy end @@ -14,14 +14,60 @@ abstract type AbstractInitStrategy end Generate a new value for a random variable with the given distribution. -!!! warning "Return values must be unlinked" - The values returned by `init` must always be in the untransformed space, i.e., - they must be within the support of the original distribution. That means that, - for example, `init(rng, dist, u::InitFromUniform)` will in general return values that - are outside the range [u.lower, u.upper]. +This function must return a tuple `(x, trf)`, where + +- `x` is the generated value +- `trf` is a function that transforms the generated value back to the unlinked space. If the + value is already in unlinked space, then this should be `DynamicPPL.typed_identity`. You + can also use `Base.identity`, but if you use this, you **must** be confident that + `zero(eltype(x))` will **never** error. See the docstring of `typed_identity` for more + information. """ function init end +""" + DynamicPPL.get_param_eltype(strategy::AbstractInitStrategy) + +Return the element type of the parameters generated from the given initialisation strategy. + +The default implementation returns `Any`. However, for `InitFromParams` which provides known +parameters for evaluating the model, methods are implemented in order to return more specific +types. + +In general, if you are implementing a custom `AbstractInitStrategy`, correct behaviour can +only be guaranteed if you implement this method as well. However, quite often, the default +return value of `Any` will actually suffice. The cases where this does *not* suffice, and +where you _do_ have to manually implement `get_param_eltype`, are explained in the extended +help (see `??DynamicPPL.get_param_eltype` in the REPL). + +# Extended help + +There are a few edge cases in DynamicPPL where the element type is needed. These largely +relate to determining the element type of accumulators ahead of time (_before_ evaluation), +as well as promoting type parameters in model arguments. The classic case is when evaluating +a model with ForwardDiff: the accumulators must be set to `Dual`s, and any `Vector{Float64}` +arguments must be promoted to `Vector{Dual}`. Other tracer types, for example those in +SparseConnectivityTracer.jl, also require similar treatment. + +If the `AbstractInitStrategy` is never used in combination with tracer types, then it is +perfectly safe to return `Any`. This does not lead to type instability downstream because +the actual accumulators will still be created with concrete Float types (the `Any` is just +used to determine whether the float type needs to be modified). + +In case that wasn't enough: in fact, even the above is not always true. Firstly, the +accumulator argument is only true when evaluating with ThreadSafeVarInfo. See the comments +in `DynamicPPL.unflatten` for more details. For non-threadsafe evaluation, Julia is capable +of automatically promoting the types on its own. Secondly, the promotion only matters if you +are trying to directly assign into a `Vector{Float64}` with a `ForwardDiff.Dual` or similar +tracer type, for example using `xs[i] = MyDual`. This doesn't actually apply to +tilde-statements like `xs[i] ~ ...` because those use `Accessors.@set` under the hood, which +also does the promotion for you. For the gory details, see the following issues: + +- https://github.com/TuringLang/DynamicPPL.jl/issues/906 for accumulator types +- https://github.com/TuringLang/DynamicPPL.jl/issues/823 for type argument promotion +""" +get_param_eltype(::AbstractInitStrategy) = Any + """ InitFromPrior() @@ -29,7 +75,7 @@ Obtain new values by sampling from the prior distribution. """ struct InitFromPrior <: AbstractInitStrategy end function init(rng::Random.AbstractRNG, ::VarName, dist::Distribution, ::InitFromPrior) - return rand(rng, dist) + return rand(rng, dist), typed_identity end """ @@ -69,43 +115,61 @@ function init(rng::Random.AbstractRNG, ::VarName, dist::Distribution, u::InitFro if x isa Array{<:Any,0} x = x[] end - return x + return x, typed_identity end """ InitFromParams( - params::Union{AbstractDict{<:VarName},NamedTuple}, + params::Any fallback::Union{AbstractInitStrategy,Nothing}=InitFromPrior() ) -Obtain new values by extracting them from the given dictionary or NamedTuple. +Obtain new values by extracting them from the given set of `params`. + +The most common use case is to provide a `NamedTuple` or `AbstractDict{<:VarName}`, which +provides a mapping from variable names to values. However, we leave the type of `params` +open in order to allow for custom parameter storage types. -The parameter `fallback` specifies how new values are to be obtained if they -cannot be found in `params`, or they are specified as `missing`. `fallback` -can either be an initialisation strategy itself, in which case it will be -used to obtain new values, or it can be `nothing`, in which case an error -will be thrown. The default for `fallback` is `InitFromPrior()`. +## Custom parameter storage types -!!! note - The values in `params` must be provided in the space of the untransformed - distribution. +For `InitFromParams` to work correctly with a custom `params::P`, you need to implement + +```julia +DynamicPPL.init(rng, vn::VarName, dist::Distribution, p::InitFromParams{P}) where {P} +``` + +This tells you how to obtain values for the random variable `vn` from `p.params`. Note that +the last argument is `InitFromParams(params)`, not just `params` itself. Please see the +docstring of [`DynamicPPL.init`](@ref) for more information on the expected behaviour. + +If you only use `InitFromParams` with `DynamicPPL.OnlyAccsVarInfo`, as is usually the case, +then you will not need to implement anything else. So far, this is the same as you would do +for creating any new `AbstractInitStrategy` subtype. + +However, to use `InitFromParams` with a full `DynamicPPL.VarInfo`, you *may* also need to +implement + +```julia +DynamicPPL.get_param_eltype(p::InitFromParams{P}) where {P} +``` + +See the docstring of [`DynamicPPL.get_param_eltype`](@ref) for more information on when this +is needed. + +The argument `fallback` specifies how new values are to be obtained if they cannot be found +in `params`, or they are specified as `missing`. `fallback` can either be an initialisation +strategy itself, in which case it will be used to obtain new values, or it can be `nothing`, +in which case an error will be thrown. The default for `fallback` is `InitFromPrior()`. """ struct InitFromParams{P,S<:Union{AbstractInitStrategy,Nothing}} <: AbstractInitStrategy params::P fallback::S - function InitFromParams( - params::AbstractDict{<:VarName}, - fallback::Union{AbstractInitStrategy,Nothing}=InitFromPrior(), - ) - return new{typeof(params),typeof(fallback)}(params, fallback) - end - function InitFromParams( - params::NamedTuple, fallback::Union{AbstractInitStrategy,Nothing}=InitFromPrior() - ) - return InitFromParams(to_varname_dict(params), fallback) - end end -function init(rng::Random.AbstractRNG, vn::VarName, dist::Distribution, p::InitFromParams) +InitFromParams(params) = InitFromParams(params, InitFromPrior()) + +function init( + rng::Random.AbstractRNG, vn::VarName, dist::Distribution, p::InitFromParams{P} +) where {P<:Union{AbstractDict{<:VarName},NamedTuple}} # TODO(penelopeysm): It would be nice to do a check to make sure that all # of the parameters in `p.params` were actually used, and either warn or # error if they aren't. This is actually quite non-trivial though because @@ -119,13 +183,89 @@ function init(rng::Random.AbstractRNG, vn::VarName, dist::Distribution, p::InitF else # TODO(penelopeysm): Since x is user-supplied, maybe we could also # check here that the type / size of x matches the dist? - x + x, typed_identity end else p.fallback === nothing && error("No value was provided for the variable `$(vn)`.") init(rng, vn, dist, p.fallback) end end +function get_param_eltype( + strategy::InitFromParams{<:Union{AbstractDict{<:VarName},NamedTuple}} +) + return infer_nested_eltype(typeof(strategy.params)) +end + +""" + RangeAndLinked + +Suppose we have vectorised parameters `params::AbstractVector{<:Real}`. Each random variable +in the model will in general correspond to a sub-vector of `params`. This struct stores +information about that range, as well as whether the sub-vector represents a linked value or +an unlinked value. + +$(TYPEDFIELDS) +""" +struct RangeAndLinked + # indices that the variable corresponds to in the vectorised parameter + range::UnitRange{Int} + # whether it's linked + is_linked::Bool +end + +""" + VectorWithRanges( + iden_varname_ranges::NamedTuple, + varname_ranges::Dict{VarName,RangeAndLinked}, + vect::AbstractVector{<:Real}, + ) + +A struct that wraps a vector of parameter values, plus information about how random +variables map to ranges in that vector. + +In the simplest case, this could be accomplished only with a single dictionary mapping +VarNames to ranges and link status. However, for performance reasons, we separate out +VarNames with identity optics into a NamedTuple (`iden_varname_ranges`). All +non-identity-optic VarNames are stored in the `varname_ranges` Dict. + +It would be nice to improve the NamedTuple and Dict approach. See, e.g. +https://github.com/TuringLang/DynamicPPL.jl/issues/1116. +""" +struct VectorWithRanges{N<:NamedTuple,T<:AbstractVector{<:Real}} + # This NamedTuple stores the ranges for identity VarNames + iden_varname_ranges::N + # This Dict stores the ranges for all other VarNames + varname_ranges::Dict{VarName,RangeAndLinked} + # The full parameter vector which we index into to get variable values + vect::T +end + +function _get_range_and_linked( + vr::VectorWithRanges, ::VarName{sym,typeof(identity)} +) where {sym} + return vr.iden_varname_ranges[sym] +end +function _get_range_and_linked(vr::VectorWithRanges, vn::VarName) + return vr.varname_ranges[vn] +end +function init( + ::Random.AbstractRNG, + vn::VarName, + dist::Distribution, + p::InitFromParams{<:VectorWithRanges}, +) + vr = p.params + range_and_linked = _get_range_and_linked(vr, vn) + transform = if range_and_linked.is_linked + from_linked_vec_transform(dist) + else + from_vec_transform(dist) + end + return (@view vr.vect[range_and_linked.range]), transform +end +function get_param_eltype(strategy::InitFromParams{<:VectorWithRanges}) + return eltype(strategy.params.vect) +end """ InitContext( @@ -150,15 +290,13 @@ struct InitContext{R<:Random.AbstractRNG,S<:AbstractInitStrategy} <: AbstractCon return InitContext(Random.default_rng(), strategy) end end -NodeTrait(::InitContext) = IsLeaf() function tilde_assume!!( ctx::InitContext, dist::Distribution, vn::VarName, vi::AbstractVarInfo ) in_varinfo = haskey(vi, vn) - # `init()` always returns values in original space, i.e. possibly - # constrained - x = init(ctx.rng, vn, dist, ctx.strategy) + val, transform = init(ctx.rng, vn, dist, ctx.strategy) + x, inv_logjac = with_logabsdet_jacobian(transform, val) # Determine whether to insert a transformed value into the VarInfo. # If the VarInfo alrady had a value for this variable, we will # keep the same linked status as in the original VarInfo. If not, we @@ -166,17 +304,49 @@ function tilde_assume!!( # is_transformed(vi) returns true if vi is nonempty and all variables in vi # are linked. insert_transformed_value = in_varinfo ? is_transformed(vi, vn) : is_transformed(vi) - y, logjac = if insert_transformed_value - with_logabsdet_jacobian(link_transform(dist), x) + val_to_insert, logjac = if insert_transformed_value + # Calculate the forward logjac and sum them up. + y, fwd_logjac = with_logabsdet_jacobian(link_transform(dist), x) + # Note that if we use VectorWithRanges with a full VarInfo, this double-Jacobian + # calculation wastes a lot of time going from linked vectorised -> unlinked -> + # linked, and `inv_logjac` will also just be the negative of `fwd_logjac`. + # + # However, `VectorWithRanges` is only really used with `OnlyAccsVarInfo`, in which + # case this branch is never hit (since `in_varinfo` will always be false). It does + # mean that the combination of InitFromParams{<:VectorWithRanges} with a full, + # linked, VarInfo will be very slow. That should never really be used, though. So + # (at least for now) we can leave this branch in for full generality with other + # combinations of init strategies / VarInfo. + # + # TODO(penelopeysm): Figure out one day how to refactor this. The crux of the issue + # is that the transform used by `VectorWithRanges` is `from_linked_VEC_transform`, + # which is NOT the same as `inverse(link_transform)` (because there is an additional + # vectorisation step). We need `init` and `tilde_assume!!` to share this information + # but it's not clear right now how to do this. In my opinion, there are a couple of + # potential ways forward: + # + # 1. Just remove metadata entirely so that there is never any need to construct + # a linked vectorised value again. This would require us to use VAIMAcc as the only + # way of getting values. I consider this the best option, but it might take a long + # time. + # + # 2. Clean up the behaviour of bijectors so that we can have a complete separation + # between the linking and vectorisation parts of it. That way, `x` can either be + # unlinked, unlinked vectorised, linked, or linked vectorised, and regardless of + # which it is, we should only need to apply at most one linking and one + # vectorisation transform. Doing so would allow us to remove the first call to + # `with_logabsdet_jacobian`, and instead compose and/or uncompose the + # transformations before calling `with_logabsdet_jacobian` once. + y, -inv_logjac + fwd_logjac else - x, zero(LogProbType) + x, -inv_logjac end # Add the new value to the VarInfo. `push!!` errors if the value already # exists, hence the need for setindex!!. if in_varinfo - vi = setindex!!(vi, y, vn) + vi = setindex!!(vi, val_to_insert, vn) else - vi = push!!(vi, vn, y, dist) + vi = push!!(vi, vn, val_to_insert, dist) end # Neither of these set the `trans` flag so we have to do it manually if # necessary. diff --git a/src/contexts/prefix.jl b/src/contexts/prefix.jl index 24615e683..45307874a 100644 --- a/src/contexts/prefix.jl +++ b/src/contexts/prefix.jl @@ -13,7 +13,7 @@ unique. See also: [`to_submodel`](@ref) """ -struct PrefixContext{Tvn<:VarName,C<:AbstractContext} <: AbstractContext +struct PrefixContext{Tvn<:VarName,C<:AbstractContext} <: AbstractParentContext vn_prefix::Tvn context::C end @@ -23,7 +23,6 @@ function PrefixContext(::Val{sym}, context::AbstractContext) where {sym} end PrefixContext(::Val{sym}) where {sym} = PrefixContext(VarName{sym}()) -NodeTrait(::PrefixContext) = IsParent() childcontext(context::PrefixContext) = context.context function setchildcontext(ctx::PrefixContext, child::AbstractContext) return PrefixContext(ctx.vn_prefix, child) @@ -37,11 +36,8 @@ Apply the prefixes in the context `ctx` to the variable name `vn`. function prefix(ctx::PrefixContext, vn::VarName) return AbstractPPL.prefix(prefix(childcontext(ctx), vn), ctx.vn_prefix) end -function prefix(ctx::AbstractContext, vn::VarName) - return prefix(NodeTrait(ctx), ctx, vn) -end -prefix(::IsLeaf, ::AbstractContext, vn::VarName) = vn -function prefix(::IsParent, ctx::AbstractContext, vn::VarName) +prefix(::AbstractContext, vn::VarName) = vn +function prefix(ctx::AbstractParentContext, vn::VarName) return prefix(childcontext(ctx), vn) end @@ -72,11 +68,8 @@ function prefix_and_strip_contexts(ctx::PrefixContext, vn::VarName) ) return AbstractPPL.prefix(vn_prefixed, ctx.vn_prefix), child_context_without_prefixes end -function prefix_and_strip_contexts(ctx::AbstractContext, vn::VarName) - return prefix_and_strip_contexts(NodeTrait(ctx), ctx, vn) -end -prefix_and_strip_contexts(::IsLeaf, ctx::AbstractContext, vn::VarName) = (vn, ctx) -function prefix_and_strip_contexts(::IsParent, ctx::AbstractContext, vn::VarName) +prefix_and_strip_contexts(ctx::AbstractContext, vn::VarName) = (vn, ctx) +function prefix_and_strip_contexts(ctx::AbstractParentContext, vn::VarName) vn, new_ctx = prefix_and_strip_contexts(childcontext(ctx), vn) return vn, setchildcontext(ctx, new_ctx) end diff --git a/src/contexts/transformation.jl b/src/contexts/transformation.jl index 5153f7857..0914d7a79 100644 --- a/src/contexts/transformation.jl +++ b/src/contexts/transformation.jl @@ -7,10 +7,9 @@ constrained space if `isinverse` or unconstrained if `!isinverse`. Note that some `AbstractVarInfo` types, must notably `VarInfo`, override the `DynamicTransformationContext` methods with more efficient implementations. `DynamicTransformationContext` is a fallback for when we need to evaluate the model to know -how to do the transformation, used by e.g. `SimpleVarInfo`. +how to do the transformation. """ struct DynamicTransformationContext{isinverse} <: AbstractContext end -NodeTrait(::DynamicTransformationContext) = IsLeaf() function tilde_assume!!( ::DynamicTransformationContext{isinverse}, diff --git a/src/fasteval.jl b/src/fasteval.jl new file mode 100644 index 000000000..91c4edec0 --- /dev/null +++ b/src/fasteval.jl @@ -0,0 +1,401 @@ +using DynamicPPL: + AbstractVarInfo, + AccumulatorTuple, + InitContext, + InitFromParams, + AbstractInitStrategy, + LogJacobianAccumulator, + LogLikelihoodAccumulator, + LogPriorAccumulator, + Model, + ThreadSafeVarInfo, + VarInfo, + OnlyAccsVarInfo, + RangeAndLinked, + VectorWithRanges, + Metadata, + VarNamedVector, + default_accumulators, + float_type_with_fallback, + getlogjoint, + getlogjoint_internal, + getloglikelihood, + getlogprior, + getlogprior_internal +using ADTypes: ADTypes +using BangBang: BangBang +using AbstractPPL: AbstractPPL, VarName +using LogDensityProblems: LogDensityProblems +import DifferentiationInterface as DI +using Random: Random + +""" + DynamicPPL.fast_evaluate!!( + [rng::Random.AbstractRNG,] + model::Model, + strategy::AbstractInitStrategy, + accs::AccumulatorTuple, + ) + +Evaluate a model using parameters obtained via `strategy`, and only computing the results in +the provided accumulators. + +It is assumed that the accumulators passed in have been initialised to appropriate values, +as this function will not reset them. The default constructors for each accumulator will do +this for you correctly. + +Returns a tuple of the model's return value, plus an `OnlyAccsVarInfo`. Note that the `accs` +argument may be mutated (depending on how the accumulators are implemented); hence the `!!` +in the function name. +""" +@inline function fast_evaluate!!( + # Note that this `@inline` is mandatory for performance. If it's not inlined, it leads + # to extra allocations (even for trivial models) and much slower runtime. + rng::Random.AbstractRNG, + model::Model, + strategy::AbstractInitStrategy, + accs::AccumulatorTuple, +) + ctx = InitContext(rng, strategy) + model = DynamicPPL.setleafcontext(model, ctx) + # Calling `evaluate!!` would be fine, but would lead to an extra call to resetaccs!!, + # which is unnecessary. So we shortcircuit this by simply calling `_evaluate!!` + # directly. To preserve thread-safety we need to reproduce the ThreadSafeVarInfo logic + # here. + # TODO(penelopeysm): This should _not_ check Threads.nthreads(). I still don't know what + # it _should_ do, but this is wrong regardless. + # https://github.com/TuringLang/DynamicPPL.jl/issues/1086 + vi = if Threads.nthreads() > 1 + param_eltype = DynamicPPL.get_param_eltype(strategy) + accs = map(accs) do acc + DynamicPPL.convert_eltype(float_type_with_fallback(param_eltype), acc) + end + ThreadSafeVarInfo(OnlyAccsVarInfo(accs)) + else + OnlyAccsVarInfo(accs) + end + return DynamicPPL._evaluate!!(model, vi) +end +@inline function fast_evaluate!!( + model::Model, strategy::AbstractInitStrategy, accs::AccumulatorTuple +) + # This `@inline` is also mandatory for performance + return fast_evaluate!!(Random.default_rng(), model, strategy, accs) +end + +""" + DynamicPPL.LogDensityFunction( + model::Model, + getlogdensity::Function=getlogjoint_internal, + varinfo::AbstractVarInfo=VarInfo(model); + adtype::Union{ADTypes.AbstractADType,Nothing}=nothing, + ) + +A struct which contains a model, along with all the information necessary to: + + - calculate its log density at a given point; + - and if `adtype` is provided, calculate the gradient of the log density at that point. + +This information can be extracted using the LogDensityProblems.jl interface, specifically, +using `LogDensityProblems.logdensity` and `LogDensityProblems.logdensity_and_gradient`. If +`adtype` is nothing, then only `logdensity` is implemented. If `adtype` is a concrete AD +backend type, then `logdensity_and_gradient` is also implemented. + +There are several options for `getlogdensity` that are 'supported' out of the box: + +- [`getlogjoint_internal`](@ref): calculate the log joint, including the log-Jacobian term + for any variables that have been linked in the provided VarInfo. +- [`getlogprior_internal`](@ref): calculate the log prior, including the log-Jacobian term + for any variables that have been linked in the provided VarInfo. +- [`getlogjoint`](@ref): calculate the log joint in the model space, ignoring any effects of + linking +- [`getlogprior`](@ref): calculate the log prior in the model space, ignoring any effects of + linking +- [`getloglikelihood`](@ref): calculate the log likelihood (this is unaffected by linking, + since transforms are only applied to random variables) + +!!! note + By default, `LogDensityFunction` uses `getlogjoint_internal`, i.e., the result of + `LogDensityProblems.logdensity(f, x)` will depend on whether the `LogDensityFunction` + was created with a linked or unlinked VarInfo. This is done primarily to ease + interoperability with MCMC samplers. + +If you provide one of these functions, a `VarInfo` will be automatically created for you. If +you provide a different function, you have to manually create a VarInfo and pass it as the +third argument. + +If the `adtype` keyword argument is provided, then this struct will also store the adtype +along with other information for efficient calculation of the gradient of the log density. +Note that preparing a `LogDensityFunction` with an AD type `AutoBackend()` requires the AD +backend itself to have been loaded (e.g. with `import Backend`). + +## Fields + +Note that it is undefined behaviour to access any of a `LogDensityFunction`'s fields, apart +from: + +- `ldf.model`: The original model from which this `LogDensityFunction` was constructed. +- `ldf.adtype`: The AD type used for gradient calculations, or `nothing` if no AD + type was provided. + +# Extended help + +Up until DynamicPPL v0.38, there have been two ways of evaluating a DynamicPPL model at a +given set of parameters: + +1. With `unflatten` + `evaluate!!` with `DefaultContext`: this stores a vector of parameters + inside a VarInfo's metadata, then reads parameter values from the VarInfo during evaluation. + +2. With `InitFromParams`: this reads parameter values from a NamedTuple or a Dict, and stores + them inside a VarInfo's metadata. + +In general, both of these approaches work fine, but the fact that they modify the VarInfo's +metadata can often be quite wasteful. In particular, it is very common that the only outputs +we care about from model evaluation are those which are stored in accumulators, such as log +probability densities, or `ValuesAsInModel`. + +To avoid this issue, we use `OnlyAccsVarInfo`, which is a VarInfo that only contains +accumulators. It implements enough of the `AbstractVarInfo` interface to not error during +model evaluation. + +Because `OnlyAccsVarInfo` does not store any parameter values, when evaluating a model with +it, it is mandatory that parameters are provided from outside the VarInfo, namely via +`InitContext`. + +The main problem that we face is that it is not possible to directly implement +`DynamicPPL.init(rng, vn, dist, strategy)` for `strategy::InitFromParams{<:AbstractVector}`. +In particular, it is not clear: + + - which parts of the vector correspond to which random variables, and + - whether the variables are linked or unlinked. + +Traditionally, this problem has been solved by `unflatten`, because that function would +place values into the VarInfo's metadata alongside the information about ranges and linking. +That way, when we evaluate with `DefaultContext`, we can read this information out again. +However, we want to avoid using a metadata. Thus, here, we _extract this information from +the VarInfo_ a single time when constructing a `LogDensityFunction` object. Inside the +LogDensityFunction, we store a mapping from VarNames to ranges in that vector, along with +link status. + +For VarNames with identity optics, this is stored in a NamedTuple for efficiency. For all +other VarNames, this is stored in a Dict. The internal data structure used to represent this +could almost certainly be optimised further. See e.g. the discussion in +https://github.com/TuringLang/DynamicPPL.jl/issues/1116. + +When evaluating the model, this allows us to combine the parameter vector together with those +ranges to create an `InitFromParams{VectorWithRanges}`, which lets us very quickly read +parameter values from the vector. + +Note that this assumes that the ranges and link status are static throughout the lifetime of +the `LogDensityFunction` object. Therefore, a `LogDensityFunction` object cannot handle +models which have variable numbers of parameters, or models which may visit random variables +in different orders depending on stochastic control flow. **Indeed, silent errors may occur +with such models.** This is a general limitation of vectorised parameters: the original +`unflatten` + `evaluate!!` approach also fails with such models. +""" +struct LogDensityFunction{ + M<:Model, + AD<:Union{ADTypes.AbstractADType,Nothing}, + F<:Function, + N<:NamedTuple, + ADP<:Union{Nothing,DI.GradientPrep}, +} + model::M + adtype::AD + _getlogdensity::F + _iden_varname_ranges::N + _varname_ranges::Dict{VarName,RangeAndLinked} + _adprep::ADP + _dim::Int + + function LogDensityFunction( + model::Model, + getlogdensity::Function=getlogjoint_internal, + varinfo::AbstractVarInfo=VarInfo(model); + adtype::Union{ADTypes.AbstractADType,Nothing}=nothing, + ) + # Figure out which variable corresponds to which index, and + # which variables are linked. + all_iden_ranges, all_ranges = get_ranges_and_linked(varinfo) + x = [val for val in varinfo[:]] + dim = length(x) + # Do AD prep if needed + prep = if adtype === nothing + nothing + else + # Make backend-specific tweaks to the adtype + adtype = DynamicPPL.tweak_adtype(adtype, model, varinfo) + DI.prepare_gradient( + LogDensityAt(model, getlogdensity, all_iden_ranges, all_ranges), + adtype, + x, + ) + end + return new{ + typeof(model), + typeof(adtype), + typeof(getlogdensity), + typeof(all_iden_ranges), + typeof(prep), + }( + model, adtype, getlogdensity, all_iden_ranges, all_ranges, prep, dim + ) + end +end + +################################### +# LogDensityProblems.jl interface # +################################### +""" + fast_ldf_accs(getlogdensity::Function) + +Determine which accumulators are needed for fast evaluation with the given +`getlogdensity` function. +""" +fast_ldf_accs(::Function) = default_accumulators() +fast_ldf_accs(::typeof(getlogjoint_internal)) = default_accumulators() +function fast_ldf_accs(::typeof(getlogjoint)) + return AccumulatorTuple((LogPriorAccumulator(), LogLikelihoodAccumulator())) +end +function fast_ldf_accs(::typeof(getlogprior_internal)) + return AccumulatorTuple((LogPriorAccumulator(), LogJacobianAccumulator())) +end +fast_ldf_accs(::typeof(getlogprior)) = AccumulatorTuple((LogPriorAccumulator(),)) +fast_ldf_accs(::typeof(getloglikelihood)) = AccumulatorTuple((LogLikelihoodAccumulator(),)) + +struct LogDensityAt{M<:Model,F<:Function,N<:NamedTuple} + model::M + getlogdensity::F + iden_varname_ranges::N + varname_ranges::Dict{VarName,RangeAndLinked} +end +function (f::LogDensityAt)(params::AbstractVector{<:Real}) + strategy = InitFromParams( + VectorWithRanges(f.iden_varname_ranges, f.varname_ranges, params), nothing + ) + accs = fast_ldf_accs(f.getlogdensity) + _, vi = DynamicPPL.fast_evaluate!!(f.model, strategy, accs) + return f.getlogdensity(vi) +end + +function LogDensityProblems.logdensity( + ldf::LogDensityFunction, params::AbstractVector{<:Real} +) + return LogDensityAt( + ldf.model, ldf._getlogdensity, ldf._iden_varname_ranges, ldf._varname_ranges + )( + params + ) +end + +function LogDensityProblems.logdensity_and_gradient( + ldf::LogDensityFunction, params::AbstractVector{<:Real} +) + return DI.value_and_gradient( + LogDensityAt( + ldf.model, ldf._getlogdensity, ldf._iden_varname_ranges, ldf._varname_ranges + ), + ldf._adprep, + ldf.adtype, + params, + ) +end + +function LogDensityProblems.capabilities(::Type{<:LogDensityFunction{M,Nothing}}) where {M} + return LogDensityProblems.LogDensityOrder{0}() +end +function LogDensityProblems.capabilities( + ::Type{<:LogDensityFunction{M,<:ADTypes.AbstractADType}} +) where {M} + return LogDensityProblems.LogDensityOrder{1}() +end +function LogDensityProblems.dimension(ldf::LogDensityFunction) + return ldf._dim +end + +""" + tweak_adtype( + adtype::ADTypes.AbstractADType, + model::Model, + varinfo::AbstractVarInfo, + ) + +Return an 'optimised' form of the adtype. This is useful for doing +backend-specific optimisation of the adtype (e.g., for ForwardDiff, calculating +the chunk size: see the method override in `ext/DynamicPPLForwardDiffExt.jl`). +The model is passed as a parameter in case the optimisation depends on the +model. + +By default, this just returns the input unchanged. +""" +tweak_adtype(adtype::ADTypes.AbstractADType, ::Model, ::AbstractVarInfo) = adtype + +###################################################### +# Helper functions to extract ranges and link status # +###################################################### + +""" + get_ranges_and_linked(varinfo::VarInfo) + +Given a `VarInfo`, extract the ranges of each variable in the vectorised parameter +representation, along with whether each variable is linked or unlinked. + +This function should return a tuple containing: + +- A NamedTuple mapping VarNames with identity optics to their corresponding `RangeAndLinked` +- A Dict mapping all other VarNames to their corresponding `RangeAndLinked`. +""" +function get_ranges_and_linked(varinfo::VarInfo{<:NamedTuple{syms}}) where {syms} + all_iden_ranges = NamedTuple() + all_ranges = Dict{VarName,RangeAndLinked}() + offset = 1 + for sym in syms + md = varinfo.metadata[sym] + this_md_iden, this_md_others, offset = get_ranges_and_linked_metadata(md, offset) + all_iden_ranges = merge(all_iden_ranges, this_md_iden) + all_ranges = merge(all_ranges, this_md_others) + end + return all_iden_ranges, all_ranges +end +function get_ranges_and_linked(varinfo::VarInfo{<:Union{Metadata,VarNamedVector}}) + all_iden, all_others, _ = get_ranges_and_linked_metadata(varinfo.metadata, 1) + return all_iden, all_others +end +function get_ranges_and_linked_metadata(md::Metadata, start_offset::Int) + all_iden_ranges = NamedTuple() + all_ranges = Dict{VarName,RangeAndLinked}() + offset = start_offset + for (vn, idx) in md.idcs + is_linked = md.is_transformed[idx] + range = md.ranges[idx] .+ (start_offset - 1) + if AbstractPPL.getoptic(vn) === identity + all_iden_ranges = merge( + all_iden_ranges, + NamedTuple((AbstractPPL.getsym(vn) => RangeAndLinked(range, is_linked),)), + ) + else + all_ranges[vn] = RangeAndLinked(range, is_linked) + end + offset += length(range) + end + return all_iden_ranges, all_ranges, offset +end +function get_ranges_and_linked_metadata(vnv::VarNamedVector, start_offset::Int) + all_iden_ranges = NamedTuple() + all_ranges = Dict{VarName,RangeAndLinked}() + offset = start_offset + for (vn, idx) in vnv.varname_to_index + is_linked = vnv.is_unconstrained[idx] + range = vnv.ranges[idx] .+ (start_offset - 1) + if AbstractPPL.getoptic(vn) === identity + all_iden_ranges = merge( + all_iden_ranges, + NamedTuple((AbstractPPL.getsym(vn) => RangeAndLinked(range, is_linked),)), + ) + else + all_ranges[vn] = RangeAndLinked(range, is_linked) + end + offset += length(range) + end + return all_iden_ranges, all_ranges, offset +end diff --git a/src/logdensityfunction.jl b/src/logdensityfunction.jl deleted file mode 100644 index 7c7438c9f..000000000 --- a/src/logdensityfunction.jl +++ /dev/null @@ -1,377 +0,0 @@ -using AbstractMCMC: AbstractModel -import DifferentiationInterface as DI - -""" - is_supported(adtype::AbstractADType) - -Check if the given AD type is formally supported by DynamicPPL. - -AD backends that are not formally supported can still be used for gradient -calculation; it is just that the DynamicPPL developers do not commit to -maintaining compatibility with them. -""" -is_supported(::ADTypes.AbstractADType) = false -is_supported(::ADTypes.AutoEnzyme) = true -is_supported(::ADTypes.AutoForwardDiff) = true -is_supported(::ADTypes.AutoMooncake) = true -is_supported(::ADTypes.AutoReverseDiff) = true - -""" - LogDensityFunction( - model::Model, - getlogdensity::Function=getlogjoint_internal, - varinfo::AbstractVarInfo=ldf_default_varinfo(model, getlogdensity); - adtype::Union{ADTypes.AbstractADType,Nothing}=nothing - ) - -A struct which contains a model, along with all the information necessary to: - - - calculate its log density at a given point; - - and if `adtype` is provided, calculate the gradient of the log density at - that point. - -This information can be extracted using the LogDensityProblems.jl interface, -specifically, using `LogDensityProblems.logdensity` and -`LogDensityProblems.logdensity_and_gradient`. If `adtype` is nothing, then only -`logdensity` is implemented. If `adtype` is a concrete AD backend type, then -`logdensity_and_gradient` is also implemented. - -There are several options for `getlogdensity` that are 'supported' out of the -box: - -- [`getlogjoint_internal`](@ref): calculate the log joint, including the - log-Jacobian term for any variables that have been linked in the provided - VarInfo. -- [`getlogprior_internal`](@ref): calculate the log prior, including the - log-Jacobian term for any variables that have been linked in the provided - VarInfo. -- [`getlogjoint`](@ref): calculate the log joint in the model space, ignoring - any effects of linking -- [`getlogprior`](@ref): calculate the log prior in the model space, ignoring - any effects of linking -- [`getloglikelihood`](@ref): calculate the log likelihood (this is unaffected - by linking, since transforms are only applied to random variables) - -!!! note - By default, `LogDensityFunction` uses `getlogjoint_internal`, i.e., the - result of `LogDensityProblems.logdensity(f, x)` will depend on whether the - `LogDensityFunction` was created with a linked or unlinked VarInfo. This - is done primarily to ease interoperability with MCMC samplers. - -If you provide one of these functions, a `VarInfo` will be automatically created -for you. If you provide a different function, you have to manually create a -VarInfo and pass it as the third argument. - -If the `adtype` keyword argument is provided, then this struct will also store -the adtype along with other information for efficient calculation of the -gradient of the log density. Note that preparing a `LogDensityFunction` with an -AD type `AutoBackend()` requires the AD backend itself to have been loaded -(e.g. with `import Backend`). - -# Fields -$(FIELDS) - -# Examples - -```jldoctest -julia> using Distributions - -julia> using DynamicPPL: LogDensityFunction, setaccs!! - -julia> @model function demo(x) - m ~ Normal() - x ~ Normal(m, 1) - end -demo (generic function with 2 methods) - -julia> model = demo(1.0); - -julia> f = LogDensityFunction(model); - -julia> # It implements the interface of LogDensityProblems.jl. - using LogDensityProblems - -julia> LogDensityProblems.logdensity(f, [0.0]) --2.3378770664093453 - -julia> LogDensityProblems.dimension(f) -1 - -julia> # By default it uses `VarInfo` under the hood, but this is not necessary. - f = LogDensityFunction(model, getlogjoint_internal, SimpleVarInfo(model)); - -julia> LogDensityProblems.logdensity(f, [0.0]) --2.3378770664093453 - -julia> # One can also specify evaluating e.g. the log prior only: - f_prior = LogDensityFunction(model, getlogprior); - -julia> LogDensityProblems.logdensity(f_prior, [0.0]) == logpdf(Normal(), 0.0) -true - -julia> # If we also need to calculate the gradient, we can specify an AD backend. - import ForwardDiff, ADTypes - -julia> f = LogDensityFunction(model, adtype=ADTypes.AutoForwardDiff()); - -julia> LogDensityProblems.logdensity_and_gradient(f, [0.0]) -(-2.3378770664093453, [1.0]) -``` -""" -struct LogDensityFunction{ - M<:Model,F<:Function,V<:AbstractVarInfo,AD<:Union{Nothing,ADTypes.AbstractADType} -} <: AbstractModel - "model used for evaluation" - model::M - "function to be called on `varinfo` to extract the log density. By default `getlogjoint_internal`." - getlogdensity::F - "varinfo used for evaluation. If not specified, generated with `ldf_default_varinfo`." - varinfo::V - "AD type used for evaluation of log density gradient. If `nothing`, no gradient can be calculated" - adtype::AD - "(internal use only) gradient preparation object for the model" - prep::Union{Nothing,DI.GradientPrep} - - function LogDensityFunction( - model::Model, - getlogdensity::Function=getlogjoint_internal, - varinfo::AbstractVarInfo=ldf_default_varinfo(model, getlogdensity); - adtype::Union{ADTypes.AbstractADType,Nothing}=nothing, - ) - if adtype === nothing - prep = nothing - else - # Make backend-specific tweaks to the adtype - adtype = tweak_adtype(adtype, model, varinfo) - # Check whether it is supported - is_supported(adtype) || - @warn "The AD backend $adtype is not officially supported by DynamicPPL. Gradient calculations may still work, but compatibility is not guaranteed." - # Get a set of dummy params to use for prep - x = [val for val in varinfo[:]] - if use_closure(adtype) - prep = DI.prepare_gradient( - LogDensityAt(model, getlogdensity, varinfo), adtype, x - ) - else - prep = DI.prepare_gradient( - logdensity_at, - adtype, - x, - DI.Constant(model), - DI.Constant(getlogdensity), - DI.Constant(varinfo), - ) - end - end - return new{typeof(model),typeof(getlogdensity),typeof(varinfo),typeof(adtype)}( - model, getlogdensity, varinfo, adtype, prep - ) - end -end - -""" - LogDensityFunction( - ldf::LogDensityFunction, - adtype::Union{Nothing,ADTypes.AbstractADType} - ) - -Create a new LogDensityFunction using the model and varinfo from the given -`ldf` argument, but with the AD type set to `adtype`. To remove the AD type, -pass `nothing` as the second argument. -""" -function LogDensityFunction( - f::LogDensityFunction, adtype::Union{Nothing,ADTypes.AbstractADType} -) - return if adtype === f.adtype - f # Avoid recomputing prep if not needed - else - LogDensityFunction(f.model, f.getlogdensity, f.varinfo; adtype=adtype) - end -end - -""" - ldf_default_varinfo(model::Model, getlogdensity::Function) - -Create the default AbstractVarInfo that should be used for evaluating the log density. - -Only the accumulators necesessary for `getlogdensity` will be used. -""" -function ldf_default_varinfo(::Model, getlogdensity::Function) - msg = """ - LogDensityFunction does not know what sort of VarInfo should be used when \ - `getlogdensity` is $getlogdensity. Please specify a VarInfo explicitly. - """ - return error(msg) -end - -ldf_default_varinfo(model::Model, ::typeof(getlogjoint_internal)) = VarInfo(model) - -function ldf_default_varinfo(model::Model, ::typeof(getlogjoint)) - return setaccs!!(VarInfo(model), (LogPriorAccumulator(), LogLikelihoodAccumulator())) -end - -function ldf_default_varinfo(model::Model, ::typeof(getlogprior_internal)) - return setaccs!!(VarInfo(model), (LogPriorAccumulator(), LogJacobianAccumulator())) -end - -function ldf_default_varinfo(model::Model, ::typeof(getlogprior)) - return setaccs!!(VarInfo(model), (LogPriorAccumulator(),)) -end - -function ldf_default_varinfo(model::Model, ::typeof(getloglikelihood)) - return setaccs!!(VarInfo(model), (LogLikelihoodAccumulator(),)) -end - -""" - logdensity_at( - x::AbstractVector, - model::Model, - getlogdensity::Function, - varinfo::AbstractVarInfo, - ) - -Evaluate the log density of the given `model` at the given parameter values -`x`, using the given `varinfo`. Note that the `varinfo` argument is provided -only for its structure, in the sense that the parameters from the vector `x` -are inserted into it, and its own parameters are discarded. `getlogdensity` is -the function that extracts the log density from the evaluated varinfo. -""" -function logdensity_at( - x::AbstractVector, model::Model, getlogdensity::Function, varinfo::AbstractVarInfo -) - varinfo_new = unflatten(varinfo, x) - varinfo_eval = last(evaluate!!(model, varinfo_new)) - return getlogdensity(varinfo_eval) -end - -""" - LogDensityAt{M<:Model,F<:Function,V<:AbstractVarInfo}( - model::M - getlogdensity::F, - varinfo::V - ) - -A callable struct that serves the same purpose as `x -> logdensity_at(x, model, -getlogdensity, varinfo)`. -""" -struct LogDensityAt{M<:Model,F<:Function,V<:AbstractVarInfo} - model::M - getlogdensity::F - varinfo::V -end -function (ld::LogDensityAt)(x::AbstractVector) - return logdensity_at(x, ld.model, ld.getlogdensity, ld.varinfo) -end - -### LogDensityProblems interface - -function LogDensityProblems.capabilities( - ::Type{<:LogDensityFunction{M,F,V,Nothing}} -) where {M,F,V} - return LogDensityProblems.LogDensityOrder{0}() -end -function LogDensityProblems.capabilities( - ::Type{<:LogDensityFunction{M,F,V,AD}} -) where {M,F,V,AD<:ADTypes.AbstractADType} - return LogDensityProblems.LogDensityOrder{1}() -end -function LogDensityProblems.logdensity(f::LogDensityFunction, x::AbstractVector) - return logdensity_at(x, f.model, f.getlogdensity, f.varinfo) -end -function LogDensityProblems.logdensity_and_gradient( - f::LogDensityFunction{M,F,V,AD}, x::AbstractVector -) where {M,F,V,AD<:ADTypes.AbstractADType} - f.prep === nothing && - error("Gradient preparation not available; this should not happen") - x = [val for val in x] # Concretise type - # Make branching statically inferrable, i.e. type-stable (even if the two - # branches happen to return different types) - return if use_closure(f.adtype) - DI.value_and_gradient( - LogDensityAt(f.model, f.getlogdensity, f.varinfo), f.prep, f.adtype, x - ) - else - DI.value_and_gradient( - logdensity_at, - f.prep, - f.adtype, - x, - DI.Constant(f.model), - DI.Constant(f.getlogdensity), - DI.Constant(f.varinfo), - ) - end -end - -# TODO: should we instead implement and call on `length(f.varinfo)` (at least in the cases where no sampler is involved)? -LogDensityProblems.dimension(f::LogDensityFunction) = length(getparams(f)) - -### Utils - -""" - tweak_adtype( - adtype::ADTypes.AbstractADType, - model::Model, - varinfo::AbstractVarInfo, - ) - -Return an 'optimised' form of the adtype. This is useful for doing -backend-specific optimisation of the adtype (e.g., for ForwardDiff, calculating -the chunk size: see the method override in `ext/DynamicPPLForwardDiffExt.jl`). -The model is passed as a parameter in case the optimisation depends on the -model. - -By default, this just returns the input unchanged. -""" -tweak_adtype(adtype::ADTypes.AbstractADType, ::Model, ::AbstractVarInfo) = adtype - -""" - use_closure(adtype::ADTypes.AbstractADType) - -In LogDensityProblems, we want to calculate the derivative of logdensity(f, x) -with respect to x, where f is the model (in our case LogDensityFunction) and is -a constant. However, DifferentiationInterface generally expects a -single-argument function g(x) to differentiate. - -There are two ways of dealing with this: - -1. Construct a closure over the model, i.e. let g = Base.Fix1(logdensity, f) - -2. Use a constant DI.Context. This lets us pass a two-argument function to DI, - as long as we also give it the 'inactive argument' (i.e. the model) wrapped - in `DI.Constant`. - -The relative performance of the two approaches, however, depends on the AD -backend used. Some benchmarks are provided here: -https://github.com/TuringLang/DynamicPPL.jl/issues/946#issuecomment-2931604829 - -This function is used to determine whether a given AD backend should use a -closure or a constant. If `use_closure(adtype)` returns `true`, then the -closure approach will be used. By default, this function returns `false`, i.e. -the constant approach will be used. -""" -use_closure(::ADTypes.AbstractADType) = true -use_closure(::ADTypes.AutoEnzyme) = false - -""" - getmodel(f) - -Return the `DynamicPPL.Model` wrapped in the given log-density function `f`. -""" -getmodel(f::DynamicPPL.LogDensityFunction) = f.model - -""" - setmodel(f, model[, adtype]) - -Set the `DynamicPPL.Model` in the given log-density function `f` to `model`. -""" -function setmodel(f::DynamicPPL.LogDensityFunction, model::DynamicPPL.Model) - return LogDensityFunction(model, f.getlogdensity, f.varinfo; adtype=f.adtype) -end - -""" - getparams(f::LogDensityFunction) - -Return the parameters of the wrapped varinfo as a vector. -""" -getparams(f::LogDensityFunction) = f.varinfo[:] diff --git a/src/model.jl b/src/model.jl index edb042ba9..2ba0c6cd4 100644 --- a/src/model.jl +++ b/src/model.jl @@ -427,7 +427,7 @@ Return the conditioned values in `model`. ```jldoctest julia> using Distributions -julia> using DynamicPPL: conditioned, contextualize +julia> using DynamicPPL: conditioned, contextualize, PrefixContext, ConditionContext julia> @model function demo() m ~ Normal() @@ -770,7 +770,7 @@ Return the fixed values in `model`. ```jldoctest julia> using Distributions -julia> using DynamicPPL: fixed, contextualize +julia> using DynamicPPL: fixed, contextualize, PrefixContext julia> @model function demo() m ~ Normal() @@ -986,9 +986,13 @@ Return the arguments and keyword arguments to be passed to the evaluator of the ) where {_F,argnames} unwrap_args = [ if is_splat_symbol(var) - :($matchingvalue(varinfo, model.args.$var)...) + :( + $matchingvalue( + $get_param_eltype(varinfo, model.context), model.args.$var + )... + ) else - :($matchingvalue(varinfo, model.args.$var)) + :($matchingvalue($get_param_eltype(varinfo, model.context), model.args.$var)) end for var in argnames ] return quote @@ -1006,6 +1010,30 @@ Return the arguments and keyword arguments to be passed to the evaluator of the end end +""" + get_param_eltype(varinfo::AbstractVarInfo, context::AbstractContext) + +Get the element type of the parameters being used to evaluate a model, using a `varinfo` +under the given `context`. For example, when evaluating a model with ForwardDiff AD, this +should return `ForwardDiff.Dual`. + +By default, this uses `eltype(varinfo)` which is slightly cursed. This relies on the fact +that typically, before evaluation, the parameters will have been inserted into the VarInfo's +metadata field. + +For `InitContext`, it's quite different: because `InitContext` is responsible for supplying +the parameters, we can avoid using `eltype(varinfo)` and instead query the parameters inside +it. See the docstring of `get_param_eltype(strategy::AbstractInitStrategy)` for more +explanation. +""" +function get_param_eltype(vi::AbstractVarInfo, ctx::AbstractParentContext) + return get_param_eltype(vi, DynamicPPL.childcontext(ctx)) +end +get_param_eltype(vi::AbstractVarInfo, ::AbstractContext) = eltype(vi) +function get_param_eltype(::AbstractVarInfo, ctx::InitContext) + return get_param_eltype(ctx.strategy) +end + """ getargnames(model::Model) @@ -1034,8 +1062,11 @@ Base.nameof(model::Model{<:Function}) = nameof(model.f) Generate a sample of type `T` from the prior distribution of the `model`. """ function Base.rand(rng::Random.AbstractRNG, ::Type{T}, model::Model) where {T} - x = last(init!!(rng, model, SimpleVarInfo{Float64}(OrderedDict{VarName,Any}()))) - return values_as(x, T) + # TODO(penelopeysm): This can be done with an accumulator instead. For + # T = Dict, ValuesAsInModelAcc can already do it. For T = NamedTuple we + # would just need a similar accumulator that collects into a NamedTuple + # rather than a Dict. + return values_as(VarInfo(rng, model), T) end # Default RNG and type @@ -1107,11 +1138,6 @@ function predict end Execute `model` with variables `keys` set to `values` and return the values returned by the `model`. - returned(model::Model, values, keys) - -Execute `model` with variables `keys` set to `values` and return the values returned by the `model`. -This method is deprecated; use the NamedTuple or AbstractDict version instead. - # Example ```jldoctest julia> using DynamicPPL, Distributions @@ -1132,15 +1158,115 @@ julia> returned(model, Dict{VarName,Float64}(@varname(m) => 2.0)) ``` """ function returned(model::Model, parameters::Union{NamedTuple,AbstractDict{<:VarName}}) - vi = DynamicPPL.setaccs!!(VarInfo(), ()) - # Note: we can't use `fix(model, parameters)` because - # https://github.com/TuringLang/DynamicPPL.jl/issues/1097 - # Use `nothing` as the fallback to ensure that any missing parameters cause an error - ctx = InitContext(Random.default_rng(), InitFromParams(parameters, nothing)) - new_model = setleafcontext(model, ctx) - # We can't use new_model() because that overwrites it with an InitContext of its own. - return first(evaluate!!(new_model, vi)) + accs = AccumulatorTuple() + retval, _ = DynamicPPL.fast_evaluate!!(model, InitFromParams(parameters, nothing), accs) + return retval +end + +""" + logjoint(model::Model, θ::Union{NamedTuple,AbstractDict}) + +Return the log joint probability of variables `θ` for the probabilistic `model`. + +See [`logprior`](@ref) and [`loglikelihood`](@ref). + +# Examples +```jldoctest; setup=:(using Distributions) +julia> @model function demo(x) + m ~ Normal() + for i in eachindex(x) + x[i] ~ Normal(m, 1.0) + end + end +demo (generic function with 2 methods) + +julia> # Using a `NamedTuple`. + logjoint(demo([1.0]), (m = 100.0, )) +-9902.33787706641 + +julia> # Using a `OrderedDict`. + logjoint(demo([1.0]), OrderedDict(@varname(m) => 100.0)) +-9902.33787706641 + +julia> # Truth. + logpdf(Normal(100.0, 1.0), 1.0) + logpdf(Normal(), 100.0) +-9902.33787706641 +``` +""" +function logjoint(model::Model, θ::Union{NamedTuple,AbstractDict}) + accs = AccumulatorTuple((LogPriorAccumulator(), LogLikelihoodAccumulator())) + _, vi = DynamicPPL.fast_evaluate!!(model, InitFromParams(θ, nothing), accs) + return getlogjoint(vi) +end + +""" + logprior(model::Model, θ::Union{NamedTuple,AbstractDict}) + +Return the log prior probability of variables `θ` for the probabilistic `model`. + +See also [`logjoint`](@ref) and [`loglikelihood`](@ref). + +# Examples +```jldoctest; setup=:(using Distributions) +julia> @model function demo(x) + m ~ Normal() + for i in eachindex(x) + x[i] ~ Normal(m, 1.0) + end + end +demo (generic function with 2 methods) + +julia> # Using a `NamedTuple`. + logprior(demo([1.0]), (m = 100.0, )) +-5000.918938533205 + +julia> # Using a `OrderedDict`. + logprior(demo([1.0]), OrderedDict(@varname(m) => 100.0)) +-5000.918938533205 + +julia> # Truth. + logpdf(Normal(), 100.0) +-5000.918938533205 +``` +""" +function logprior(model::Model, θ::Union{NamedTuple,AbstractDict}) + accs = AccumulatorTuple((LogPriorAccumulator(),)) + _, vi = DynamicPPL.fast_evaluate!!(model, InitFromParams(θ, nothing), accs) + return getlogprior(vi) +end + +""" + loglikelihood(model::Model, θ::Union{NamedTuple,AbstractDict}) + +Return the log likelihood of variables `θ` for the probabilistic `model`. + +See also [`logjoint`](@ref) and [`logprior`](@ref). + +# Examples +```jldoctest; setup=:(using Distributions) +julia> @model function demo(x) + m ~ Normal() + for i in eachindex(x) + x[i] ~ Normal(m, 1.0) + end + end +demo (generic function with 2 methods) + +julia> # Using a `NamedTuple`. + loglikelihood(demo([1.0]), (m = 100.0, )) +-4901.418938533205 + +julia> # Using a `OrderedDict`. + loglikelihood(demo([1.0]), OrderedDict(@varname(m) => 100.0)) +-4901.418938533205 + +julia> # Truth. + logpdf(Normal(100.0, 1.0), 1.0) +-4901.418938533205 +``` +""" +function Distributions.loglikelihood(model::Model, θ::Union{NamedTuple,AbstractDict}) + accs = AccumulatorTuple((LogLikelihoodAccumulator(),)) + _, vi = DynamicPPL.fast_evaluate!!(model, InitFromParams(θ, nothing), accs) + return getloglikelihood(vi) end -Base.@deprecate returned(model::Model, values, keys) returned( - model, NamedTuple{keys}(values) -) diff --git a/src/onlyaccs.jl b/src/onlyaccs.jl new file mode 100644 index 000000000..940f23124 --- /dev/null +++ b/src/onlyaccs.jl @@ -0,0 +1,42 @@ +""" + OnlyAccsVarInfo + +This is a wrapper around an `AccumulatorTuple` that implements the minimal `AbstractVarInfo` +interface to work with the `tilde_assume!!` and `tilde_observe!!` functions for +`InitContext`. + +Note that this does not implement almost every other AbstractVarInfo interface function, and +so using this with a different leaf context such as `DefaultContext` will result in errors. + +Conceptually, one can also think of this as a VarInfo that doesn't contain a metadata field. +This is also why it only works with `InitContext`: in this case, the parameters used for +evaluation are supplied by the context instead of the metadata. +""" +struct OnlyAccsVarInfo{Accs<:AccumulatorTuple} <: AbstractVarInfo + accs::Accs +end +OnlyAccsVarInfo() = OnlyAccsVarInfo(default_accumulators()) +function OnlyAccsVarInfo(accs::NTuple{N,AbstractAccumulator}) where {N} + return OnlyAccsVarInfo(AccumulatorTuple(accs)) +end + +# Minimal AbstractVarInfo interface +DynamicPPL.maybe_invlink_before_eval!!(vi::OnlyAccsVarInfo, ::Model) = vi +DynamicPPL.getaccs(vi::OnlyAccsVarInfo) = vi.accs +DynamicPPL.setaccs!!(::OnlyAccsVarInfo, accs::AccumulatorTuple) = OnlyAccsVarInfo(accs) + +# Ideally, we'd define this together with InitContext, but alas that file comes way before +# this one, and sorting out the include order is a pain. +function tilde_assume!!( + ctx::InitContext, + dist::Distribution, + vn::VarName, + vi::Union{OnlyAccsVarInfo,ThreadSafeVarInfo{<:OnlyAccsVarInfo}}, +) + # For OnlyAccsVarInfo, since we don't need to write into the VarInfo, we can + # cut out a lot of the code above. + val, transform = init(ctx.rng, vn, dist, ctx.strategy) + x, inv_logjac = with_logabsdet_jacobian(transform, val) + vi = accumulate_assume!!(vi, x, -inv_logjac, vn, dist) + return x, vi +end diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl deleted file mode 100644 index 434480be6..000000000 --- a/src/simple_varinfo.jl +++ /dev/null @@ -1,655 +0,0 @@ -""" - $(TYPEDEF) - -A simple wrapper of the parameters with a `logp` field for -accumulation of the logdensity. - -Currently only implemented for `NT<:NamedTuple` and `NT<:AbstractDict`. - -# Fields -$(FIELDS) - -# Notes -The major differences between this and `NTVarInfo` are: -1. `SimpleVarInfo` does not require linearization. -2. `SimpleVarInfo` can use more efficient bijectors. -3. `SimpleVarInfo` is only type-stable if `NT<:NamedTuple` and either - a) no indexing is used in tilde-statements, or - b) the values have been specified with the correct shapes. - -# Examples -## General usage -```jldoctest simplevarinfo-general; setup=:(using Distributions) -julia> using StableRNGs - -julia> @model function demo() - m ~ Normal() - x = Vector{Float64}(undef, 2) - for i in eachindex(x) - x[i] ~ Normal() - end - return x - end -demo (generic function with 2 methods) - -julia> m = demo(); - -julia> rng = StableRNG(42); - -julia> # In the `NamedTuple` version we need to provide the place-holder values for - # the variables which are using "containers", e.g. `Array`. - # In this case, this means that we need to specify `x` but not `m`. - _, vi = DynamicPPL.init!!(rng, m, SimpleVarInfo((x = ones(2), ))); - -julia> # (✓) Vroom, vroom! FAST!!! - vi[@varname(x[1])] -0.4471218424633827 - -julia> # We can also access arbitrary varnames pointing to `x`, e.g. - vi[@varname(x)] -2-element Vector{Float64}: - 0.4471218424633827 - 1.3736306979834252 - -julia> vi[@varname(x[1:2])] -2-element Vector{Float64}: - 0.4471218424633827 - 1.3736306979834252 - -julia> # (×) If we don't provide the container... - _, vi = DynamicPPL.init!!(rng, m, SimpleVarInfo()); -ERROR: FieldError: type NamedTuple has no field `x`, available fields: `m` -[...] - -julia> # If one does not know the varnames, we can use a `OrderedDict` instead. - _, vi = DynamicPPL.init!!(rng, m, SimpleVarInfo{Float64}(OrderedDict{VarName,Any}())); - -julia> # (✓) Sort of fast, but only possible at runtime. - vi[@varname(x[1])] --1.019202452456547 - -julia> # In addtion, we can only access varnames as they appear in the model! - vi[@varname(x)] -ERROR: x was not found in the dictionary provided -[...] - -julia> vi[@varname(x[1:2])] -ERROR: x[1:2] was not found in the dictionary provided -[...] -``` - -_Technically_, it's possible to use any implementation of `AbstractDict` in place of -`OrderedDict`, but `OrderedDict` ensures that certain operations, e.g. linearization/flattening -of the values in the varinfo, are consistent between evaluations. Hence `OrderedDict` is -the preferred implementation of `AbstractDict` to use here. - -You can also sample in _transformed_ space: - -```jldoctest simplevarinfo-general -julia> @model demo_constrained() = x ~ Exponential() -demo_constrained (generic function with 2 methods) - -julia> m = demo_constrained(); - -julia> _, vi = DynamicPPL.init!!(rng, m, SimpleVarInfo()); - -julia> vi[@varname(x)] # (✓) 0 ≤ x < ∞ -1.8632965762164932 - -julia> _, vi = DynamicPPL.init!!(rng, m, DynamicPPL.set_transformed!!(SimpleVarInfo(), true)); - -julia> vi[@varname(x)] # (✓) -∞ < x < ∞ --0.21080155351918753 - -julia> xs = [last(DynamicPPL.init!!(rng, m, DynamicPPL.set_transformed!!(SimpleVarInfo(), true)))[@varname(x)] for i = 1:10]; - -julia> any(xs .< 0) # (✓) Positive probability mass on negative numbers! -true - -julia> # And with `OrderedDict` of course! - _, vi = DynamicPPL.init!!(rng, m, DynamicPPL.set_transformed!!(SimpleVarInfo(OrderedDict{VarName,Any}()), true)); - -julia> vi[@varname(x)] # (✓) -∞ < x < ∞ -0.6225185067787314 - -julia> xs = [last(DynamicPPL.init!!(rng, m, DynamicPPL.set_transformed!!(SimpleVarInfo(), true)))[@varname(x)] for i = 1:10]; - -julia> any(xs .< 0) # (✓) Positive probability mass on negative numbers! -true -``` - -Evaluation in transformed space of course also works: - -```jldoctest simplevarinfo-general -julia> vi = DynamicPPL.set_transformed!!(SimpleVarInfo((x = -1.0,)), true) -Transformed SimpleVarInfo((x = -1.0,), (LogPrior = LogPriorAccumulator(0.0), LogJacobian = LogJacobianAccumulator(0.0), LogLikelihood = LogLikelihoodAccumulator(0.0))) - -julia> # (✓) Positive probability mass on negative numbers! - getlogjoint_internal(last(DynamicPPL.evaluate!!(m, vi))) --1.3678794411714423 - -julia> # While if we forget to indicate that it's transformed: - vi = DynamicPPL.set_transformed!!(SimpleVarInfo((x = -1.0,)), false) -SimpleVarInfo((x = -1.0,), (LogPrior = LogPriorAccumulator(0.0), LogJacobian = LogJacobianAccumulator(0.0), LogLikelihood = LogLikelihoodAccumulator(0.0))) - -julia> # (✓) No probability mass on negative numbers! - getlogjoint_internal(last(DynamicPPL.evaluate!!(m, vi))) --Inf -``` - -## Indexing -Using `NamedTuple` as underlying storage. - -```jldoctest -julia> svi_nt = SimpleVarInfo((m = (a = [1.0], ), )); - -julia> svi_nt[@varname(m)] -(a = [1.0],) - -julia> svi_nt[@varname(m.a)] -1-element Vector{Float64}: - 1.0 - -julia> svi_nt[@varname(m.a[1])] -1.0 - -julia> svi_nt[@varname(m.a[2])] -ERROR: BoundsError: attempt to access 1-element Vector{Float64} at index [2] -[...] - -julia> svi_nt[@varname(m.b)] -ERROR: FieldError: type NamedTuple has no field `b`, available fields: `a` -[...] -``` - -Using `OrderedDict` as underlying storage. -```jldoctest -julia> svi_dict = SimpleVarInfo(OrderedDict(@varname(m) => (a = [1.0], ))); - -julia> svi_dict[@varname(m)] -(a = [1.0],) - -julia> svi_dict[@varname(m.a)] -1-element Vector{Float64}: - 1.0 - -julia> svi_dict[@varname(m.a[1])] -1.0 - -julia> svi_dict[@varname(m.a[2])] -ERROR: m.a[2] was not found in the dictionary provided -[...] - -julia> svi_dict[@varname(m.b)] -ERROR: m.b was not found in the dictionary provided -[...] -``` -""" -struct SimpleVarInfo{NT,Accs<:AccumulatorTuple where {N},C<:AbstractTransformation} <: - AbstractVarInfo - "underlying representation of the realization represented" - values::NT - "tuple of accumulators for things like log prior and log likelihood" - accs::Accs - "represents whether it assumes variables to be transformed" - transformation::C -end - -function Base.:(==)(vi1::SimpleVarInfo, vi2::SimpleVarInfo) - return vi1.values == vi2.values && - vi1.accs == vi2.accs && - vi1.transformation == vi2.transformation -end - -transformation(vi::SimpleVarInfo) = vi.transformation - -function SimpleVarInfo(values, accs) - return SimpleVarInfo(values, accs, NoTransformation()) -end -function SimpleVarInfo{T}(values) where {T<:Real} - return SimpleVarInfo(values, default_accumulators(T)) -end -function SimpleVarInfo(values) - return SimpleVarInfo{LogProbType}(values) -end -function SimpleVarInfo(values::Union{<:NamedTuple,<:AbstractDict{<:VarName}}) - return if isempty(values) - # Can't infer from values, so we just use default. - SimpleVarInfo{LogProbType}(values) - else - # Infer from `values`. - SimpleVarInfo{float_type_with_fallback(infer_nested_eltype(typeof(values)))}(values) - end -end - -# Using `kwargs` to specify the values. -function SimpleVarInfo{T}(; kwargs...) where {T<:Real} - return SimpleVarInfo{T}(NamedTuple(kwargs)) -end -function SimpleVarInfo(; kwargs...) - return SimpleVarInfo(NamedTuple(kwargs)) -end - -# Constructor from `Model`. -function SimpleVarInfo{T}( - rng::Random.AbstractRNG, - model::Model, - init_strategy::AbstractInitStrategy=InitFromPrior(), -) where {T<:Real} - return last(init!!(rng, model, SimpleVarInfo{T}(), init_strategy)) -end -function SimpleVarInfo{T}( - model::Model, init_strategy::AbstractInitStrategy=InitFromPrior() -) where {T<:Real} - return SimpleVarInfo{T}(Random.default_rng(), model, init_strategy) -end -# Constructors without type param -function SimpleVarInfo( - rng::Random.AbstractRNG, - model::Model, - init_strategy::AbstractInitStrategy=InitFromPrior(), -) - return SimpleVarInfo{LogProbType}(rng, model, init_strategy) -end -function SimpleVarInfo(model::Model, init_strategy::AbstractInitStrategy=InitFromPrior()) - return SimpleVarInfo{LogProbType}(Random.default_rng(), model, init_strategy) -end - -# Constructor from `VarInfo`. -function SimpleVarInfo(vi::NTVarInfo, ::Type{D}) where {D} - values = values_as(vi, D) - return SimpleVarInfo(values, copy(getaccs(vi))) -end -function SimpleVarInfo{T}(vi::NTVarInfo, ::Type{D}) where {T<:Real,D} - values = values_as(vi, D) - accs = map(acc -> convert_eltype(T, acc), getaccs(vi)) - return SimpleVarInfo(values, accs) -end - -function untyped_simple_varinfo(model::Model) - varinfo = SimpleVarInfo(OrderedDict{VarName,Any}()) - return last(init!!(model, varinfo)) -end - -function typed_simple_varinfo(model::Model) - varinfo = SimpleVarInfo{Float64}() - return last(init!!(model, varinfo)) -end - -function unflatten(svi::SimpleVarInfo, x::AbstractVector) - vals = unflatten(svi.values, x) - # TODO(mhauru) See comment in unflatten in src/varinfo.jl for why this conversion is - # required but undesireable. - # The below line is finicky for type stability. For instance, assigning the eltype to - # convert to into an intermediate variable makes this unstable (constant propagation) - # fails. Take care when editing. - accs = map( - acc -> convert_eltype(float_type_with_fallback(eltype(x)), acc), getaccs(svi) - ) - return SimpleVarInfo(vals, accs, svi.transformation) -end - -function BangBang.empty!!(vi::SimpleVarInfo) - return resetaccs!!(Accessors.@set vi.values = empty!!(vi.values)) -end -Base.isempty(vi::SimpleVarInfo) = isempty(vi.values) - -getaccs(vi::SimpleVarInfo) = vi.accs -setaccs!!(vi::SimpleVarInfo, accs::AccumulatorTuple) = Accessors.@set vi.accs = accs - -""" - keys(vi::SimpleVarInfo) - -Return an iterator of keys present in `vi`. -""" -Base.keys(vi::SimpleVarInfo) = keys(vi.values) -Base.keys(vi::SimpleVarInfo{<:NamedTuple}) = map(k -> VarName{k}(), keys(vi.values)) - -function Base.show(io::IO, mime::MIME"text/plain", svi::SimpleVarInfo) - if !(svi.transformation isa NoTransformation) - print(io, "Transformed ") - end - - return print(io, "SimpleVarInfo(", svi.values, ", ", repr(mime, getaccs(svi)), ")") -end - -function Base.getindex(vi::SimpleVarInfo, vn::VarName, dist::Distribution) - return from_maybe_linked_internal(vi, vn, dist, getindex(vi, vn)) -end -function Base.getindex(vi::SimpleVarInfo, vns::Vector{<:VarName}, dist::Distribution) - vals_linked = mapreduce(vcat, vns) do vn - getindex(vi, vn, dist) - end - return recombine(dist, vals_linked, length(vns)) -end - -Base.getindex(vi::SimpleVarInfo, vn::VarName) = getindex_internal(vi, vn) - -# `SimpleVarInfo` doesn't necessarily vectorize, so we can have arrays other than -# just `Vector`. -function Base.getindex(vi::SimpleVarInfo, vns::AbstractArray{<:VarName}) - return map(Base.Fix1(getindex, vi), vns) -end -# HACK: Needed to disambiguate. -Base.getindex(vi::SimpleVarInfo, vns::Vector{<:VarName}) = map(Base.Fix1(getindex, vi), vns) - -Base.getindex(svi::SimpleVarInfo, ::Colon) = values_as(svi, Vector) - -getindex_internal(vi::SimpleVarInfo, vn::VarName) = get(vi.values, vn) -# `AbstractDict` -function getindex_internal( - vi::SimpleVarInfo{<:Union{AbstractDict,VarNamedVector}}, vn::VarName -) - return getvalue(vi.values, vn) -end - -Base.haskey(vi::SimpleVarInfo, vn::VarName) = hasvalue(vi.values, vn) - -function BangBang.setindex!!(vi::SimpleVarInfo, val, vn::VarName) - # For `NamedTuple` we treat the symbol in `vn` as the _property_ to set. - return Accessors.@set vi.values = set!!(vi.values, vn, val) -end - -# TODO: Specialize to handle certain cases, e.g. a collection of `VarName` with -# same symbol and same type of, say, `IndexLens`, for improved `.~` performance. -function BangBang.setindex!!(vi::SimpleVarInfo, vals, vns::AbstractVector{<:VarName}) - for (vn, val) in zip(vns, vals) - vi = BangBang.setindex!!(vi, val, vn) - end - return vi -end - -function BangBang.setindex!!(vi::SimpleVarInfo{<:AbstractDict}, val, vn::VarName) - # For dictlike objects, we treat the entire `vn` as a _key_ to set. - dict = values_as(vi) - # Attempt to split into `parent` and `child` optic. - parent, child, issuccess = splitoptic(getoptic(vn)) do optic - o = optic === nothing ? identity : optic - haskey(dict, VarName{getsym(vn)}(o)) - end - # When combined with `VarInfo`, `nothing` is equivalent to `identity`. - keyoptic = parent === nothing ? identity : parent - - dict_new = if !issuccess - # Split doesn't exist ⟹ we're working with a new key. - BangBang.setindex!!(dict, val, vn) - else - # Split exists ⟹ trying to set an existing key. - vn_key = VarName{getsym(vn)}(keyoptic) - BangBang.setindex!!(dict, set!!(dict[vn_key], child, val), vn_key) - end - return Accessors.@set vi.values = dict_new -end - -# `NamedTuple` -function BangBang.push!!( - vi::SimpleVarInfo{<:NamedTuple}, ::VarName{sym,typeof(identity)}, value, ::Distribution -) where {sym} - return Accessors.@set vi.values = merge(vi.values, NamedTuple{(sym,)}((value,))) -end -function BangBang.push!!( - vi::SimpleVarInfo{<:NamedTuple}, vn::VarName{sym}, value, ::Distribution -) where {sym} - return Accessors.@set vi.values = set!!(vi.values, vn, value) -end - -# `AbstractDict` -function BangBang.push!!( - vi::SimpleVarInfo{<:AbstractDict}, vn::VarName, value, ::Distribution -) - vi.values[vn] = value - return vi -end - -function BangBang.push!!( - vi::SimpleVarInfo{<:VarNamedVector}, vn::VarName, value, ::Distribution -) - # The semantics of push!! for SimpleVarInfo and VarNamedVector are different. For - # SimpleVarInfo, push!! allows the key to exist already, for VarNamedVector it does not. - # Hence we need to call update!! here, which has the same semantics as push!! does for - # SimpleVarInfo. - return Accessors.@set vi.values = setindex!!(vi.values, value, vn) -end - -const SimpleOrThreadSafeSimple{T,V,C} = Union{ - SimpleVarInfo{T,V,C},ThreadSafeVarInfo{<:SimpleVarInfo{T,V,C}} -} - -# Necessary for `matchingvalue` to work properly. -Base.eltype(::SimpleOrThreadSafeSimple{<:Any,V}) where {V} = V - -# `subset` -function subset(varinfo::SimpleVarInfo, vns::AbstractVector{<:VarName}) - return SimpleVarInfo( - _subset(varinfo.values, vns), map(copy, getaccs(varinfo)), varinfo.transformation - ) -end - -function _subset(x::AbstractDict, vns::AbstractVector{VN}) where {VN<:VarName} - vns_present = collect(keys(x)) - vns_found = filter( - vn_present -> any(subsumes(vn, vn_present) for vn in vns), vns_present - ) - C = ConstructionBase.constructorof(typeof(x)) - if isempty(vns_found) - return C() - else - return C(vn => x[vn] for vn in vns_found) - end -end - -function _subset(x::NamedTuple, vns) - # NOTE: Here we can only handle `vns` that contain `identity` as optic. - if any(Base.Fix1(!==, identity) ∘ getoptic, vns) - throw( - ArgumentError( - "Cannot subset `NamedTuple` with non-`identity` `VarName`. " * - "For example, `@varname(x)` is allowed, but `@varname(x[1])` is not.", - ), - ) - end - - syms = map(getsym, vns) - x_syms = filter(Base.Fix2(in, syms), keys(x)) - return NamedTuple{Tuple(x_syms)}(Tuple(map(Base.Fix1(getindex, x), x_syms))) -end - -_subset(x::VarNamedVector, vns) = subset(x, vns) - -# `merge` -function Base.merge(varinfo_left::SimpleVarInfo, varinfo_right::SimpleVarInfo) - values = merge(varinfo_left.values, varinfo_right.values) - accs = map(copy, getaccs(varinfo_right)) - transformation = merge_transformations( - varinfo_left.transformation, varinfo_right.transformation - ) - return SimpleVarInfo(values, accs, transformation) -end - -function set_transformed!!(vi::SimpleVarInfo, trans) - return set_transformed!!(vi, trans ? DynamicTransformation() : NoTransformation()) -end -function set_transformed!!(vi::SimpleVarInfo, transformation::AbstractTransformation) - return Accessors.@set vi.transformation = transformation -end -function set_transformed!!(vi::ThreadSafeVarInfo{<:SimpleVarInfo}, trans) - return Accessors.@set vi.varinfo = set_transformed!!(vi.varinfo, trans) -end -function set_transformed!!(vi::SimpleOrThreadSafeSimple, trans::Bool, ::VarName) - # We keep this method around just to obey the AbstractVarInfo interface. - # However, note that this would only be a valid operation if it would be a - # no-op, which we check here. - if trans != is_transformed(vi) - error( - "Individual variables in SimpleVarInfo cannot have different `set_transformed` statuses.", - ) - end - return vi -end - -is_transformed(vi::SimpleVarInfo) = !(vi.transformation isa NoTransformation) -is_transformed(vi::SimpleVarInfo, ::VarName) = is_transformed(vi) -function is_transformed(vi::ThreadSafeVarInfo{<:SimpleVarInfo}, vn::VarName) - return is_transformed(vi.varinfo, vn) -end -is_transformed(vi::ThreadSafeVarInfo{<:SimpleVarInfo}) = is_transformed(vi.varinfo) - -values_as(vi::SimpleVarInfo) = vi.values -values_as(vi::SimpleVarInfo{<:T}, ::Type{T}) where {T} = vi.values -function values_as(vi::SimpleVarInfo, ::Type{Vector}) - isempty(vi) && return Any[] - return mapreduce(tovec, vcat, values(vi.values)) -end -function values_as(vi::SimpleVarInfo, ::Type{D}) where {D<:AbstractDict} - return ConstructionBase.constructorof(D)(zip(keys(vi), values(vi.values))) -end -function values_as(vi::SimpleVarInfo{<:AbstractDict}, ::Type{NamedTuple}) - return NamedTuple((Symbol(k), v) for (k, v) in vi.values) -end -function values_as(vi::SimpleVarInfo, ::Type{T}) where {T} - return values_as(vi.values, T) -end - -""" - logjoint(model::Model, θ::Union{NamedTuple,AbstractDict}) - -Return the log joint probability of variables `θ` for the probabilistic `model`. - -See [`logprior`](@ref) and [`loglikelihood`](@ref). - -# Examples -```jldoctest; setup=:(using Distributions) -julia> @model function demo(x) - m ~ Normal() - for i in eachindex(x) - x[i] ~ Normal(m, 1.0) - end - end -demo (generic function with 2 methods) - -julia> # Using a `NamedTuple`. - logjoint(demo([1.0]), (m = 100.0, )) --9902.33787706641 - -julia> # Using a `OrderedDict`. - logjoint(demo([1.0]), OrderedDict(@varname(m) => 100.0)) --9902.33787706641 - -julia> # Truth. - logpdf(Normal(100.0, 1.0), 1.0) + logpdf(Normal(), 100.0) --9902.33787706641 -``` -""" -logjoint(model::Model, θ::Union{NamedTuple,AbstractDict}) = - logjoint(model, SimpleVarInfo(θ)) - -""" - logprior(model::Model, θ::Union{NamedTuple,AbstractDict}) - -Return the log prior probability of variables `θ` for the probabilistic `model`. - -See also [`logjoint`](@ref) and [`loglikelihood`](@ref). - -# Examples -```jldoctest; setup=:(using Distributions) -julia> @model function demo(x) - m ~ Normal() - for i in eachindex(x) - x[i] ~ Normal(m, 1.0) - end - end -demo (generic function with 2 methods) - -julia> # Using a `NamedTuple`. - logprior(demo([1.0]), (m = 100.0, )) --5000.918938533205 - -julia> # Using a `OrderedDict`. - logprior(demo([1.0]), OrderedDict(@varname(m) => 100.0)) --5000.918938533205 - -julia> # Truth. - logpdf(Normal(), 100.0) --5000.918938533205 -``` -""" -logprior(model::Model, θ::Union{NamedTuple,AbstractDict}) = - logprior(model, SimpleVarInfo(θ)) - -""" - loglikelihood(model::Model, θ::Union{NamedTuple,AbstractDict}) - -Return the log likelihood of variables `θ` for the probabilistic `model`. - -See also [`logjoint`](@ref) and [`logprior`](@ref). - -# Examples -```jldoctest; setup=:(using Distributions) -julia> @model function demo(x) - m ~ Normal() - for i in eachindex(x) - x[i] ~ Normal(m, 1.0) - end - end -demo (generic function with 2 methods) - -julia> # Using a `NamedTuple`. - loglikelihood(demo([1.0]), (m = 100.0, )) --4901.418938533205 - -julia> # Using a `OrderedDict`. - loglikelihood(demo([1.0]), OrderedDict(@varname(m) => 100.0)) --4901.418938533205 - -julia> # Truth. - logpdf(Normal(100.0, 1.0), 1.0) --4901.418938533205 -``` -""" -Distributions.loglikelihood(model::Model, θ::Union{NamedTuple,AbstractDict}) = - loglikelihood(model, SimpleVarInfo(θ)) - -# Allow usage of `NamedBijector` too. -function link!!( - t::StaticTransformation{<:Bijectors.NamedTransform}, - vi::SimpleVarInfo{<:NamedTuple}, - ::Model, -) - b = inverse(t.bijector) - x = vi.values - y, logjac = with_logabsdet_jacobian(b, x) - vi_new = Accessors.@set(vi.values = y) - if hasacc(vi_new, Val(:LogJacobian)) - vi_new = acclogjac!!(vi_new, logjac) - end - return set_transformed!!(vi_new, t) -end - -function invlink!!( - t::StaticTransformation{<:Bijectors.NamedTransform}, - vi::SimpleVarInfo{<:NamedTuple}, - ::Model, -) - b = t.bijector - y = vi.values - x, inv_logjac = with_logabsdet_jacobian(b, y) - vi_new = Accessors.@set(vi.values = x) - # Mildly confusing: we need to _add_ the logjac of the inverse transform, - # because we are trying to remove the logjac of the forward transform - # that was previously accumulated when linking. - if hasacc(vi_new, Val(:LogJacobian)) - vi_new = acclogjac!!(vi_new, inv_logjac) - end - return set_transformed!!(vi_new, NoTransformation()) -end - -# With `SimpleVarInfo`, when we're not working with linked variables, there's no need to do anything. -from_internal_transform(vi::SimpleVarInfo, ::VarName) = identity -from_internal_transform(vi::SimpleVarInfo, ::VarName, dist) = identity -# TODO: Should the following methods specialize on the case where we have a `StaticTransformation{<:Bijectors.NamedTransform}`? -from_linked_internal_transform(vi::SimpleVarInfo, ::VarName) = identity -function from_linked_internal_transform(vi::SimpleVarInfo, ::VarName, dist) - return invlink_transform(dist) -end - -has_varnamedvector(vi::SimpleVarInfo) = vi.values isa VarNamedVector diff --git a/src/test_utils/ad.jl b/src/test_utils/ad.jl index a49ffd18b..8ee850877 100644 --- a/src/test_utils/ad.jl +++ b/src/test_utils/ad.jl @@ -5,7 +5,13 @@ using Chairmarks: @be import DifferentiationInterface as DI using DocStringExtensions using DynamicPPL: - Model, LogDensityFunction, VarInfo, AbstractVarInfo, getlogjoint_internal, link + DynamicPPL, + Model, + LogDensityFunction, + VarInfo, + AbstractVarInfo, + getlogjoint_internal, + link using LogDensityProblems: logdensity, logdensity_and_gradient using Random: AbstractRNG, default_rng using Statistics: median @@ -298,8 +304,10 @@ function run_ad( # Benchmark grad_time, primal_time = if benchmark - primal_benchmark = @be (ldf, params) logdensity(_[1], _[2]) - grad_benchmark = @be (ldf, params) logdensity_and_gradient(_[1], _[2]) + logdensity(ldf, params) # Warm-up + primal_benchmark = @be logdensity($ldf, $params) + logdensity_and_gradient(ldf, params) # Warm-up + grad_benchmark = @be logdensity_and_gradient($ldf, $params) median_primal = median(primal_benchmark).time median_grad = median(grad_benchmark).time r(f) = round(f; sigdigits=4) diff --git a/src/test_utils/contexts.jl b/src/test_utils/contexts.jl index aae2e4ec6..c48d2ddfd 100644 --- a/src/test_utils/contexts.jl +++ b/src/test_utils/contexts.jl @@ -4,11 +4,10 @@ # Utilities for testing contexts. # Dummy context to test nested behaviors. -struct TestParentContext{C<:DynamicPPL.AbstractContext} <: DynamicPPL.AbstractContext +struct TestParentContext{C<:DynamicPPL.AbstractContext} <: DynamicPPL.AbstractParentContext context::C end TestParentContext() = TestParentContext(DefaultContext()) -DynamicPPL.NodeTrait(::TestParentContext) = DynamicPPL.IsParent() DynamicPPL.childcontext(context::TestParentContext) = context.context DynamicPPL.setchildcontext(::TestParentContext, child) = TestParentContext(child) function Base.show(io::IO, c::TestParentContext) @@ -25,19 +24,13 @@ This method ensures that `context` - Correctly implements the tilde-pipeline. """ function test_context(context::DynamicPPL.AbstractContext, model::DynamicPPL.Model) - node_trait = DynamicPPL.NodeTrait(context) - if node_trait isa DynamicPPL.IsLeaf - test_leaf_context(context, model) - elseif node_trait isa DynamicPPL.IsParent - test_parent_context(context, model) - else - error("Invalid NodeTrait: $node_trait") - end + return test_leaf_context(context, model) +end +function test_context(context::DynamicPPL.AbstractParentContext, model::DynamicPPL.Model) + return test_parent_context(context, model) end function test_leaf_context(context::DynamicPPL.AbstractContext, model::DynamicPPL.Model) - @test DynamicPPL.NodeTrait(context) isa DynamicPPL.IsLeaf - # Note that for a leaf context we can't assume that it will work with an # empty VarInfo. (For example, DefaultContext will error with empty # varinfos.) Thus we only test evaluation with VarInfos that are already @@ -57,8 +50,6 @@ function test_leaf_context(context::DynamicPPL.AbstractContext, model::DynamicPP end function test_parent_context(context::DynamicPPL.AbstractContext, model::DynamicPPL.Model) - @test DynamicPPL.NodeTrait(context) isa DynamicPPL.IsParent - @testset "get/set leaf and child contexts" begin # Ensure we're using a different leaf context than the current. leafcontext_new = if DynamicPPL.leafcontext(context) isa DefaultContext diff --git a/src/test_utils/model_interface.jl b/src/test_utils/model_interface.jl index cb949464e..bd6caa93b 100644 --- a/src/test_utils/model_interface.jl +++ b/src/test_utils/model_interface.jl @@ -89,10 +89,10 @@ function logprior_true_with_logabsdet_jacobian end Return a collection of `VarName` as they are expected to appear in the model. Even though it is recommended to implement this by hand for a particular `Model`, -a default implementation using [`SimpleVarInfo{<:Dict}`](@ref) is provided. +a default implementation using [`VarInfo`](@ref) is provided. """ function varnames(model::Model) - return collect(keys(last(DynamicPPL.init!!(model, SimpleVarInfo(Dict()))))) + return collect(keys(VarInfo(model))) end """ diff --git a/src/test_utils/varinfo.jl b/src/test_utils/varinfo.jl index 26e2aa7ca..0f74da3ae 100644 --- a/src/test_utils/varinfo.jl +++ b/src/test_utils/varinfo.jl @@ -32,24 +32,12 @@ function setup_varinfos( vi_typed_metadata = DynamicPPL.typed_varinfo(model) vi_typed_vnv = DynamicPPL.typed_vector_varinfo(model) - # SimpleVarInfo - svi_typed = SimpleVarInfo(example_values) - svi_untyped = SimpleVarInfo(OrderedDict{VarName,Any}()) - svi_vnv = SimpleVarInfo(DynamicPPL.VarNamedVector()) - - varinfos = map(( - vi_untyped_metadata, - vi_untyped_vnv, - vi_typed_metadata, - vi_typed_vnv, - svi_typed, - svi_untyped, - svi_vnv, - )) do vi - # Set them all to the same values and evaluate logp. - vi = update_values!!(vi, example_values, varnames) - last(DynamicPPL.evaluate!!(model, vi)) - end + varinfos = + map((vi_untyped_metadata, vi_untyped_vnv, vi_typed_metadata, vi_typed_vnv)) do vi + # Set them all to the same values and evaluate logp. + vi = update_values!!(vi, example_values, varnames) + last(DynamicPPL.evaluate!!(model, vi)) + end if include_threadsafe varinfos = (varinfos..., map(DynamicPPL.ThreadSafeVarInfo ∘ deepcopy, varinfos)...) diff --git a/src/utils.jl b/src/utils.jl index b55a2f715..2d7b0404f 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -5,9 +5,6 @@ const NO_DEFAULT = NoDefault() # A short-hand for a type commonly used in type signatures for VarInfo methods. VarNameTuple = NTuple{N,VarName} where {N} -# TODO(mhauru) This is currently used in the transformation functions of NoDist, -# ReshapeTransform, and UnwrapSingletonTransform, and in VarInfo. We should also use it in -# SimpleVarInfo and maybe other places. """ The type for all log probability variables. @@ -15,6 +12,41 @@ This is Float64 on 64-bit systems and Float32 on 32-bit systems. """ const LogProbType = float(Real) +""" + typed_identity(x) + +Identity function, but with an overload for `with_logabsdet_jacobian` to ensure +that it returns a sensible zero logjac. + +The problem with plain old `identity` is that the default definition of +`with_logabsdet_jacobian` for `identity` returns `zero(eltype(x))`: +https://github.com/JuliaMath/ChangesOfVariables.jl/blob/d6a8115fc9b9419decbdb48e2c56ec9675b4c6a4/src/with_ladj.jl#L154 + +This is fine for most samples `x`, but if `eltype(x)` doesn't return a sensible type (e.g. +if it's `Any`), then using `identity` will error with `zero(Any)`. This can happen with, +for example, `ProductNamedTupleDistribution`: + +```julia +julia> using Distributions; d = product_distribution((a = Normal(), b = LKJCholesky(3, 0.5))); + +julia> eltype(rand(d)) +Any +``` + +The same problem precludes us from eventually broadening the scope of DynamicPPL.jl to +support distributions with non-numeric samples. + +Furthermore, in principle, the type of the log-probability should be separate from the type +of the sample. Thus, instead of using `zero(LogProbType)`, we should use the eltype of the +LogJacobianAccumulator. There's no easy way to thread that through here, but if a way to do +this is discovered, then `typed_identity` is what will allow us to obtain that custom +behaviour. +""" +function typed_identity end +@inline typed_identity(x) = x +@inline Bijectors.with_logabsdet_jacobian(::typeof(typed_identity), x) = + (x, zero(LogProbType)) + """ @addlogprob!(ex) diff --git a/test/Project.toml b/test/Project.toml index 5590ac169..0c014a193 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -6,6 +6,7 @@ Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697" Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" BangBang = "198e06fe-97b7-11e9-32a5-e1d131e6ad66" Bijectors = "76274a88-744f-5084-9051-94815aaf08c4" +Chairmarks = "0ca39b1e-fe0b-4e98-acfc-b1656634c4de" Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa" DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63" Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b" diff --git a/test/ad.jl b/test/ad.jl deleted file mode 100644 index 0236c232f..000000000 --- a/test/ad.jl +++ /dev/null @@ -1,137 +0,0 @@ -using DynamicPPL: LogDensityFunction -using DynamicPPL.TestUtils.AD: run_ad, WithExpectedResult, NoTest - -@testset "Automatic differentiation" begin - # Used as the ground truth that others are compared against. - ref_adtype = AutoForwardDiff() - - test_adtypes = [ - AutoReverseDiff(; compile=false), - AutoReverseDiff(; compile=true), - AutoMooncake(; config=nothing), - ] - - @testset "Unsupported backends" begin - @model demo() = x ~ Normal() - @test_logs (:warn, r"not officially supported") LogDensityFunction( - demo(); adtype=AutoZygote() - ) - end - - @testset "Correctness" begin - @testset "$(m.f)" for m in DynamicPPL.TestUtils.DEMO_MODELS - rand_param_values = DynamicPPL.TestUtils.rand_prior_true(m) - vns = DynamicPPL.TestUtils.varnames(m) - varinfos = DynamicPPL.TestUtils.setup_varinfos(m, rand_param_values, vns) - - @testset "$(short_varinfo_name(varinfo))" for varinfo in varinfos - linked_varinfo = DynamicPPL.link(varinfo, m) - f = LogDensityFunction(m, getlogjoint_internal, linked_varinfo) - x = DynamicPPL.getparams(f) - - # Calculate reference logp + gradient of logp using ForwardDiff - ref_ad_result = run_ad(m, ref_adtype; varinfo=linked_varinfo, test=NoTest()) - ref_logp, ref_grad = ref_ad_result.value_actual, ref_ad_result.grad_actual - - @testset "$adtype" for adtype in test_adtypes - @info "Testing AD on: $(m.f) - $(short_varinfo_name(linked_varinfo)) - $adtype" - - # Put predicates here to avoid long lines - is_mooncake = adtype isa AutoMooncake - is_1_10 = v"1.10" <= VERSION < v"1.11" - is_1_11_or_1_12 = v"1.11" <= VERSION < v"1.13" - is_svi_vnv = - linked_varinfo isa SimpleVarInfo{<:DynamicPPL.VarNamedVector} - is_svi_od = linked_varinfo isa SimpleVarInfo{<:OrderedDict} - - # Mooncake doesn't work with several combinations of SimpleVarInfo. - if is_mooncake && is_1_11_or_1_12 && is_svi_vnv - # https://github.com/compintell/Mooncake.jl/issues/470 - @test_throws ArgumentError DynamicPPL.LogDensityFunction( - m, getlogjoint_internal, linked_varinfo; adtype=adtype - ) - elseif is_mooncake && is_1_10 && is_svi_vnv - # TODO: report upstream - @test_throws UndefRefError DynamicPPL.LogDensityFunction( - m, getlogjoint_internal, linked_varinfo; adtype=adtype - ) - elseif is_mooncake && is_1_10 && is_svi_od - # TODO: report upstream - @test_throws Mooncake.MooncakeRuleCompilationError DynamicPPL.LogDensityFunction( - m, getlogjoint_internal, linked_varinfo; adtype=adtype - ) - else - @test run_ad( - m, - adtype; - varinfo=linked_varinfo, - test=WithExpectedResult(ref_logp, ref_grad), - ) isa Any - end - end - end - end - end - - # Test that various different ways of specifying array types as arguments work with all - # ADTypes. - @testset "Array argument types" begin - test_m = randn(2, 3) - - function eval_logp_and_grad(model, m, adtype) - ldf = LogDensityFunction(model(); adtype=adtype) - return LogDensityProblems.logdensity_and_gradient(ldf, m[:]) - end - - @model function scalar_matrix_model(::Type{T}=Float64) where {T<:Real} - m = Matrix{T}(undef, 2, 3) - return m ~ filldist(MvNormal(zeros(2), I), 3) - end - - scalar_matrix_model_reference = eval_logp_and_grad( - scalar_matrix_model, test_m, ref_adtype - ) - - @model function matrix_model(::Type{T}=Matrix{Float64}) where {T} - m = T(undef, 2, 3) - return m ~ filldist(MvNormal(zeros(2), I), 3) - end - - matrix_model_reference = eval_logp_and_grad(matrix_model, test_m, ref_adtype) - - @model function scalar_array_model(::Type{T}=Float64) where {T<:Real} - m = Array{T}(undef, 2, 3) - return m ~ filldist(MvNormal(zeros(2), I), 3) - end - - scalar_array_model_reference = eval_logp_and_grad( - scalar_array_model, test_m, ref_adtype - ) - - @model function array_model(::Type{T}=Array{Float64}) where {T} - m = T(undef, 2, 3) - return m ~ filldist(MvNormal(zeros(2), I), 3) - end - - array_model_reference = eval_logp_and_grad(array_model, test_m, ref_adtype) - - @testset "$adtype" for adtype in test_adtypes - scalar_matrix_model_logp_and_grad = eval_logp_and_grad( - scalar_matrix_model, test_m, adtype - ) - @test scalar_matrix_model_logp_and_grad[1] ≈ scalar_matrix_model_reference[1] - @test scalar_matrix_model_logp_and_grad[2] ≈ scalar_matrix_model_reference[2] - matrix_model_logp_and_grad = eval_logp_and_grad(matrix_model, test_m, adtype) - @test matrix_model_logp_and_grad[1] ≈ matrix_model_reference[1] - @test matrix_model_logp_and_grad[2] ≈ matrix_model_reference[2] - scalar_array_model_logp_and_grad = eval_logp_and_grad( - scalar_array_model, test_m, adtype - ) - @test scalar_array_model_logp_and_grad[1] ≈ scalar_array_model_reference[1] - @test scalar_array_model_logp_and_grad[2] ≈ scalar_array_model_reference[2] - array_model_logp_and_grad = eval_logp_and_grad(array_model, test_m, adtype) - @test array_model_logp_and_grad[1] ≈ array_model_reference[1] - @test array_model_logp_and_grad[2] ≈ array_model_reference[2] - end - end -end diff --git a/test/chains.jl b/test/chains.jl index ab0ff4475..43b877d62 100644 --- a/test/chains.jl +++ b/test/chains.jl @@ -4,7 +4,7 @@ using DynamicPPL using Distributions using Test -@testset "ParamsWithStats" begin +@testset "ParamsWithStats from VarInfo" begin @model function f(z) x ~ Normal() y := x + 1 @@ -66,4 +66,30 @@ using Test end end +@testset "ParamsWithStats from FastLDF" begin + @testset "$(m.f)" for m in DynamicPPL.TestUtils.DEMO_MODELS + unlinked_vi = VarInfo(m) + @testset "$islinked" for islinked in (false, true) + vi = if islinked + DynamicPPL.link!!(unlinked_vi, m) + else + unlinked_vi + end + params = [x for x in vi[:]] + + # Get the ParamsWithStats using FastLDF + fldf = DynamicPPL.Experimental.FastLDF(m, getlogjoint, vi) + ps = ParamsWithStats(params, fldf) + + # Check that length of parameters is as expected + @test length(ps.params) == length(keys(vi)) + + # Iterate over all variables to check that their values match + for vn in keys(vi) + @test ps.params[vn] == vi[vn] + end + end + end +end + end # module diff --git a/test/compiler.jl b/test/compiler.jl index b1309254e..0da1f13fb 100644 --- a/test/compiler.jl +++ b/test/compiler.jl @@ -604,13 +604,13 @@ module Issue537 end # Even if the return-value is `AbstractVarInfo`, we should return # a `Tuple` with `AbstractVarInfo` in the second component too. @model demo() = return __varinfo__ - retval, svi = DynamicPPL.init!!(demo(), SimpleVarInfo()) - @test svi == SimpleVarInfo() + retval, vi = DynamicPPL.init!!(demo(), VarInfo()) + @test vi == VarInfo() if Threads.nthreads() > 1 - @test retval isa DynamicPPL.ThreadSafeVarInfo{<:SimpleVarInfo} - @test retval.varinfo == svi + @test retval isa DynamicPPL.ThreadSafeVarInfo{<:VarInfo} + @test retval.varinfo == vi else - @test retval == svi + @test retval == vi end # We should not be altering return-values other than at top-level. @@ -620,11 +620,11 @@ module Issue537 end f(x) = return x^2 return f(1.0) end - retval, svi = DynamicPPL.init!!(demo(), SimpleVarInfo()) + retval, vi = DynamicPPL.init!!(demo(), VarInfo()) @test retval isa Float64 @model demo() = x ~ Normal() - retval, svi = DynamicPPL.init!!(demo(), SimpleVarInfo()) + retval, vi = DynamicPPL.init!!(demo(), VarInfo()) # Return-value when using `to_submodel` @model inner() = x ~ Normal() diff --git a/test/contexts.jl b/test/contexts.jl index 972d833a5..71f2f13b6 100644 --- a/test/contexts.jl +++ b/test/contexts.jl @@ -6,10 +6,9 @@ using DynamicPPL: childcontext, setchildcontext, AbstractContext, - NodeTrait, - IsLeaf, - IsParent, + AbstractParentContext, contextual_isassumption, + PrefixContext, FixedContext, ConditionContext, decondition_context, @@ -25,22 +24,21 @@ using LinearAlgebra: I using Random: Xoshiro # TODO: Should we maybe put this in DPPL itself? +function Base.iterate(context::AbstractParentContext) + return context, childcontext(context) +end function Base.iterate(context::AbstractContext) - if NodeTrait(context) isa IsLeaf - return nothing - end - - return context, context + return context, nothing end -function Base.iterate(_::AbstractContext, context::AbstractContext) - return _iterate(NodeTrait(context), context) +function Base.iterate(::AbstractContext, state::AbstractParentContext) + return state, childcontext(state) end -_iterate(::IsLeaf, context) = nothing -function _iterate(::IsParent, context) - child = childcontext(context) - return child, child +function Base.iterate(::AbstractContext, state::AbstractContext) + return state, nothing +end +function Base.iterate(::AbstractContext, state::Nothing) + return nothing end - Base.IteratorSize(::Type{<:AbstractContext}) = Base.SizeUnknown() Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown() @@ -347,11 +345,11 @@ Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown() @testset "collapse_prefix_stack" begin # Utility function to make sure that there are no PrefixContexts in # the context stack. - function has_no_prefixcontexts(ctx::AbstractContext) - return !(ctx isa PrefixContext) && ( - NodeTrait(ctx) isa IsLeaf || has_no_prefixcontexts(childcontext(ctx)) - ) + has_no_prefixcontexts(::PrefixContext) = false + function has_no_prefixcontexts(ctx::AbstractParentContext) + return has_no_prefixcontexts(childcontext(ctx)) end + has_no_prefixcontexts(::AbstractContext) = true # Prefix -> Condition c1 = PrefixContext(@varname(a), ConditionContext((c=1, d=2))) @@ -424,8 +422,6 @@ Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown() "typed+VNV", DynamicPPL.typed_vector_varinfo(DynamicPPL.typed_varinfo(VarInfo())), ), - ("SVI+NamedTuple", SimpleVarInfo()), - ("Svi+Dict", SimpleVarInfo(Dict{VarName,Any}())), ] @model function test_init_model() diff --git a/test/fasteval.jl b/test/fasteval.jl new file mode 100644 index 000000000..a75441c93 --- /dev/null +++ b/test/fasteval.jl @@ -0,0 +1,225 @@ +module DynamicPPLFastEvalTests + +using AbstractPPL: AbstractPPL +using Chairmarks +using DynamicPPL +using Distributions +using DistributionsAD: filldist +using ADTypes +using DynamicPPL.TestUtils.AD: run_ad, WithExpectedResult, NoTest +using LinearAlgebra: I +using Test +using LogDensityProblems: LogDensityProblems + +using ForwardDiff: ForwardDiff +using ReverseDiff: ReverseDiff +using Mooncake: Mooncake + +@testset "LogDensityFunction: Correctness" begin + @testset "$(m.f)" for m in DynamicPPL.TestUtils.DEMO_MODELS + @testset "$varinfo_func" for varinfo_func in [ + DynamicPPL.untyped_varinfo, + DynamicPPL.typed_varinfo, + DynamicPPL.untyped_vector_varinfo, + DynamicPPL.typed_vector_varinfo, + ] + unlinked_vi = varinfo_func(m) + @testset "$islinked" for islinked in (false, true) + vi = if islinked + DynamicPPL.link!!(unlinked_vi, m) + else + unlinked_vi + end + nt_ranges, dict_ranges = DynamicPPL.get_ranges_and_linked(vi) + params = [x for x in vi[:]] + # Iterate over all variables + for vn in keys(vi) + # Check that `getindex_internal` returns the same thing as using the ranges + # directly + range_with_linked = if AbstractPPL.getoptic(vn) === identity + nt_ranges[AbstractPPL.getsym(vn)] + else + dict_ranges[vn] + end + @test params[range_with_linked.range] == + DynamicPPL.getindex_internal(vi, vn) + # Check that the link status is correct + @test range_with_linked.is_linked == islinked + end + end + end + end + + @testset "Threaded observe" begin + if Threads.nthreads() > 1 + @model function threaded(y) + x ~ Normal() + Threads.@threads for i in eachindex(y) + y[i] ~ Normal(x) + end + end + N = 100 + model = threaded(zeros(N)) + ldf = DynamicPPL.LogDensityFunction(model) + + xs = [1.0] + @test LogDensityProblems.logdensity(ldf, xs) ≈ + logpdf(Normal(), xs[1]) + N * logpdf(Normal(xs[1]), 0.0) + end + end +end + +@testset "Fast evaluation: performance" begin + if Threads.nthreads() == 1 + # Evaluating these three models with OnlyAccsVarInfo should not lead to any + # allocations (but only when not using TSVI). + @model function f() + x ~ Normal() + return 1.0 ~ Normal(x) + end + @model function submodel_inner() + m ~ Normal(0, 1) + s ~ Exponential() + return (m=m, s=s) + end + # Note that for the allocation tests to work on this one, `inner` has + # to be passed as an argument to `submodel_outer`, instead of just + # being called inside the model function itself + @model function submodel_outer(inner) + params ~ to_submodel(inner) + y ~ Normal(params.m, params.s) + return 1.0 ~ Normal(y) + end + + @testset for model in + (f(), submodel_inner() | (; s=0.0), submodel_outer(submodel_inner())) + @testset "LogDensityFunction" begin + # Performance tests on LogDensityFunction. + vi = VarInfo(model) + fldf = DynamicPPL.LogDensityFunction( + model, DynamicPPL.getlogjoint_internal, vi + ) + x = vi[:] + bench = median(@be LogDensityProblems.logdensity(fldf, x)) + @test iszero(bench.allocs) + end + + # And for returned/logp evaluation functions. + @testset "$func" for func in (returned, logprior, loglikelihood, logjoint) + if model.f !== submodel_outer + # submodel_outer contains nested parameters, so the NamedTuple + # representation doesn't work. One day, we'll fix rand(NamedTuple, + # model) to 'work' with nested parameters. But this will require us to + # figure out submodels properly... + params_nt = rand(NamedTuple, model) + bench = median(@be func(model, params_nt)) + @test iszero(bench.allocs) + end + + # Thank goodness Dicts work... + params_dict = rand(Dict, model) + bench = median(@be func(model, params_dict)) + @test iszero(bench.allocs) + end + end + end +end + +@testset "AD with LogDensityFunction" begin + # Used as the ground truth that others are compared against. + ref_adtype = AutoForwardDiff() + + test_adtypes = [ + AutoReverseDiff(; compile=false), + AutoReverseDiff(; compile=true), + AutoMooncake(; config=nothing), + ] + + @testset "Correctness" begin + @testset "$(m.f)" for m in DynamicPPL.TestUtils.DEMO_MODELS + varinfo = VarInfo(m) + linked_varinfo = DynamicPPL.link(varinfo, m) + f = LogDensityFunction(m, getlogjoint_internal, linked_varinfo) + x = [p for p in linked_varinfo[:]] + + # Calculate reference logp + gradient of logp using ForwardDiff + ref_ad_result = run_ad(m, ref_adtype; varinfo=linked_varinfo, test=NoTest()) + ref_logp, ref_grad = ref_ad_result.value_actual, ref_ad_result.grad_actual + + @testset "$adtype" for adtype in test_adtypes + @info "Testing AD on: $(m.f) - $adtype" + + @test run_ad( + m, + adtype; + varinfo=linked_varinfo, + test=WithExpectedResult(ref_logp, ref_grad), + ) isa Any + end + end + end + + # Test that various different ways of specifying array types as arguments work with all + # ADTypes. + @testset "Array argument types" begin + test_m = randn(2, 3) + + function eval_logp_and_grad(model, m, adtype) + ldf = LogDensityFunction(model(); adtype=adtype) + return LogDensityProblems.logdensity_and_gradient(ldf, m[:]) + end + + @model function scalar_matrix_model(::Type{T}=Float64) where {T<:Real} + m = Matrix{T}(undef, 2, 3) + return m ~ filldist(MvNormal(zeros(2), I), 3) + end + + scalar_matrix_model_reference = eval_logp_and_grad( + scalar_matrix_model, test_m, ref_adtype + ) + + @model function matrix_model(::Type{T}=Matrix{Float64}) where {T} + m = T(undef, 2, 3) + return m ~ filldist(MvNormal(zeros(2), I), 3) + end + + matrix_model_reference = eval_logp_and_grad(matrix_model, test_m, ref_adtype) + + @model function scalar_array_model(::Type{T}=Float64) where {T<:Real} + m = Array{T}(undef, 2, 3) + return m ~ filldist(MvNormal(zeros(2), I), 3) + end + + scalar_array_model_reference = eval_logp_and_grad( + scalar_array_model, test_m, ref_adtype + ) + + @model function array_model(::Type{T}=Array{Float64}) where {T} + m = T(undef, 2, 3) + return m ~ filldist(MvNormal(zeros(2), I), 3) + end + + array_model_reference = eval_logp_and_grad(array_model, test_m, ref_adtype) + + @testset "$adtype" for adtype in test_adtypes + scalar_matrix_model_logp_and_grad = eval_logp_and_grad( + scalar_matrix_model, test_m, adtype + ) + @test scalar_matrix_model_logp_and_grad[1] ≈ scalar_matrix_model_reference[1] + @test scalar_matrix_model_logp_and_grad[2] ≈ scalar_matrix_model_reference[2] + matrix_model_logp_and_grad = eval_logp_and_grad(matrix_model, test_m, adtype) + @test matrix_model_logp_and_grad[1] ≈ matrix_model_reference[1] + @test matrix_model_logp_and_grad[2] ≈ matrix_model_reference[2] + scalar_array_model_logp_and_grad = eval_logp_and_grad( + scalar_array_model, test_m, adtype + ) + @test scalar_array_model_logp_and_grad[1] ≈ scalar_array_model_reference[1] + @test scalar_array_model_logp_and_grad[2] ≈ scalar_array_model_reference[2] + array_model_logp_and_grad = eval_logp_and_grad(array_model, test_m, adtype) + @test array_model_logp_and_grad[1] ≈ array_model_reference[1] + @test array_model_logp_and_grad[2] ≈ array_model_reference[2] + end + end +end + +end diff --git a/test/integration/enzyme/main.jl b/test/integration/enzyme/main.jl index b40bbeb8f..b017c658d 100644 --- a/test/integration/enzyme/main.jl +++ b/test/integration/enzyme/main.jl @@ -1,13 +1,16 @@ using DynamicPPL.TestUtils: DEMO_MODELS using DynamicPPL.TestUtils.AD: run_ad +using DynamicPPL: OrderedDict using ADTypes: AutoEnzyme using Test: @test, @testset import Enzyme: set_runtime_activity, Forward, Reverse, Const using ForwardDiff: ForwardDiff # run_ad uses FD for correctness test -ADTYPES = Dict( - "EnzymeForward" => AutoEnzyme(; mode=set_runtime_activity(Forward)), - "EnzymeReverse" => AutoEnzyme(; mode=set_runtime_activity(Reverse)), +ADTYPES = OrderedDict( + "EnzymeForward" => + AutoEnzyme(; mode=set_runtime_activity(Forward), function_annotation=Const), + "EnzymeReverse" => + AutoEnzyme(; mode=set_runtime_activity(Reverse), function_annotation=Const), ) @testset "$ad_key" for (ad_key, ad_type) in ADTYPES diff --git a/test/logdensityfunction.jl b/test/logdensityfunction.jl deleted file mode 100644 index fbd868f71..000000000 --- a/test/logdensityfunction.jl +++ /dev/null @@ -1,49 +0,0 @@ -using Test, DynamicPPL, ADTypes, LogDensityProblems, ForwardDiff - -@testset "`getmodel` and `setmodel`" begin - @testset "$(nameof(model))" for model in DynamicPPL.TestUtils.DEMO_MODELS - model = DynamicPPL.TestUtils.DEMO_MODELS[1] - ℓ = DynamicPPL.LogDensityFunction(model) - @test DynamicPPL.getmodel(ℓ) == model - @test DynamicPPL.setmodel(ℓ, model).model == model - end -end - -@testset "LogDensityFunction" begin - @testset "$(nameof(model))" for model in DynamicPPL.TestUtils.DEMO_MODELS - example_values = DynamicPPL.TestUtils.rand_prior_true(model) - vns = DynamicPPL.TestUtils.varnames(model) - varinfos = DynamicPPL.TestUtils.setup_varinfos(model, example_values, vns) - - vi = first(varinfos) - theta = vi[:] - ldf_joint = DynamicPPL.LogDensityFunction(model) - @test LogDensityProblems.logdensity(ldf_joint, theta) ≈ logjoint(model, vi) - ldf_prior = DynamicPPL.LogDensityFunction(model, getlogprior) - @test LogDensityProblems.logdensity(ldf_prior, theta) ≈ logprior(model, vi) - ldf_likelihood = DynamicPPL.LogDensityFunction(model, getloglikelihood) - @test LogDensityProblems.logdensity(ldf_likelihood, theta) ≈ - loglikelihood(model, vi) - - @testset "$(varinfo)" for varinfo in varinfos - # Note use of `getlogjoint` rather than `getlogjoint_internal` here ... - logdensity = DynamicPPL.LogDensityFunction(model, getlogjoint, varinfo) - θ = varinfo[:] - # ... because it has to match with `logjoint(model, vi)`, which always returns - # the unlinked value - @test LogDensityProblems.logdensity(logdensity, θ) ≈ logjoint(model, varinfo) - @test LogDensityProblems.dimension(logdensity) == length(θ) - end - end - - @testset "capabilities" begin - model = DynamicPPL.TestUtils.DEMO_MODELS[1] - ldf = DynamicPPL.LogDensityFunction(model) - @test LogDensityProblems.capabilities(typeof(ldf)) == - LogDensityProblems.LogDensityOrder{0}() - - ldf_with_ad = DynamicPPL.LogDensityFunction(model; adtype=AutoForwardDiff()) - @test LogDensityProblems.capabilities(typeof(ldf_with_ad)) == - LogDensityProblems.LogDensityOrder{1}() - end -end diff --git a/test/model.jl b/test/model.jl index 6da5ea246..2830a131e 100644 --- a/test/model.jl +++ b/test/model.jl @@ -27,7 +27,6 @@ end is_type_stable_varinfo(::DynamicPPL.AbstractVarInfo) = false is_type_stable_varinfo(varinfo::DynamicPPL.NTVarInfo) = true -is_type_stable_varinfo(varinfo::DynamicPPL.SimpleVarInfo{<:NamedTuple}) = true const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal() @@ -314,7 +313,7 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal() @test logjoint(model, x) != DynamicPPL.TestUtils.logjoint_true_with_logabsdet_jacobian(model, x...) # Ensure `varnames` is implemented. - vi = last(DynamicPPL.init!!(model, SimpleVarInfo(OrderedDict{VarName,Any}()))) + vi = last(DynamicPPL.init!!(model, VarInfo())) @test all(collect(keys(vi)) .== DynamicPPL.TestUtils.varnames(model)) # Ensure `posterior_mean` is implemented. @test DynamicPPL.TestUtils.posterior_mean(model) isa typeof(x) @@ -492,12 +491,7 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal() end model = product_dirichlet() - varinfos = [ - DynamicPPL.untyped_varinfo(model), - DynamicPPL.typed_varinfo(model), - DynamicPPL.typed_simple_varinfo(model), - DynamicPPL.untyped_simple_varinfo(model), - ] + varinfos = [DynamicPPL.untyped_varinfo(model), DynamicPPL.typed_varinfo(model)] @testset "$(short_varinfo_name(varinfo))" for varinfo in varinfos logjoint = getlogjoint(varinfo) # unlinked space varinfo_linked = DynamicPPL.link(varinfo, model) diff --git a/test/runtests.jl b/test/runtests.jl index 5e40635e6..47cff58c2 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -13,6 +13,7 @@ using ForwardDiff using LogDensityProblems using MacroTools using MCMCChains +using Mooncake using StableRNGs using ReverseDiff using Mooncake @@ -54,12 +55,13 @@ include("test_util.jl") include("compiler.jl") include("varnamedvector.jl") include("varinfo.jl") - include("simple_varinfo.jl") include("model.jl") include("distribution_wrappers.jl") - include("logdensityfunction.jl") include("linking.jl") include("serialization.jl") + end + + if GROUP == "All" || GROUP == "Group2" include("pointwise_logdensities.jl") include("lkj.jl") include("contexts.jl") @@ -69,9 +71,7 @@ include("test_util.jl") include("submodels.jl") include("chains.jl") include("bijector.jl") - end - - if GROUP == "All" || GROUP == "Group2" + include("fasteval.jl") @testset "extensions" begin include("ext/DynamicPPLMCMCChainsExt.jl") include("ext/DynamicPPLJETExt.jl") @@ -80,7 +80,6 @@ include("test_util.jl") @testset "ad" begin include("ext/DynamicPPLForwardDiffExt.jl") include("ext/DynamicPPLMooncakeExt.jl") - include("ad.jl") end @testset "prob and logprob macro" begin @test_throws ErrorException prob"..." diff --git a/test/simple_varinfo.jl b/test/simple_varinfo.jl deleted file mode 100644 index 488cb8941..000000000 --- a/test/simple_varinfo.jl +++ /dev/null @@ -1,320 +0,0 @@ -@testset "simple_varinfo.jl" begin - @testset "constructor & indexing" begin - @testset "NamedTuple" begin - svi = SimpleVarInfo(; m=1.0) - @test getlogjoint(svi) == 0.0 - @test haskey(svi, @varname(m)) - @test !haskey(svi, @varname(m[1])) - - svi = SimpleVarInfo(; m=[1.0]) - @test getlogjoint(svi) == 0.0 - @test haskey(svi, @varname(m)) - @test haskey(svi, @varname(m[1])) - @test !haskey(svi, @varname(m[2])) - @test svi[@varname(m)][1] == svi[@varname(m[1])] - - svi = SimpleVarInfo(; m=(a=[1.0],)) - @test haskey(svi, @varname(m)) - @test haskey(svi, @varname(m.a)) - @test haskey(svi, @varname(m.a[1])) - @test !haskey(svi, @varname(m.a[2])) - @test !haskey(svi, @varname(m.a.b)) - - svi = SimpleVarInfo{Float32}(; m=1.0) - @test getlogjoint(svi) isa Float32 - - svi = SimpleVarInfo((m=1.0,)) - svi = accloglikelihood!!(svi, 1.0) - @test getlogjoint(svi) == 1.0 - end - - @testset "Dict" begin - svi = SimpleVarInfo(Dict(@varname(m) => 1.0)) - @test getlogjoint(svi) == 0.0 - @test haskey(svi, @varname(m)) - @test !haskey(svi, @varname(m[1])) - - svi = SimpleVarInfo(Dict(@varname(m) => [1.0])) - @test getlogjoint(svi) == 0.0 - @test haskey(svi, @varname(m)) - @test haskey(svi, @varname(m[1])) - @test !haskey(svi, @varname(m[2])) - @test svi[@varname(m)][1] == svi[@varname(m[1])] - - svi = SimpleVarInfo(Dict(@varname(m) => (a=[1.0],))) - @test haskey(svi, @varname(m)) - @test haskey(svi, @varname(m.a)) - @test haskey(svi, @varname(m.a[1])) - @test !haskey(svi, @varname(m.a[2])) - @test !haskey(svi, @varname(m.a.b)) - - svi = SimpleVarInfo(Dict(@varname(m.a) => [1.0])) - # Now we only have a variable `m.a` which is subsumed by `m`, - # but we can't guarantee that we have the "entire" `m`. - @test !haskey(svi, @varname(m)) - @test haskey(svi, @varname(m.a)) - @test haskey(svi, @varname(m.a[1])) - @test !haskey(svi, @varname(m.a[2])) - @test !haskey(svi, @varname(m.a.b)) - end - - @testset "VarNamedVector" begin - svi = SimpleVarInfo(push!!(DynamicPPL.VarNamedVector(), @varname(m) => 1.0)) - @test getlogjoint(svi) == 0.0 - @test haskey(svi, @varname(m)) - @test !haskey(svi, @varname(m[1])) - - svi = SimpleVarInfo(push!!(DynamicPPL.VarNamedVector(), @varname(m) => [1.0])) - @test getlogjoint(svi) == 0.0 - @test haskey(svi, @varname(m)) - @test haskey(svi, @varname(m[1])) - @test !haskey(svi, @varname(m[2])) - @test svi[@varname(m)][1] == svi[@varname(m[1])] - - svi = SimpleVarInfo(push!!(DynamicPPL.VarNamedVector(), @varname(m.a) => [1.0])) - @test haskey(svi, @varname(m)) - @test haskey(svi, @varname(m.a)) - @test haskey(svi, @varname(m.a[1])) - @test !haskey(svi, @varname(m.a[2])) - @test !haskey(svi, @varname(m.a.b)) - # The implementation of haskey and getvalue fo VarNamedVector is incomplete, the - # next test is here to remind of us that. - svi = SimpleVarInfo( - push!!(DynamicPPL.VarNamedVector(), @varname(m.a.b) => [1.0]) - ) - @test_broken !haskey(svi, @varname(m.a.b.c.d)) - end - end - - @testset "link!! & invlink!! on $(nameof(model))" for model in - DynamicPPL.TestUtils.DEMO_MODELS - values_constrained = DynamicPPL.TestUtils.rand_prior_true(model) - @testset "$name" for (name, vi) in ( - ("SVI{Dict}", SimpleVarInfo(Dict{VarName,Any}())), - ("SVI{NamedTuple}", SimpleVarInfo(values_constrained)), - ("SVI{VNV}", SimpleVarInfo(DynamicPPL.VarNamedVector())), - ("TypedVarInfo", DynamicPPL.typed_varinfo(model)), - ) - for vn in DynamicPPL.TestUtils.varnames(model) - vi = DynamicPPL.setindex!!(vi, get(values_constrained, vn), vn) - end - vi = last(DynamicPPL.evaluate!!(model, vi)) - - # Calculate ground truth - lp_unlinked_true = DynamicPPL.TestUtils.logjoint_true( - model, values_constrained... - ) - _, lp_linked_true = DynamicPPL.TestUtils.logjoint_true_with_logabsdet_jacobian( - model, values_constrained... - ) - - # `link!!` - vi_linked = link!!(deepcopy(vi), model) - lp_unlinked = getlogjoint(vi_linked) - lp_linked = getlogjoint_internal(vi_linked) - @test lp_linked ≈ lp_linked_true - @test lp_unlinked ≈ lp_unlinked_true - @test logjoint(model, vi_linked) ≈ lp_unlinked - - # `invlink!!` - vi_invlinked = invlink!!(deepcopy(vi_linked), model) - lp_unlinked = getlogjoint(vi_invlinked) - also_lp_unlinked = getlogjoint_internal(vi_invlinked) - @test lp_unlinked ≈ lp_unlinked_true - @test also_lp_unlinked ≈ lp_unlinked_true - @test logjoint(model, vi_invlinked) ≈ lp_unlinked - - # Should result in same values. - @test all( - DynamicPPL.tovec(DynamicPPL.getindex_internal(vi_invlinked, vn)) ≈ - DynamicPPL.tovec(get(values_constrained, vn)) for - vn in DynamicPPL.TestUtils.varnames(model) - ) - end - end - - @testset "SimpleVarInfo on $(nameof(model))" for model in - DynamicPPL.TestUtils.DEMO_MODELS - # We might need to pre-allocate for the variable `m`, so we need - # to see whether this is the case. - svi_nt = SimpleVarInfo(DynamicPPL.TestUtils.rand_prior_true(model)) - svi_dict = SimpleVarInfo(VarInfo(model), Dict) - vnv = DynamicPPL.VarNamedVector() - for (k, v) in pairs(DynamicPPL.TestUtils.rand_prior_true(model)) - vnv = push!!(vnv, VarName{k}() => v) - end - svi_vnv = SimpleVarInfo(vnv) - - @testset "$name" for (name, svi) in ( - ("NamedTuple", svi_nt), - ("Dict", svi_dict), - ("VarNamedVector", svi_vnv), - # TODO(mhauru) Fix linked SimpleVarInfos to work with our test models. - # DynamicPPL.set_transformed!!(deepcopy(svi_nt), true), - # DynamicPPL.set_transformed!!(deepcopy(svi_dict), true), - # DynamicPPL.set_transformed!!(deepcopy(svi_vnv), true), - ) - # Random seed is set in each `@testset`, so we need to sample - # a new realization for `m` here. - retval = model() - - ### Sampling ### - # Sample a new varinfo! - _, svi_new = DynamicPPL.init!!(model, svi) - - # Realization for `m` should be different wp. 1. - for vn in DynamicPPL.TestUtils.varnames(model) - @test svi_new[vn] != get(retval, vn) - end - - # Logjoint should be non-zero wp. 1. - @test getlogjoint(svi_new) != 0 - - ### Evaluation ### - values_eval_constrained = DynamicPPL.TestUtils.rand_prior_true(model) - if DynamicPPL.is_transformed(svi) - _values_prior, logpri_true = DynamicPPL.TestUtils.logprior_true_with_logabsdet_jacobian( - model, values_eval_constrained... - ) - values_eval, logπ_true = DynamicPPL.TestUtils.logjoint_true_with_logabsdet_jacobian( - model, values_eval_constrained... - ) - # Make sure that these two computation paths provide the same - # transformed values. - @test values_eval == _values_prior - else - logpri_true = DynamicPPL.TestUtils.logprior_true( - model, values_eval_constrained... - ) - logπ_true = DynamicPPL.TestUtils.logjoint_true( - model, values_eval_constrained... - ) - values_eval = values_eval_constrained - end - - # No logabsdet-jacobian correction needed for the likelihood. - loglik_true = DynamicPPL.TestUtils.loglikelihood_true( - model, values_eval_constrained... - ) - - # Update the realizations in `svi_new`. - svi_eval = svi_new - for vn in DynamicPPL.TestUtils.varnames(model) - svi_eval = DynamicPPL.setindex!!(svi_eval, get(values_eval, vn), vn) - end - - # Reset the logp accumulators. - svi_eval = DynamicPPL.resetaccs!!(svi_eval) - - # Compute `logjoint` using the varinfo. - logπ = logjoint(model, svi_eval) - logpri = logprior(model, svi_eval) - loglik = loglikelihood(model, svi_eval) - - # Values should not have changed. - for vn in DynamicPPL.TestUtils.varnames(model) - @test svi_eval[vn] == get(values_eval, vn) - end - - # Compare log-probability computations. - @test logpri ≈ logpri_true - @test loglik ≈ loglik_true - @test logπ ≈ logπ_true - end - end - - @testset "Dynamic constraints" begin - model = DynamicPPL.TestUtils.demo_dynamic_constraint() - - # Initialize. - svi_nt = DynamicPPL.set_transformed!!(SimpleVarInfo(), true) - svi_nt = last(DynamicPPL.init!!(model, svi_nt)) - svi_vnv = DynamicPPL.set_transformed!!( - SimpleVarInfo(DynamicPPL.VarNamedVector()), true - ) - svi_vnv = last(DynamicPPL.init!!(model, svi_vnv)) - - for svi in (svi_nt, svi_vnv) - # Sample with large variations in unconstrained space. - for i in 1:10 - for vn in keys(svi) - svi = DynamicPPL.setindex!!(svi, 10 * randn(), vn) - end - retval, svi = DynamicPPL.evaluate!!(model, svi) - @test retval.m == svi[@varname(m)] # `m` is unconstrained - @test retval.x ≠ svi[@varname(x)] # `x` is constrained depending on `m` - - retval_unconstrained, lp_true = DynamicPPL.TestUtils.logjoint_true_with_logabsdet_jacobian( - model, retval.m, retval.x - ) - - # Realizations from model should all be equal to the unconstrained realization. - for vn in DynamicPPL.TestUtils.varnames(model) - @test get(retval_unconstrained, vn) ≈ svi[vn] rtol = 1e-6 - end - - # `getlogp` should be equal to the logjoint with log-absdet-jac correction. - lp = getlogjoint_internal(svi) - # needs higher atol because of https://github.com/TuringLang/Bijectors.jl/issues/375 - @test lp ≈ lp_true atol = 1.2e-5 - end - end - end - - @testset "Static transformation" begin - model = DynamicPPL.TestUtils.demo_static_transformation() - - varinfos = DynamicPPL.TestUtils.setup_varinfos( - model, DynamicPPL.TestUtils.rand_prior_true(model), [@varname(s), @varname(m)] - ) - @testset "$(short_varinfo_name(vi))" for vi in varinfos - # Initialize varinfo and link. - vi_linked = DynamicPPL.link!!(vi, model) - - # Make sure `maybe_invlink_before_eval!!` results in `invlink!!`. - @test !DynamicPPL.is_transformed( - DynamicPPL.maybe_invlink_before_eval!!(deepcopy(vi), model) - ) - - # Resulting varinfo should no longer be transformed. - vi_result = last(DynamicPPL.init!!(model, deepcopy(vi))) - @test !DynamicPPL.is_transformed(vi_result) - - # Set the values to something that is out of domain if we're in constrained space. - for vn in keys(vi) - vi_linked = DynamicPPL.setindex!!(vi_linked, -rand(), vn) - end - - # NOTE: Evaluating a linked VarInfo, **specifically when the transformation - # is static**, will result in an invlinked VarInfo. This is because of - # `maybe_invlink_before_eval!`, which only invlinks if the transformation - # is static. (src/abstract_varinfo.jl) - retval, vi_unlinked_again = DynamicPPL.evaluate!!(model, deepcopy(vi_linked)) - - @test DynamicPPL.tovec(DynamicPPL.getindex_internal(vi_linked, @varname(s))) ≠ - DynamicPPL.tovec(retval.s) # `s` is unconstrained in original - @test DynamicPPL.tovec( - DynamicPPL.getindex_internal(vi_unlinked_again, @varname(s)) - ) == DynamicPPL.tovec(retval.s) # `s` is constrained in result - - # `m` should not be transformed. - @test vi_linked[@varname(m)] == retval.m - @test vi_unlinked_again[@varname(m)] == retval.m - - # Get ground truths - retval_unconstrained, lp_linked_true = DynamicPPL.TestUtils.logjoint_true_with_logabsdet_jacobian( - model, retval.s, retval.m - ) - lp_unlinked_true = DynamicPPL.TestUtils.logjoint_true(model, retval.s, retval.m) - - @test DynamicPPL.tovec(DynamicPPL.getindex_internal(vi_linked, @varname(s))) ≈ - DynamicPPL.tovec(retval_unconstrained.s) - @test DynamicPPL.tovec(DynamicPPL.getindex_internal(vi_linked, @varname(m))) ≈ - DynamicPPL.tovec(retval_unconstrained.m) - - # The unlinked varinfo should hold the unlinked logp. - lp_unlinked = getlogjoint(vi_unlinked_again) - @test getlogjoint(vi_unlinked_again) ≈ lp_unlinked_true - end - end -end diff --git a/test/test_util.jl b/test/test_util.jl index 94fdbd744..911de1079 100644 --- a/test/test_util.jl +++ b/test/test_util.jl @@ -25,20 +25,6 @@ function short_varinfo_name(vi::DynamicPPL.NTVarInfo) end short_varinfo_name(::DynamicPPL.UntypedVarInfo) = "UntypedVarInfo" short_varinfo_name(::DynamicPPL.UntypedVectorVarInfo) = "UntypedVectorVarInfo" -function short_varinfo_name(::SimpleVarInfo{<:NamedTuple,<:Ref}) - return "SimpleVarInfo{<:NamedTuple,<:Ref}" -end -function short_varinfo_name(::SimpleVarInfo{<:OrderedDict,<:Ref}) - return "SimpleVarInfo{<:OrderedDict,<:Ref}" -end -function short_varinfo_name(::SimpleVarInfo{<:DynamicPPL.VarNamedVector,<:Ref}) - return "SimpleVarInfo{<:VarNamedVector,<:Ref}" -end -short_varinfo_name(::SimpleVarInfo{<:NamedTuple}) = "SimpleVarInfo{<:NamedTuple}" -short_varinfo_name(::SimpleVarInfo{<:OrderedDict}) = "SimpleVarInfo{<:OrderedDict}" -function short_varinfo_name(::SimpleVarInfo{<:DynamicPPL.VarNamedVector}) - return "SimpleVarInfo{<:VarNamedVector}" -end # convenient functions for testing model.jl # function to modify the representation of values based on their length diff --git a/test/varinfo.jl b/test/varinfo.jl index a1a1b370f..f9ce7171f 100644 --- a/test/varinfo.jl +++ b/test/varinfo.jl @@ -1,17 +1,7 @@ function check_varinfo_keys(varinfo, vns) - if varinfo isa DynamicPPL.SimpleOrThreadSafeSimple{<:NamedTuple} - # NOTE: We can't compare the `keys(varinfo_merged)` directly with `vns`, - # since `keys(varinfo_merged)` only contains `VarName` with `identity`. - # So we just check that the original keys are present. - for vn in vns - # Should have all the original keys. - @test haskey(varinfo, vn) - end - else - vns_varinfo = keys(varinfo) - # Should be equivalent. - @test union(vns_varinfo, vns) == intersect(vns_varinfo, vns) - end + vns_varinfo = keys(varinfo) + # Should be equivalent. + @test union(vns_varinfo, vns) == intersect(vns_varinfo, vns) end """ @@ -100,9 +90,6 @@ end test_base(VarInfo()) test_base(DynamicPPL.typed_varinfo(VarInfo())) - test_base(SimpleVarInfo()) - test_base(SimpleVarInfo(Dict{VarName,Any}())) - test_base(SimpleVarInfo(DynamicPPL.VarNamedVector())) end @testset "get/set/acclogp" begin @@ -129,9 +116,6 @@ end vi = VarInfo() test_varinfo_logp!(vi) test_varinfo_logp!(DynamicPPL.typed_varinfo(vi)) - test_varinfo_logp!(SimpleVarInfo()) - test_varinfo_logp!(SimpleVarInfo(Dict())) - test_varinfo_logp!(SimpleVarInfo(DynamicPPL.VarNamedVector())) end @testset "logp accumulators" begin @@ -444,19 +428,6 @@ end vi = DynamicPPL.typed_varinfo(model) vi = DynamicPPL.set_transformed!!(vi, true, vn) test_linked_varinfo(model, vi) - - ### `SimpleVarInfo` - ## `SimpleVarInfo{<:NamedTuple}` - vi = DynamicPPL.set_transformed!!(SimpleVarInfo(), true) - test_linked_varinfo(model, vi) - - ## `SimpleVarInfo{<:Dict}` - vi = DynamicPPL.set_transformed!!(SimpleVarInfo(Dict{VarName,Any}()), true) - test_linked_varinfo(model, vi) - - ## `SimpleVarInfo{<:VarNamedVector}` - vi = DynamicPPL.set_transformed!!(SimpleVarInfo(DynamicPPL.VarNamedVector()), true) - test_linked_varinfo(model, vi) end @testset "values_as" begin @@ -514,20 +485,6 @@ end model, value_true, varnames; include_threadsafe=true ) @testset "$(short_varinfo_name(varinfo))" for varinfo in varinfos - if varinfo isa DynamicPPL.SimpleOrThreadSafeSimple{<:NamedTuple} - # NOTE: this is broken since we'll end up trying to set - # - # varinfo[@varname(x[4:5])] = [x[4],] - # - # upon linking (since `x[4:5]` will be projected onto a 1-dimensional - # space). In the case of `SimpleVarInfo{<:NamedTuple}`, this results in - # calling `setindex!!(varinfo.values, [x[4],], @varname(x[4:5]))`, which - # in turn attempts to call `setindex!(varinfo.values.x, [x[4],], 4:5)`, - # i.e. a vector of length 1 (`[x[4],]`) being assigned to 2 indices (`4:5`). - @test_broken false - continue - end - if DynamicPPL.has_varnamedvector(varinfo) && mutating # NOTE: Can't handle mutating `link!` and `invlink!` `VarNamedVector`. @test_broken false @@ -591,12 +548,6 @@ end model, (; x=1.0), (@varname(x),); include_threadsafe=true ) @testset "$(short_varinfo_name(varinfo))" for varinfo in varinfos - # Skip the inconcrete `SimpleVarInfo` types, since checking for type - # stability for them doesn't make much sense anyway. - if varinfo isa SimpleVarInfo{<:AbstractDict} || - varinfo isa DynamicPPL.ThreadSafeVarInfo{<:SimpleVarInfo{<:AbstractDict}} - continue - end @inferred DynamicPPL.unflatten(varinfo, varinfo[:]) end end @@ -618,9 +569,6 @@ end model, model(), vns; include_threadsafe=true ) varinfos_standard = filter(Base.Fix2(isa, VarInfo), varinfos) - varinfos_simple = filter( - Base.Fix2(isa, DynamicPPL.SimpleOrThreadSafeSimple), varinfos - ) # `VarInfo` supports subsetting using, basically, arbitrary varnames. vns_supported_standard = [ @@ -648,33 +596,18 @@ end [@varname(s), @varname(m), @varname(x[1]), @varname(x[2])], ] - # `SimpleVarInfo` only supports subsetting using the varnames as they appear - # in the model. - vns_supported_simple = filter(∈(vns), vns_supported_standard) - @testset "$(short_varinfo_name(varinfo))" for varinfo in varinfos # All variables. check_varinfo_keys(varinfo, vns) - # Added a `convert` to make the naming of the testsets a bit more readable. - # `SimpleVarInfo{<:NamedTuple}` only supports subsetting with "simple" varnames, - ## i.e. `VarName{sym}()` without any indexing, etc. - vns_supported = - if varinfo isa DynamicPPL.SimpleOrThreadSafeSimple && - values_as(varinfo) isa NamedTuple - vns_supported_simple - else - vns_supported_standard - end - @testset ("$(convert(Vector{VarName}, vns_subset)) empty") for vns_subset in - vns_supported + vns_supported_standard varinfo_subset = subset(varinfo, VarName[]) @test isempty(varinfo_subset) end @testset "$(convert(Vector{VarName}, vns_subset))" for vns_subset in - vns_supported + vns_supported_standard varinfo_subset = subset(varinfo, vns_subset) # Should now only contain the variables in `vns_subset`. check_varinfo_keys(varinfo_subset, vns_subset) @@ -709,7 +642,7 @@ end end @testset "$(convert(Vector{VarName}, vns_subset)) order" for vns_subset in - vns_supported + vns_supported_standard varinfo_subset = subset(varinfo, vns_subset) vns_subset_reversed = reverse(vns_subset) varinfo_subset_reversed = subset(varinfo, vns_subset_reversed) @@ -718,15 +651,6 @@ end @test varinfo_subset[:] == ground_truth end end - - # For certain varinfos we should have errors. - # `SimpleVarInfo{<:NamedTuple}` can only handle varnames with `identity`. - varinfo = varinfos[findfirst(Base.Fix2(isa, SimpleVarInfo{<:NamedTuple}), varinfos)] - @testset "$(short_varinfo_name(varinfo)): failure cases" begin - @test_throws ArgumentError subset( - varinfo, [@varname(s), @varname(m), @varname(x[1])] - ) - end end @testset "merge" begin