From 8c3d30f78b375e00c1ed4fba75b0d8d0414e43a0 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Tue, 21 Oct 2025 18:08:25 +0100 Subject: [PATCH 01/15] v0.39 --- HISTORY.md | 2 ++ Project.toml | 2 +- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/HISTORY.md b/HISTORY.md index 57ccaecd1..8165317f6 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -1,5 +1,7 @@ # DynamicPPL Changelog +## 0.39.0 + ## 0.38.0 ### Breaking changes diff --git a/Project.toml b/Project.toml index 2fe65fd7b..7f58083bf 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "DynamicPPL" uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8" -version = "0.38.0" +version = "0.39.0" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" From 7300c224dc0fbac2febea2b6185c74f48181dcc9 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Fri, 31 Oct 2025 18:48:58 +0000 Subject: [PATCH 02/15] Update DPPL compats for benchmarks and docs --- benchmarks/Project.toml | 2 +- docs/Project.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/benchmarks/Project.toml b/benchmarks/Project.toml index 0d4e9a654..55ca81da0 100644 --- a/benchmarks/Project.toml +++ b/benchmarks/Project.toml @@ -23,7 +23,7 @@ DynamicPPL = {path = "../"} ADTypes = "1.14.0" BenchmarkTools = "1.6.0" Distributions = "0.25.117" -DynamicPPL = "0.38" +DynamicPPL = "0.39" Enzyme = "0.13" ForwardDiff = "0.10.38, 1" LogDensityProblems = "2.1.2" diff --git a/docs/Project.toml b/docs/Project.toml index fed06ebde..69e0a4c5a 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -19,7 +19,7 @@ Accessors = "0.1" Distributions = "0.25" Documenter = "1" DocumenterMermaid = "0.1, 0.2" -DynamicPPL = "0.38" +DynamicPPL = "0.39" FillArrays = "0.13, 1" ForwardDiff = "0.10, 1" JET = "0.9, 0.10, 0.11" From 79150baf0a70380f46d9573746a900f7c38bb370 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Tue, 4 Nov 2025 17:31:18 +0000 Subject: [PATCH 03/15] remove merge conflict markers --- HISTORY.md | 2 -- 1 file changed, 2 deletions(-) diff --git a/HISTORY.md b/HISTORY.md index cd4cbc767..45be1772d 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -19,8 +19,6 @@ The generic method `returned(::Model, values, keys)` is deprecated and will be r Added a compatibility entry for JET@0.11. -> > > > > > > main - ## 0.38.1 Added `from_linked_vec_transform` and `from_vec_transform` methods for `ProductNamedTupleDistribution`. From 4ca95281ebbe451f99813f009bfc9c7140330a45 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Tue, 11 Nov 2025 16:24:14 +0000 Subject: [PATCH 04/15] Remove `NodeTrait` (#1133) * Remove NodeTrait * Changelog * Fix exports * docs * fix a bug * Fix doctests * Fix test * tweak changelog --- HISTORY.md | 17 +++++++ docs/src/api.md | 46 +++++++++++++---- src/DynamicPPL.jl | 13 +++-- src/contexts.jl | 82 ++++++++++++------------------ src/contexts/conditionfix.jl | 92 +++++++++++----------------------- src/contexts/default.jl | 1 - src/contexts/init.jl | 1 - src/contexts/prefix.jl | 17 ++----- src/contexts/transformation.jl | 1 - src/model.jl | 4 +- src/test_utils/contexts.jl | 19 ++----- test/contexts.jl | 36 +++++++------ 12 files changed, 154 insertions(+), 175 deletions(-) diff --git a/HISTORY.md b/HISTORY.md index 613957c33..f181897f7 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -2,6 +2,23 @@ ## 0.39.0 +### Breaking changes + +#### Parent and leaf contexts + +The `DynamicPPL.NodeTrait` function has been removed. +Instead of implementing this, parent contexts should subtype `DynamicPPL.AbstractParentContext`. +This is an abstract type which requires you to overload two functions, `DynamicPPL.childcontext` and `DynamicPPL.setchildcontext`. + +There should generally be few reasons to define your own parent contexts (the only one we are aware of, outside of DynamicPPL itself, is `Turing.Inference.GibbsContext`), so this change should not really affect users. + +Leaf contexts require no changes, apart from a removal of the `NodeTrait` function. + +`ConditionContext` and `PrefixContext` are no longer exported. +You should not need to use these directly, please use `AbstractPPL.condition` and `DynamicPPL.prefix` instead. + +#### Miscellaneous + Removed the method `returned(::Model, values, keys)`; please use `returned(::Model, ::AbstractDict{<:VarName})` instead. ## 0.38.9 diff --git a/docs/src/api.md b/docs/src/api.md index bbe39fb73..63dafdfca 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -352,13 +352,6 @@ Base.empty! SimpleVarInfo ``` -### Tilde-pipeline - -```@docs -tilde_assume!! -tilde_observe!! -``` - ### Accumulators The subtypes of [`AbstractVarInfo`](@ref) store the cumulative log prior and log likelihood, and sometimes other variables that change during executing, in what are called accumulators. @@ -463,15 +456,48 @@ By default, it does not perform any actual sampling: it only evaluates the model If you wish to sample new values, see the section on [VarInfo initialisation](#VarInfo-initialisation) just below this. The behaviour of a model execution can be changed with evaluation contexts, which are a field of the model. -Contexts are subtypes of `AbstractPPL.AbstractContext`. + +All contexts are subtypes of `AbstractPPL.AbstractContext`. + +Contexts are split into two kinds: + +**Leaf contexts**: These are the most important contexts as they ultimately decide how model evaluation proceeds. +For example, `DefaultContext` evaluates the model using values stored inside a VarInfo's metadata, whereas `InitContext` obtains new values either by sampling or from a known set of parameters. +DynamicPPL has more leaf contexts which are used for internal purposes, but these are the two that are exported. ```@docs DefaultContext -PrefixContext -ConditionContext InitContext ``` +To implement a leaf context, you need to subtype `AbstractPPL.AbstractContext` and implement the `tilde_assume!!` and `tilde_observe!!` methods for your context. + +```@docs +tilde_assume!! +tilde_observe!! +``` + +**Parent contexts**: These essentially act as 'modifiers' for leaf contexts. +For example, `PrefixContext` adds a prefix to all variable names during evaluation, while `ConditionContext` marks certain variables as observed. + +To implement a parent context, you have to subtype `DynamicPPL.AbstractParentContext`, and implement the `childcontext` and `setchildcontext` methods. +If needed, you can also implement `tilde_assume!!` and `tilde_observe!!` for your context. +This is optional; the default implementation is to simply delegate to the child context. + +```@docs +AbstractParentContext +childcontext +setchildcontext +``` + +Since contexts form a tree structure, these functions are automatically defined for manipulating context stacks. +They are mainly useful for modifying the fundamental behaviour (i.e. the leaf context), without affecting any of the modifiers (i.e. parent contexts). + +```@docs +leafcontext +setleafcontext +``` + ### VarInfo initialisation The function `init!!` is used to initialise, or overwrite, values in a VarInfo. diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index e66f3fe11..c43bd89d5 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -94,16 +94,21 @@ export AbstractVarInfo, values_as_in_model, # LogDensityFunction LogDensityFunction, - # Contexts + # Leaf contexts + AbstractContext, contextualize, DefaultContext, - PrefixContext, - ConditionContext, + InitContext, + # Parent contexts + AbstractParentContext, + childcontext, + setchildcontext, + leafcontext, + setleafcontext, # Tilde pipeline tilde_assume!!, tilde_observe!!, # Initialisation - InitContext, AbstractInitStrategy, InitFromPrior, InitFromUniform, diff --git a/src/contexts.jl b/src/contexts.jl index 32a236e8e..46c5b8855 100644 --- a/src/contexts.jl +++ b/src/contexts.jl @@ -1,48 +1,32 @@ """ - NodeTrait(context) - NodeTrait(f, context) + AbstractParentContext -Specifies the role of `context` in the context-tree. +An abstract context that has a child context. -The officially supported traits are: -- `IsLeaf`: `context` does not have any decendants. -- `IsParent`: `context` has a child context to which we often defer. - Expects the following methods to be implemented: - - [`childcontext`](@ref) - - [`setchildcontext`](@ref) -""" -abstract type NodeTrait end -NodeTrait(_, context) = NodeTrait(context) - -""" - IsLeaf - -Specifies that the context is a leaf in the context-tree. -""" -struct IsLeaf <: NodeTrait end -""" - IsParent +Subtypes of `AbstractParentContext` must implement the following interface: -Specifies that the context is a parent in the context-tree. +- `DynamicPPL.childcontext(context::AbstractParentContext)`: Return the child context. +- `DynamicPPL.setchildcontext(parent::AbstractParentContext, child::AbstractContext)`: Reconstruct + `parent` but now using `child` as its child context. """ -struct IsParent <: NodeTrait end +abstract type AbstractParentContext <: AbstractContext end """ - childcontext(context) + childcontext(context::AbstractParentContext) Return the descendant context of `context`. """ childcontext """ - setchildcontext(parent::AbstractContext, child::AbstractContext) + setchildcontext(parent::AbstractParentContext, child::AbstractContext) Reconstruct `parent` but now using `child` is its [`childcontext`](@ref), effectively updating the child context. # Examples ```jldoctest -julia> using DynamicPPL: DynamicTransformationContext +julia> using DynamicPPL: DynamicTransformationContext, ConditionContext julia> ctx = ConditionContext((; a = 1)); @@ -60,12 +44,11 @@ setchildcontext """ leafcontext(context::AbstractContext) -Return the leaf of `context`, i.e. the first descendant context that `IsLeaf`. +Return the leaf of `context`, i.e. the first descendant context that is not an +`AbstractParentContext`. """ -leafcontext(context::AbstractContext) = - leafcontext(NodeTrait(leafcontext, context), context) -leafcontext(::IsLeaf, context::AbstractContext) = context -leafcontext(::IsParent, context::AbstractContext) = leafcontext(childcontext(context)) +leafcontext(context::AbstractContext) = context +leafcontext(context::AbstractParentContext) = leafcontext(childcontext(context)) """ setleafcontext(left::AbstractContext, right::AbstractContext) @@ -80,12 +63,10 @@ original leaf context of `left`. ```jldoctest julia> using DynamicPPL: leafcontext, setleafcontext, childcontext, setchildcontext, AbstractContext, DynamicTransformationContext -julia> struct ParentContext{C} <: AbstractContext +julia> struct ParentContext{C} <: AbstractParentContext context::C end -julia> DynamicPPL.NodeTrait(::ParentContext) = DynamicPPL.IsParent() - julia> DynamicPPL.childcontext(context::ParentContext) = context.context julia> DynamicPPL.setchildcontext(::ParentContext, child) = ParentContext(child) @@ -104,21 +85,10 @@ julia> # Append another parent context. ParentContext(ParentContext(ParentContext(DefaultContext()))) ``` """ -function setleafcontext(left::AbstractContext, right::AbstractContext) - return setleafcontext( - NodeTrait(setleafcontext, left), NodeTrait(setleafcontext, right), left, right - ) -end -function setleafcontext( - ::IsParent, ::IsParent, left::AbstractContext, right::AbstractContext -) +function setleafcontext(left::AbstractParentContext, right::AbstractContext) return setchildcontext(left, setleafcontext(childcontext(left), right)) end -function setleafcontext(::IsParent, ::IsLeaf, left::AbstractContext, right::AbstractContext) - return setchildcontext(left, setleafcontext(childcontext(left), right)) -end -setleafcontext(::IsLeaf, ::IsParent, left::AbstractContext, right::AbstractContext) = right -setleafcontext(::IsLeaf, ::IsLeaf, left::AbstractContext, right::AbstractContext) = right +setleafcontext(::AbstractContext, right::AbstractContext) = right """ DynamicPPL.tilde_assume!!( @@ -138,10 +108,15 @@ This function should return a tuple `(x, vi)`, where `x` is the sampled value (w must be in unlinked space!) and `vi` is the updated VarInfo. """ function tilde_assume!!( - context::AbstractContext, right::Distribution, vn::VarName, vi::AbstractVarInfo + context::AbstractParentContext, right::Distribution, vn::VarName, vi::AbstractVarInfo ) return tilde_assume!!(childcontext(context), right, vn, vi) end +function tilde_assume!!( + context::AbstractContext, ::Distribution, ::VarName, ::AbstractVarInfo +) + return error("tilde_assume!! not implemented for context of type $(typeof(context))") +end """ DynamicPPL.tilde_observe!!( @@ -171,7 +146,7 @@ This function should return a tuple `(left, vi)`, where `left` is the same as th `vi` is the updated VarInfo. """ function tilde_observe!!( - context::AbstractContext, + context::AbstractParentContext, right::Distribution, left, vn::Union{VarName,Nothing}, @@ -179,3 +154,12 @@ function tilde_observe!!( ) return tilde_observe!!(childcontext(context), right, left, vn, vi) end +function tilde_observe!!( + context::AbstractContext, + ::Distribution, + ::Any, + ::Union{VarName,Nothing}, + ::AbstractVarInfo, +) + return error("tilde_observe!! not implemented for context of type $(typeof(context))") +end diff --git a/src/contexts/conditionfix.jl b/src/contexts/conditionfix.jl index d3802de85..7a34db5cb 100644 --- a/src/contexts/conditionfix.jl +++ b/src/contexts/conditionfix.jl @@ -11,7 +11,7 @@ when there are varnames that cannot be represented as symbols, e.g. """ struct ConditionContext{ Values<:Union{NamedTuple,AbstractDict{<:VarName}},Ctx<:AbstractContext -} <: AbstractContext +} <: AbstractParentContext values::Values context::Ctx end @@ -41,9 +41,10 @@ function Base.show(io::IO, context::ConditionContext) return print(io, "ConditionContext($(context.values), $(childcontext(context)))") end -NodeTrait(::ConditionContext) = IsParent() childcontext(context::ConditionContext) = context.context -setchildcontext(parent::ConditionContext, child) = ConditionContext(parent.values, child) +function setchildcontext(parent::ConditionContext, child::AbstractContext) + return ConditionContext(parent.values, child) +end """ hasconditioned(context::AbstractContext, vn::VarName) @@ -76,11 +77,8 @@ Return `true` if `vn` is found in `context` or any of its descendants. This is contrast to [`hasconditioned(::AbstractContext, ::VarName)`](@ref) which only checks for `vn` in `context`, not recursively checking if `vn` is in any of its descendants. """ -function hasconditioned_nested(context::AbstractContext, vn) - return hasconditioned_nested(NodeTrait(hasconditioned_nested, context), context, vn) -end -hasconditioned_nested(::IsLeaf, context, vn) = hasconditioned(context, vn) -function hasconditioned_nested(::IsParent, context, vn) +hasconditioned_nested(context::AbstractContext, vn) = hasconditioned(context, vn) +function hasconditioned_nested(context::AbstractParentContext, vn) return hasconditioned(context, vn) || hasconditioned_nested(childcontext(context), vn) end function hasconditioned_nested(context::PrefixContext, vn) @@ -96,15 +94,12 @@ This is contrast to [`getconditioned`](@ref) which only returns the value `vn` i not recursively looking into its descendants. """ function getconditioned_nested(context::AbstractContext, vn) - return getconditioned_nested(NodeTrait(getconditioned_nested, context), context, vn) -end -function getconditioned_nested(::IsLeaf, context, vn) return error("context $(context) does not contain value for $vn") end function getconditioned_nested(context::PrefixContext, vn) return getconditioned_nested(collapse_prefix_stack(context), vn) end -function getconditioned_nested(::IsParent, context, vn) +function getconditioned_nested(context::AbstractParentContext, vn) return if hasconditioned(context, vn) getconditioned(context, vn) else @@ -113,7 +108,7 @@ function getconditioned_nested(::IsParent, context, vn) end """ - decondition(context::AbstractContext, syms...) + decondition_context(context::AbstractContext, syms...) Return `context` but with `syms` no longer conditioned on. @@ -121,13 +116,10 @@ Note that this recursively traverses contexts, deconditioning all along the way. See also: [`condition`](@ref) """ -decondition_context(::IsLeaf, context, args...) = context -function decondition_context(::IsParent, context, args...) +decondition_context(context::AbstractContext, args...) = context +function decondition_context(context::AbstractParentContext, args...) return setchildcontext(context, decondition_context(childcontext(context), args...)) end -function decondition_context(context, args...) - return decondition_context(NodeTrait(context), context, args...) -end function decondition_context(context::ConditionContext) return decondition_context(childcontext(context)) end @@ -160,11 +152,8 @@ Return `NamedTuple` of values that are conditioned on under context`. Note that this will recursively traverse the context stack and return a merged version of the condition values. """ -function conditioned(context::AbstractContext) - return conditioned(NodeTrait(conditioned, context), context) -end -conditioned(::IsLeaf, context) = NamedTuple() -conditioned(::IsParent, context) = conditioned(childcontext(context)) +conditioned(::AbstractContext) = NamedTuple() +conditioned(context::AbstractParentContext) = conditioned(childcontext(context)) function conditioned(context::ConditionContext) # Note the order of arguments to `merge`. The behavior of the rest of DPPL # is that the outermost `context` takes precendence, hence when resolving @@ -176,7 +165,7 @@ function conditioned(context::PrefixContext) return conditioned(collapse_prefix_stack(context)) end -struct FixedContext{Values,Ctx<:AbstractContext} <: AbstractContext +struct FixedContext{Values,Ctx<:AbstractContext} <: AbstractParentContext values::Values context::Ctx end @@ -197,16 +186,17 @@ function Base.show(io::IO, context::FixedContext) return print(io, "FixedContext($(context.values), $(childcontext(context)))") end -NodeTrait(::FixedContext) = IsParent() childcontext(context::FixedContext) = context.context -setchildcontext(parent::FixedContext, child) = FixedContext(parent.values, child) +function setchildcontext(parent::FixedContext, child::AbstractContext) + return FixedContext(parent.values, child) +end """ hasfixed(context::AbstractContext, vn::VarName) Return `true` if a fixed value for `vn` is found in `context`. """ -hasfixed(context::AbstractContext, vn::VarName) = false +hasfixed(::AbstractContext, ::VarName) = false hasfixed(context::FixedContext, vn::VarName) = hasvalue(context.values, vn) function hasfixed(context::FixedContext, vns::AbstractArray{<:VarName}) return all(Base.Fix1(hasvalue, context.values), vns) @@ -230,11 +220,8 @@ Return `true` if a fixed value for `vn` is found in `context` or any of its desc This is contrast to [`hasfixed(::AbstractContext, ::VarName)`](@ref) which only checks for `vn` in `context`, not recursively checking if `vn` is in any of its descendants. """ -function hasfixed_nested(context::AbstractContext, vn) - return hasfixed_nested(NodeTrait(hasfixed_nested, context), context, vn) -end -hasfixed_nested(::IsLeaf, context, vn) = hasfixed(context, vn) -function hasfixed_nested(::IsParent, context, vn) +hasfixed_nested(context::AbstractContext, vn) = hasfixed(context, vn) +function hasfixed_nested(context::AbstractParentContext, vn) return hasfixed(context, vn) || hasfixed_nested(childcontext(context), vn) end function hasfixed_nested(context::PrefixContext, vn) @@ -250,15 +237,12 @@ This is contrast to [`getfixed`](@ref) which only returns the value `vn` in `con not recursively looking into its descendants. """ function getfixed_nested(context::AbstractContext, vn) - return getfixed_nested(NodeTrait(getfixed_nested, context), context, vn) -end -function getfixed_nested(::IsLeaf, context, vn) return error("context $(context) does not contain value for $vn") end function getfixed_nested(context::PrefixContext, vn) return getfixed_nested(collapse_prefix_stack(context), vn) end -function getfixed_nested(::IsParent, context, vn) +function getfixed_nested(context::AbstractParentContext, vn) return if hasfixed(context, vn) getfixed(context, vn) else @@ -283,7 +267,7 @@ end function fix(values::NTuple{<:Any,<:Pair{<:VarName}}) return fix(DefaultContext(), values) end -fix(context::AbstractContext, values::NamedTuple{()}) = context +fix(context::AbstractContext, ::NamedTuple{()}) = context function fix(context::AbstractContext, values::Union{AbstractDict,NamedTuple}) return FixedContext(values, context) end @@ -306,13 +290,10 @@ Note that this recursively traverses contexts, unfixing all along the way. See also: [`fix`](@ref) """ -unfix(::IsLeaf, context, args...) = context -function unfix(::IsParent, context, args...) +unfix(context::AbstractContext, args...) = context +function unfix(context::AbstractParentContext, args...) return setchildcontext(context, unfix(childcontext(context), args...)) end -function unfix(context, args...) - return unfix(NodeTrait(context), context, args...) -end function unfix(context::FixedContext) return unfix(childcontext(context)) end @@ -341,9 +322,8 @@ Return the values that are fixed under `context`. Note that this will recursively traverse the context stack and return a merged version of the fix values. """ -fixed(context::AbstractContext) = fixed(NodeTrait(fixed, context), context) -fixed(::IsLeaf, context) = NamedTuple() -fixed(::IsParent, context) = fixed(childcontext(context)) +fixed(::AbstractContext) = NamedTuple() +fixed(context::AbstractParentContext) = fixed(childcontext(context)) function fixed(context::FixedContext) # Note the order of arguments to `merge`. The behavior of the rest of DPPL # is that the outermost `context` takes precendence, hence when resolving @@ -374,7 +354,7 @@ topic](https://turinglang.org/DynamicPPL.jl/previews/PR892/internals/submodel_co which explains this in much more detail. ```jldoctest -julia> using DynamicPPL: collapse_prefix_stack +julia> using DynamicPPL: collapse_prefix_stack, PrefixContext, ConditionContext julia> c1 = PrefixContext(@varname(a), ConditionContext((x=1, ))); @@ -403,11 +383,8 @@ function collapse_prefix_stack(context::PrefixContext) # depth of the context stack. return prefix_cond_and_fixed_variables(collapsed, context.vn_prefix) end -function collapse_prefix_stack(context::AbstractContext) - return collapse_prefix_stack(NodeTrait(collapse_prefix_stack, context), context) -end -collapse_prefix_stack(::IsLeaf, context) = context -function collapse_prefix_stack(::IsParent, context) +collapse_prefix_stack(context::AbstractContext) = context +function collapse_prefix_stack(context::AbstractParentContext) new_child_context = collapse_prefix_stack(childcontext(context)) return setchildcontext(context, new_child_context) end @@ -448,19 +425,10 @@ function prefix_cond_and_fixed_variables(ctx::FixedContext, prefix::VarName) prefixed_child_ctx = prefix_cond_and_fixed_variables(childcontext(ctx), prefix) return FixedContext(prefixed_vn_dict, prefixed_child_ctx) end -function prefix_cond_and_fixed_variables(c::AbstractContext, prefix::VarName) - return prefix_cond_and_fixed_variables( - NodeTrait(prefix_cond_and_fixed_variables, c), c, prefix - ) -end -function prefix_cond_and_fixed_variables( - ::IsLeaf, context::AbstractContext, prefix::VarName -) +function prefix_cond_and_fixed_variables(context::AbstractContext, ::VarName) return context end -function prefix_cond_and_fixed_variables( - ::IsParent, context::AbstractContext, prefix::VarName -) +function prefix_cond_and_fixed_variables(context::AbstractParentContext, prefix::VarName) return setchildcontext( context, prefix_cond_and_fixed_variables(childcontext(context), prefix) ) diff --git a/src/contexts/default.jl b/src/contexts/default.jl index ec21e1a56..3cafe39f1 100644 --- a/src/contexts/default.jl +++ b/src/contexts/default.jl @@ -17,7 +17,6 @@ with `DefaultContext` means 'calculating the log-probability associated with the in the `AbstractVarInfo`'. """ struct DefaultContext <: AbstractContext end -NodeTrait(::DefaultContext) = IsLeaf() """ DynamicPPL.tilde_assume!!( diff --git a/src/contexts/init.jl b/src/contexts/init.jl index 44dbc5508..83507353f 100644 --- a/src/contexts/init.jl +++ b/src/contexts/init.jl @@ -150,7 +150,6 @@ struct InitContext{R<:Random.AbstractRNG,S<:AbstractInitStrategy} <: AbstractCon return InitContext(Random.default_rng(), strategy) end end -NodeTrait(::InitContext) = IsLeaf() function tilde_assume!!( ctx::InitContext, dist::Distribution, vn::VarName, vi::AbstractVarInfo diff --git a/src/contexts/prefix.jl b/src/contexts/prefix.jl index 24615e683..45307874a 100644 --- a/src/contexts/prefix.jl +++ b/src/contexts/prefix.jl @@ -13,7 +13,7 @@ unique. See also: [`to_submodel`](@ref) """ -struct PrefixContext{Tvn<:VarName,C<:AbstractContext} <: AbstractContext +struct PrefixContext{Tvn<:VarName,C<:AbstractContext} <: AbstractParentContext vn_prefix::Tvn context::C end @@ -23,7 +23,6 @@ function PrefixContext(::Val{sym}, context::AbstractContext) where {sym} end PrefixContext(::Val{sym}) where {sym} = PrefixContext(VarName{sym}()) -NodeTrait(::PrefixContext) = IsParent() childcontext(context::PrefixContext) = context.context function setchildcontext(ctx::PrefixContext, child::AbstractContext) return PrefixContext(ctx.vn_prefix, child) @@ -37,11 +36,8 @@ Apply the prefixes in the context `ctx` to the variable name `vn`. function prefix(ctx::PrefixContext, vn::VarName) return AbstractPPL.prefix(prefix(childcontext(ctx), vn), ctx.vn_prefix) end -function prefix(ctx::AbstractContext, vn::VarName) - return prefix(NodeTrait(ctx), ctx, vn) -end -prefix(::IsLeaf, ::AbstractContext, vn::VarName) = vn -function prefix(::IsParent, ctx::AbstractContext, vn::VarName) +prefix(::AbstractContext, vn::VarName) = vn +function prefix(ctx::AbstractParentContext, vn::VarName) return prefix(childcontext(ctx), vn) end @@ -72,11 +68,8 @@ function prefix_and_strip_contexts(ctx::PrefixContext, vn::VarName) ) return AbstractPPL.prefix(vn_prefixed, ctx.vn_prefix), child_context_without_prefixes end -function prefix_and_strip_contexts(ctx::AbstractContext, vn::VarName) - return prefix_and_strip_contexts(NodeTrait(ctx), ctx, vn) -end -prefix_and_strip_contexts(::IsLeaf, ctx::AbstractContext, vn::VarName) = (vn, ctx) -function prefix_and_strip_contexts(::IsParent, ctx::AbstractContext, vn::VarName) +prefix_and_strip_contexts(ctx::AbstractContext, vn::VarName) = (vn, ctx) +function prefix_and_strip_contexts(ctx::AbstractParentContext, vn::VarName) vn, new_ctx = prefix_and_strip_contexts(childcontext(ctx), vn) return vn, setchildcontext(ctx, new_ctx) end diff --git a/src/contexts/transformation.jl b/src/contexts/transformation.jl index 5153f7857..c2eee2863 100644 --- a/src/contexts/transformation.jl +++ b/src/contexts/transformation.jl @@ -10,7 +10,6 @@ Note that some `AbstractVarInfo` types, must notably `VarInfo`, override the how to do the transformation, used by e.g. `SimpleVarInfo`. """ struct DynamicTransformationContext{isinverse} <: AbstractContext end -NodeTrait(::DynamicTransformationContext) = IsLeaf() function tilde_assume!!( ::DynamicTransformationContext{isinverse}, diff --git a/src/model.jl b/src/model.jl index ec98b90cd..94fcd9fd4 100644 --- a/src/model.jl +++ b/src/model.jl @@ -427,7 +427,7 @@ Return the conditioned values in `model`. ```jldoctest julia> using Distributions -julia> using DynamicPPL: conditioned, contextualize +julia> using DynamicPPL: conditioned, contextualize, PrefixContext, ConditionContext julia> @model function demo() m ~ Normal() @@ -770,7 +770,7 @@ Return the fixed values in `model`. ```jldoctest julia> using Distributions -julia> using DynamicPPL: fixed, contextualize +julia> using DynamicPPL: fixed, contextualize, PrefixContext julia> @model function demo() m ~ Normal() diff --git a/src/test_utils/contexts.jl b/src/test_utils/contexts.jl index aae2e4ec6..c48d2ddfd 100644 --- a/src/test_utils/contexts.jl +++ b/src/test_utils/contexts.jl @@ -4,11 +4,10 @@ # Utilities for testing contexts. # Dummy context to test nested behaviors. -struct TestParentContext{C<:DynamicPPL.AbstractContext} <: DynamicPPL.AbstractContext +struct TestParentContext{C<:DynamicPPL.AbstractContext} <: DynamicPPL.AbstractParentContext context::C end TestParentContext() = TestParentContext(DefaultContext()) -DynamicPPL.NodeTrait(::TestParentContext) = DynamicPPL.IsParent() DynamicPPL.childcontext(context::TestParentContext) = context.context DynamicPPL.setchildcontext(::TestParentContext, child) = TestParentContext(child) function Base.show(io::IO, c::TestParentContext) @@ -25,19 +24,13 @@ This method ensures that `context` - Correctly implements the tilde-pipeline. """ function test_context(context::DynamicPPL.AbstractContext, model::DynamicPPL.Model) - node_trait = DynamicPPL.NodeTrait(context) - if node_trait isa DynamicPPL.IsLeaf - test_leaf_context(context, model) - elseif node_trait isa DynamicPPL.IsParent - test_parent_context(context, model) - else - error("Invalid NodeTrait: $node_trait") - end + return test_leaf_context(context, model) +end +function test_context(context::DynamicPPL.AbstractParentContext, model::DynamicPPL.Model) + return test_parent_context(context, model) end function test_leaf_context(context::DynamicPPL.AbstractContext, model::DynamicPPL.Model) - @test DynamicPPL.NodeTrait(context) isa DynamicPPL.IsLeaf - # Note that for a leaf context we can't assume that it will work with an # empty VarInfo. (For example, DefaultContext will error with empty # varinfos.) Thus we only test evaluation with VarInfos that are already @@ -57,8 +50,6 @@ function test_leaf_context(context::DynamicPPL.AbstractContext, model::DynamicPP end function test_parent_context(context::DynamicPPL.AbstractContext, model::DynamicPPL.Model) - @test DynamicPPL.NodeTrait(context) isa DynamicPPL.IsParent - @testset "get/set leaf and child contexts" begin # Ensure we're using a different leaf context than the current. leafcontext_new = if DynamicPPL.leafcontext(context) isa DefaultContext diff --git a/test/contexts.jl b/test/contexts.jl index 972d833a5..ae7332a43 100644 --- a/test/contexts.jl +++ b/test/contexts.jl @@ -6,10 +6,9 @@ using DynamicPPL: childcontext, setchildcontext, AbstractContext, - NodeTrait, - IsLeaf, - IsParent, + AbstractParentContext, contextual_isassumption, + PrefixContext, FixedContext, ConditionContext, decondition_context, @@ -25,22 +24,21 @@ using LinearAlgebra: I using Random: Xoshiro # TODO: Should we maybe put this in DPPL itself? +function Base.iterate(context::AbstractParentContext) + return context, childcontext(context) +end function Base.iterate(context::AbstractContext) - if NodeTrait(context) isa IsLeaf - return nothing - end - - return context, context + return context, nothing end -function Base.iterate(_::AbstractContext, context::AbstractContext) - return _iterate(NodeTrait(context), context) +function Base.iterate(::AbstractContext, state::AbstractParentContext) + return state, childcontext(state) end -_iterate(::IsLeaf, context) = nothing -function _iterate(::IsParent, context) - child = childcontext(context) - return child, child +function Base.iterate(::AbstractContext, state::AbstractContext) + return state, nothing +end +function Base.iterate(::AbstractContext, state::Nothing) + return nothing end - Base.IteratorSize(::Type{<:AbstractContext}) = Base.SizeUnknown() Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown() @@ -347,11 +345,11 @@ Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown() @testset "collapse_prefix_stack" begin # Utility function to make sure that there are no PrefixContexts in # the context stack. - function has_no_prefixcontexts(ctx::AbstractContext) - return !(ctx isa PrefixContext) && ( - NodeTrait(ctx) isa IsLeaf || has_no_prefixcontexts(childcontext(ctx)) - ) + has_no_prefixcontexts(::PrefixContext) = false + function has_no_prefixcontexts(ctx::AbstractParentContext) + return has_no_prefixcontexts(childcontext(ctx)) end + has_no_prefixcontexts(::AbstractContext) = true # Prefix -> Condition c1 = PrefixContext(@varname(a), ConditionContext((c=1, d=2))) From 535ce4f68e8f162fb382fb5d55eae0238d332e7a Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Thu, 13 Nov 2025 13:30:43 +0000 Subject: [PATCH 05/15] FastLDF / InitContext unified (#1132) * Fast Log Density Function * Make it work with AD * Optimise performance for identity VarNames * Mark `get_range_and_linked` as having zero derivative * Update comment * make AD testing / benchmarking use FastLDF * Fix tests * Optimise away `make_evaluate_args_and_kwargs` * const func annotation * Disable benchmarks on non-typed-Metadata-VarInfo * Fix `_evaluate!!` correctly to handle submodels * Actually fix submodel evaluate * Document thoroughly and organise code * Support more VarInfos, make it thread-safe (?) * fix bug in parsing ranges from metadata/VNV * Fix get_param_eltype for TSVI * Disable Enzyme benchmark * Don't override _evaluate!!, that breaks ForwardDiff (sometimes) * Move FastLDF to experimental for now * Fix imports, add tests, etc * More test fixes * Fix imports / tests * Remove AbstractFastEvalContext * Changelog and patch bump * Add correctness tests, fix imports * Concretise parameter vector in tests * Add zero-allocation tests * Add Chairmarks as test dep * Disable allocations tests on multi-threaded * Fast InitContext (#1125) * Make InitContext work with OnlyAccsVarInfo * Do not convert NamedTuple to Dict * remove logging * Enable InitFromPrior and InitFromUniform too * Fix `infer_nested_eltype` invocation * Refactor FastLDF to use InitContext * note init breaking change * fix logjac sign * workaround Mooncake segfault * fix changelog too * Fix get_param_eltype for context stacks * Add a test for threaded observe * Export init * Remove dead code * fix transforms for pathological distributions * Tidy up loads of things * fix typed_identity spelling * fix definition order * Improve docstrings * Remove stray comment * export get_param_eltype (unfortunatley) * Add more comment * Update comment * Remove inlines, fix OAVI docstring * Improve docstrings * Simplify InitFromParams constructor * Replace map(identity, x[:]) with [i for i in x[:]] * Simplify implementation for InitContext/OAVI * Add another model to allocation tests Co-authored-by: Markus Hauru * Revert removal of dist argument (oops) * Format * Update some outdated bits of FastLDF docstring * remove underscores --------- Co-authored-by: Markus Hauru --- HISTORY.md | 15 ++ docs/src/api.md | 12 +- ext/DynamicPPLEnzymeCoreExt.jl | 13 +- ext/DynamicPPLMooncakeExt.jl | 3 + src/DynamicPPL.jl | 5 +- src/compiler.jl | 38 ++-- src/contexts/init.jl | 255 ++++++++++++++++++++---- src/experimental.jl | 2 + src/fasteval.jl | 336 ++++++++++++++++++++++++++++++++ src/model.jl | 32 ++- src/onlyaccs.jl | 42 ++++ src/utils.jl | 35 ++++ test/Project.toml | 1 + test/fasteval.jl | 233 ++++++++++++++++++++++ test/integration/enzyme/main.jl | 6 +- test/runtests.jl | 1 + 16 files changed, 955 insertions(+), 74 deletions(-) create mode 100644 src/fasteval.jl create mode 100644 src/onlyaccs.jl create mode 100644 test/fasteval.jl diff --git a/HISTORY.md b/HISTORY.md index f181897f7..0f0102ce4 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -21,6 +21,21 @@ You should not need to use these directly, please use `AbstractPPL.condition` an Removed the method `returned(::Model, values, keys)`; please use `returned(::Model, ::AbstractDict{<:VarName})` instead. +The method `DynamicPPL.init` (for implementing `AbstractInitStrategy`) now has a different signature: it must return a tuple of the generated value, plus a transform function that maps it back to unlinked space. +This is a generalisation of the previous behaviour, where `init` would always return an unlinked value (in effect forcing the transform to be the identity function). + +### Other changes + +#### FastLDF + +Added `DynamicPPL.Experimental.FastLDF`, a version of `LogDensityFunction` that provides performance improvements on the order of 2–10× for both model evaluation as well as automatic differentiation. +Exact speedups depend on the model size: larger models have less significant speedups because the bulk of the work is done in calls to `logpdf`. + +Please note that `FastLDF` is currently considered internal and its API may change without warning. +We intend to replace `LogDensityFunction` with `FastLDF` in a release in the near future, but until then we recommend not using it. + +For more information about `FastLDF`, please see https://github.com/TuringLang/DynamicPPL.jl/pull/1113 as well as the `src/fasteval.jl` file, which contains extensive comments. + ## 0.38.9 Remove warning when using Enzyme as the AD backend. diff --git a/docs/src/api.md b/docs/src/api.md index 63dafdfca..e81f18dc7 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -170,6 +170,12 @@ DynamicPPL.prefix ## Utilities +`typed_identity` is the same as `identity`, but with an overload for `with_logabsdet_jacobian` that ensures that it never errors. + +```@docs +typed_identity +``` + It is possible to manually increase (or decrease) the accumulated log likelihood or prior from within a model function. ```@docs @@ -517,10 +523,12 @@ InitFromParams ``` If you wish to write your own, you have to subtype [`DynamicPPL.AbstractInitStrategy`](@ref) and implement the `init` method. +In very rare situations, you may also need to implement `get_param_eltype`, which defines the element type of the parameters generated by the strategy. ```@docs -DynamicPPL.AbstractInitStrategy -DynamicPPL.init +AbstractInitStrategy +init +get_param_eltype ``` ### Choosing a suitable VarInfo diff --git a/ext/DynamicPPLEnzymeCoreExt.jl b/ext/DynamicPPLEnzymeCoreExt.jl index 35159636f..ef21c255b 100644 --- a/ext/DynamicPPLEnzymeCoreExt.jl +++ b/ext/DynamicPPLEnzymeCoreExt.jl @@ -1,16 +1,15 @@ module DynamicPPLEnzymeCoreExt -if isdefined(Base, :get_extension) - using DynamicPPL: DynamicPPL - using EnzymeCore -else - using ..DynamicPPL: DynamicPPL - using ..EnzymeCore -end +using DynamicPPL: DynamicPPL +using EnzymeCore # Mark is_transformed as having 0 derivative. The `nothing` return value is not significant, Enzyme # only checks whether such a method exists, and never runs it. @inline EnzymeCore.EnzymeRules.inactive(::typeof(DynamicPPL.is_transformed), args...) = nothing +# Likewise for get_range_and_linked. +@inline EnzymeCore.EnzymeRules.inactive( + ::typeof(DynamicPPL._get_range_and_linked), args... +) = nothing end diff --git a/ext/DynamicPPLMooncakeExt.jl b/ext/DynamicPPLMooncakeExt.jl index 23a3430eb..8adf66030 100644 --- a/ext/DynamicPPLMooncakeExt.jl +++ b/ext/DynamicPPLMooncakeExt.jl @@ -5,5 +5,8 @@ using Mooncake: Mooncake # This is purely an optimisation. Mooncake.@zero_derivative Mooncake.DefaultCtx Tuple{typeof(is_transformed),Vararg} +Mooncake.@zero_derivative Mooncake.DefaultCtx Tuple{ + typeof(DynamicPPL._get_range_and_linked),Vararg +} end # module diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index c43bd89d5..e9b902363 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -84,8 +84,8 @@ export AbstractVarInfo, # Compiler @model, # Utilities - init, OrderedDict, + typed_identity, # Model Model, getmissings, @@ -113,6 +113,8 @@ export AbstractVarInfo, InitFromPrior, InitFromUniform, InitFromParams, + init, + get_param_eltype, # Pseudo distributions NamedDist, NoDist, @@ -193,6 +195,7 @@ include("abstract_varinfo.jl") include("threadsafe.jl") include("varinfo.jl") include("simple_varinfo.jl") +include("onlyaccs.jl") include("compiler.jl") include("pointwise_logdensities.jl") include("logdensityfunction.jl") diff --git a/src/compiler.jl b/src/compiler.jl index badba9f9d..3324780ca 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -718,14 +718,15 @@ end # TODO(mhauru) matchingvalue has methods that can accept both types and values. Why? # TODO(mhauru) This function needs a more comprehensive docstring. """ - matchingvalue(vi, value) + matchingvalue(param_eltype, value) -Convert the `value` to the correct type for the `vi` object. +Convert the `value` to the correct type, given the element type of the parameters +being used to evaluate the model. """ -function matchingvalue(vi, value) +function matchingvalue(param_eltype, value) T = typeof(value) if hasmissing(T) - _value = convert(get_matching_type(vi, T), value) + _value = convert(get_matching_type(param_eltype, T), value) # TODO(mhauru) Why do we make a deepcopy, even though in the !hasmissing branch we # are happy to return `value` as-is? if _value === value @@ -738,29 +739,30 @@ function matchingvalue(vi, value) end end -function matchingvalue(vi, value::FloatOrArrayType) - return get_matching_type(vi, value) +function matchingvalue(param_eltype, value::FloatOrArrayType) + return get_matching_type(param_eltype, value) end -function matchingvalue(vi, ::TypeWrap{T}) where {T} - return TypeWrap{get_matching_type(vi, T)}() +function matchingvalue(param_eltype, ::TypeWrap{T}) where {T} + return TypeWrap{get_matching_type(param_eltype, T)}() end # TODO(mhauru) This function needs a more comprehensive docstring. What is it for? """ - get_matching_type(vi, ::TypeWrap{T}) where {T} + get_matching_type(param_eltype, ::TypeWrap{T}) where {T} -Get the specialized version of type `T` for `vi`. +Get the specialized version of type `T`, given an element type of the parameters +being used to evaluate the model. """ get_matching_type(_, ::Type{T}) where {T} = T -function get_matching_type(vi, ::Type{<:Union{Missing,AbstractFloat}}) - return Union{Missing,float_type_with_fallback(eltype(vi))} +function get_matching_type(param_eltype, ::Type{<:Union{Missing,AbstractFloat}}) + return Union{Missing,float_type_with_fallback(param_eltype)} end -function get_matching_type(vi, ::Type{<:AbstractFloat}) - return float_type_with_fallback(eltype(vi)) +function get_matching_type(param_eltype, ::Type{<:AbstractFloat}) + return float_type_with_fallback(param_eltype) end -function get_matching_type(vi, ::Type{<:Array{T,N}}) where {T,N} - return Array{get_matching_type(vi, T),N} +function get_matching_type(param_eltype, ::Type{<:Array{T,N}}) where {T,N} + return Array{get_matching_type(param_eltype, T),N} end -function get_matching_type(vi, ::Type{<:Array{T}}) where {T} - return Array{get_matching_type(vi, T)} +function get_matching_type(param_eltype, ::Type{<:Array{T}}) where {T} + return Array{get_matching_type(param_eltype, T)} end diff --git a/src/contexts/init.jl b/src/contexts/init.jl index 83507353f..a79969a13 100644 --- a/src/contexts/init.jl +++ b/src/contexts/init.jl @@ -1,11 +1,11 @@ """ AbstractInitStrategy -Abstract type representing the possible ways of initialising new values for -the random variables in a model (e.g., when creating a new VarInfo). +Abstract type representing the possible ways of initialising new values for the random +variables in a model (e.g., when creating a new VarInfo). -Any subtype of `AbstractInitStrategy` must implement the -[`DynamicPPL.init`](@ref) method. +Any subtype of `AbstractInitStrategy` must implement the [`DynamicPPL.init`](@ref) method, +and in some cases, [`DynamicPPL.get_param_eltype`](@ref) (see its docstring for details). """ abstract type AbstractInitStrategy end @@ -14,14 +14,60 @@ abstract type AbstractInitStrategy end Generate a new value for a random variable with the given distribution. -!!! warning "Return values must be unlinked" - The values returned by `init` must always be in the untransformed space, i.e., - they must be within the support of the original distribution. That means that, - for example, `init(rng, dist, u::InitFromUniform)` will in general return values that - are outside the range [u.lower, u.upper]. +This function must return a tuple `(x, trf)`, where + +- `x` is the generated value +- `trf` is a function that transforms the generated value back to the unlinked space. If the + value is already in unlinked space, then this should be `DynamicPPL.typed_identity`. You + can also use `Base.identity`, but if you use this, you **must** be confident that + `zero(eltype(x))` will **never** error. See the docstring of `typed_identity` for more + information. """ function init end +""" + DynamicPPL.get_param_eltype(strategy::AbstractInitStrategy) + +Return the element type of the parameters generated from the given initialisation strategy. + +The default implementation returns `Any`. However, for `InitFromParams` which provides known +parameters for evaluating the model, methods are implemented in order to return more specific +types. + +In general, if you are implementing a custom `AbstractInitStrategy`, correct behaviour can +only be guaranteed if you implement this method as well. However, quite often, the default +return value of `Any` will actually suffice. The cases where this does *not* suffice, and +where you _do_ have to manually implement `get_param_eltype`, are explained in the extended +help (see `??DynamicPPL.get_param_eltype` in the REPL). + +# Extended help + +There are a few edge cases in DynamicPPL where the element type is needed. These largely +relate to determining the element type of accumulators ahead of time (_before_ evaluation), +as well as promoting type parameters in model arguments. The classic case is when evaluating +a model with ForwardDiff: the accumulators must be set to `Dual`s, and any `Vector{Float64}` +arguments must be promoted to `Vector{Dual}`. Other tracer types, for example those in +SparseConnectivityTracer.jl, also require similar treatment. + +If the `AbstractInitStrategy` is never used in combination with tracer types, then it is +perfectly safe to return `Any`. This does not lead to type instability downstream because +the actual accumulators will still be created with concrete Float types (the `Any` is just +used to determine whether the float type needs to be modified). + +In case that wasn't enough: in fact, even the above is not always true. Firstly, the +accumulator argument is only true when evaluating with ThreadSafeVarInfo. See the comments +in `DynamicPPL.unflatten` for more details. For non-threadsafe evaluation, Julia is capable +of automatically promoting the types on its own. Secondly, the promotion only matters if you +are trying to directly assign into a `Vector{Float64}` with a `ForwardDiff.Dual` or similar +tracer type, for example using `xs[i] = MyDual`. This doesn't actually apply to +tilde-statements like `xs[i] ~ ...` because those use `Accessors.@set` under the hood, which +also does the promotion for you. For the gory details, see the following issues: + +- https://github.com/TuringLang/DynamicPPL.jl/issues/906 for accumulator types +- https://github.com/TuringLang/DynamicPPL.jl/issues/823 for type argument promotion +""" +get_param_eltype(::AbstractInitStrategy) = Any + """ InitFromPrior() @@ -29,7 +75,7 @@ Obtain new values by sampling from the prior distribution. """ struct InitFromPrior <: AbstractInitStrategy end function init(rng::Random.AbstractRNG, ::VarName, dist::Distribution, ::InitFromPrior) - return rand(rng, dist) + return rand(rng, dist), typed_identity end """ @@ -69,43 +115,61 @@ function init(rng::Random.AbstractRNG, ::VarName, dist::Distribution, u::InitFro if x isa Array{<:Any,0} x = x[] end - return x + return x, typed_identity end """ InitFromParams( - params::Union{AbstractDict{<:VarName},NamedTuple}, + params::Any fallback::Union{AbstractInitStrategy,Nothing}=InitFromPrior() ) -Obtain new values by extracting them from the given dictionary or NamedTuple. +Obtain new values by extracting them from the given set of `params`. + +The most common use case is to provide a `NamedTuple` or `AbstractDict{<:VarName}`, which +provides a mapping from variable names to values. However, we leave the type of `params` +open in order to allow for custom parameter storage types. -The parameter `fallback` specifies how new values are to be obtained if they -cannot be found in `params`, or they are specified as `missing`. `fallback` -can either be an initialisation strategy itself, in which case it will be -used to obtain new values, or it can be `nothing`, in which case an error -will be thrown. The default for `fallback` is `InitFromPrior()`. +## Custom parameter storage types -!!! note - The values in `params` must be provided in the space of the untransformed - distribution. +For `InitFromParams` to work correctly with a custom `params::P`, you need to implement + +```julia +DynamicPPL.init(rng, vn::VarName, dist::Distribution, p::InitFromParams{P}) where {P} +``` + +This tells you how to obtain values for the random variable `vn` from `p.params`. Note that +the last argument is `InitFromParams(params)`, not just `params` itself. Please see the +docstring of [`DynamicPPL.init`](@ref) for more information on the expected behaviour. + +If you only use `InitFromParams` with `DynamicPPL.OnlyAccsVarInfo`, as is usually the case, +then you will not need to implement anything else. So far, this is the same as you would do +for creating any new `AbstractInitStrategy` subtype. + +However, to use `InitFromParams` with a full `DynamicPPL.VarInfo`, you *may* also need to +implement + +```julia +DynamicPPL.get_param_eltype(p::InitFromParams{P}) where {P} +``` + +See the docstring of [`DynamicPPL.get_param_eltype`](@ref) for more information on when this +is needed. + +The argument `fallback` specifies how new values are to be obtained if they cannot be found +in `params`, or they are specified as `missing`. `fallback` can either be an initialisation +strategy itself, in which case it will be used to obtain new values, or it can be `nothing`, +in which case an error will be thrown. The default for `fallback` is `InitFromPrior()`. """ struct InitFromParams{P,S<:Union{AbstractInitStrategy,Nothing}} <: AbstractInitStrategy params::P fallback::S - function InitFromParams( - params::AbstractDict{<:VarName}, - fallback::Union{AbstractInitStrategy,Nothing}=InitFromPrior(), - ) - return new{typeof(params),typeof(fallback)}(params, fallback) - end - function InitFromParams( - params::NamedTuple, fallback::Union{AbstractInitStrategy,Nothing}=InitFromPrior() - ) - return InitFromParams(to_varname_dict(params), fallback) - end end -function init(rng::Random.AbstractRNG, vn::VarName, dist::Distribution, p::InitFromParams) +InitFromParams(params) = InitFromParams(params, InitFromPrior()) + +function init( + rng::Random.AbstractRNG, vn::VarName, dist::Distribution, p::InitFromParams{P} +) where {P<:Union{AbstractDict{<:VarName},NamedTuple}} # TODO(penelopeysm): It would be nice to do a check to make sure that all # of the parameters in `p.params` were actually used, and either warn or # error if they aren't. This is actually quite non-trivial though because @@ -119,13 +183,89 @@ function init(rng::Random.AbstractRNG, vn::VarName, dist::Distribution, p::InitF else # TODO(penelopeysm): Since x is user-supplied, maybe we could also # check here that the type / size of x matches the dist? - x + x, typed_identity end else p.fallback === nothing && error("No value was provided for the variable `$(vn)`.") init(rng, vn, dist, p.fallback) end end +function get_param_eltype( + strategy::InitFromParams{<:Union{AbstractDict{<:VarName},NamedTuple}} +) + return infer_nested_eltype(typeof(strategy.params)) +end + +""" + RangeAndLinked + +Suppose we have vectorised parameters `params::AbstractVector{<:Real}`. Each random variable +in the model will in general correspond to a sub-vector of `params`. This struct stores +information about that range, as well as whether the sub-vector represents a linked value or +an unlinked value. + +$(TYPEDFIELDS) +""" +struct RangeAndLinked + # indices that the variable corresponds to in the vectorised parameter + range::UnitRange{Int} + # whether it's linked + is_linked::Bool +end + +""" + VectorWithRanges( + iden_varname_ranges::NamedTuple, + varname_ranges::Dict{VarName,RangeAndLinked}, + vect::AbstractVector{<:Real}, + ) + +A struct that wraps a vector of parameter values, plus information about how random +variables map to ranges in that vector. + +In the simplest case, this could be accomplished only with a single dictionary mapping +VarNames to ranges and link status. However, for performance reasons, we separate out +VarNames with identity optics into a NamedTuple (`iden_varname_ranges`). All +non-identity-optic VarNames are stored in the `varname_ranges` Dict. + +It would be nice to improve the NamedTuple and Dict approach. See, e.g. +https://github.com/TuringLang/DynamicPPL.jl/issues/1116. +""" +struct VectorWithRanges{N<:NamedTuple,T<:AbstractVector{<:Real}} + # This NamedTuple stores the ranges for identity VarNames + iden_varname_ranges::N + # This Dict stores the ranges for all other VarNames + varname_ranges::Dict{VarName,RangeAndLinked} + # The full parameter vector which we index into to get variable values + vect::T +end + +function _get_range_and_linked( + vr::VectorWithRanges, ::VarName{sym,typeof(identity)} +) where {sym} + return vr.iden_varname_ranges[sym] +end +function _get_range_and_linked(vr::VectorWithRanges, vn::VarName) + return vr.varname_ranges[vn] +end +function init( + ::Random.AbstractRNG, + vn::VarName, + dist::Distribution, + p::InitFromParams{<:VectorWithRanges}, +) + vr = p.params + range_and_linked = _get_range_and_linked(vr, vn) + transform = if range_and_linked.is_linked + from_linked_vec_transform(dist) + else + from_vec_transform(dist) + end + return (@view vr.vect[range_and_linked.range]), transform +end +function get_param_eltype(strategy::InitFromParams{<:VectorWithRanges}) + return eltype(strategy.params.vect) +end """ InitContext( @@ -155,9 +295,8 @@ function tilde_assume!!( ctx::InitContext, dist::Distribution, vn::VarName, vi::AbstractVarInfo ) in_varinfo = haskey(vi, vn) - # `init()` always returns values in original space, i.e. possibly - # constrained - x = init(ctx.rng, vn, dist, ctx.strategy) + val, transform = init(ctx.rng, vn, dist, ctx.strategy) + x, inv_logjac = with_logabsdet_jacobian(transform, val) # Determine whether to insert a transformed value into the VarInfo. # If the VarInfo alrady had a value for this variable, we will # keep the same linked status as in the original VarInfo. If not, we @@ -165,17 +304,49 @@ function tilde_assume!!( # is_transformed(vi) returns true if vi is nonempty and all variables in vi # are linked. insert_transformed_value = in_varinfo ? is_transformed(vi, vn) : is_transformed(vi) - y, logjac = if insert_transformed_value - with_logabsdet_jacobian(link_transform(dist), x) + val_to_insert, logjac = if insert_transformed_value + # Calculate the forward logjac and sum them up. + y, fwd_logjac = with_logabsdet_jacobian(link_transform(dist), x) + # Note that if we use VectorWithRanges with a full VarInfo, this double-Jacobian + # calculation wastes a lot of time going from linked vectorised -> unlinked -> + # linked, and `inv_logjac` will also just be the negative of `fwd_logjac`. + # + # However, `VectorWithRanges` is only really used with `OnlyAccsVarInfo`, in which + # case this branch is never hit (since `in_varinfo` will always be false). It does + # mean that the combination of InitFromParams{<:VectorWithRanges} with a full, + # linked, VarInfo will be very slow. That should never really be used, though. So + # (at least for now) we can leave this branch in for full generality with other + # combinations of init strategies / VarInfo. + # + # TODO(penelopeysm): Figure out one day how to refactor this. The crux of the issue + # is that the transform used by `VectorWithRanges` is `from_linked_VEC_transform`, + # which is NOT the same as `inverse(link_transform)` (because there is an additional + # vectorisation step). We need `init` and `tilde_assume!!` to share this information + # but it's not clear right now how to do this. In my opinion, there are a couple of + # potential ways forward: + # + # 1. Just remove metadata entirely so that there is never any need to construct + # a linked vectorised value again. This would require us to use VAIMAcc as the only + # way of getting values. I consider this the best option, but it might take a long + # time. + # + # 2. Clean up the behaviour of bijectors so that we can have a complete separation + # between the linking and vectorisation parts of it. That way, `x` can either be + # unlinked, unlinked vectorised, linked, or linked vectorised, and regardless of + # which it is, we should only need to apply at most one linking and one + # vectorisation transform. Doing so would allow us to remove the first call to + # `with_logabsdet_jacobian`, and instead compose and/or uncompose the + # transformations before calling `with_logabsdet_jacobian` once. + y, -inv_logjac + fwd_logjac else - x, zero(LogProbType) + x, -inv_logjac end # Add the new value to the VarInfo. `push!!` errors if the value already # exists, hence the need for setindex!!. if in_varinfo - vi = setindex!!(vi, y, vn) + vi = setindex!!(vi, val_to_insert, vn) else - vi = push!!(vi, vn, y, dist) + vi = push!!(vi, vn, val_to_insert, dist) end # Neither of these set the `trans` flag so we have to do it manually if # necessary. diff --git a/src/experimental.jl b/src/experimental.jl index 8c82dca68..c644c09b2 100644 --- a/src/experimental.jl +++ b/src/experimental.jl @@ -2,6 +2,8 @@ module Experimental using DynamicPPL: DynamicPPL +include("fasteval.jl") + # This file only defines the names of the functions, and their docstrings. The actual implementations are in `ext/DynamicPPLJETExt.jl`, since we don't want to depend on JET.jl other than as a weak dependency. """ is_suitable_varinfo(model::Model, varinfo::AbstractVarInfo; kwargs...) diff --git a/src/fasteval.jl b/src/fasteval.jl new file mode 100644 index 000000000..c668b1413 --- /dev/null +++ b/src/fasteval.jl @@ -0,0 +1,336 @@ +using DynamicPPL: + AbstractVarInfo, + AccumulatorTuple, + InitContext, + InitFromParams, + LogJacobianAccumulator, + LogLikelihoodAccumulator, + LogPriorAccumulator, + Model, + ThreadSafeVarInfo, + VarInfo, + OnlyAccsVarInfo, + RangeAndLinked, + VectorWithRanges, + Metadata, + VarNamedVector, + default_accumulators, + float_type_with_fallback, + getlogjoint, + getlogjoint_internal, + getloglikelihood, + getlogprior, + getlogprior_internal +using ADTypes: ADTypes +using BangBang: BangBang +using AbstractPPL: AbstractPPL, VarName +using LogDensityProblems: LogDensityProblems +import DifferentiationInterface as DI +using Random: Random + +""" + FastLDF( + model::Model, + getlogdensity::Function=getlogjoint_internal, + varinfo::AbstractVarInfo=VarInfo(model); + adtype::Union{ADTypes.AbstractADType,Nothing}=nothing, + ) + +A struct which contains a model, along with all the information necessary to: + + - calculate its log density at a given point; + - and if `adtype` is provided, calculate the gradient of the log density at that point. + +This information can be extracted using the LogDensityProblems.jl interface, specifically, +using `LogDensityProblems.logdensity` and `LogDensityProblems.logdensity_and_gradient`. If +`adtype` is nothing, then only `logdensity` is implemented. If `adtype` is a concrete AD +backend type, then `logdensity_and_gradient` is also implemented. + +There are several options for `getlogdensity` that are 'supported' out of the box: + +- [`getlogjoint_internal`](@ref): calculate the log joint, including the log-Jacobian term + for any variables that have been linked in the provided VarInfo. +- [`getlogprior_internal`](@ref): calculate the log prior, including the log-Jacobian term + for any variables that have been linked in the provided VarInfo. +- [`getlogjoint`](@ref): calculate the log joint in the model space, ignoring any effects of + linking +- [`getlogprior`](@ref): calculate the log prior in the model space, ignoring any effects of + linking +- [`getloglikelihood`](@ref): calculate the log likelihood (this is unaffected by linking, + since transforms are only applied to random variables) + +!!! note + By default, `FastLDF` uses `getlogjoint_internal`, i.e., the result of + `LogDensityProblems.logdensity(f, x)` will depend on whether the `FastLDF` was created + with a linked or unlinked VarInfo. This is done primarily to ease interoperability with + MCMC samplers. + +If you provide one of these functions, a `VarInfo` will be automatically created for you. If +you provide a different function, you have to manually create a VarInfo and pass it as the +third argument. + +If the `adtype` keyword argument is provided, then this struct will also store the adtype +along with other information for efficient calculation of the gradient of the log density. +Note that preparing a `FastLDF` with an AD type `AutoBackend()` requires the AD backend +itself to have been loaded (e.g. with `import Backend`). + +## Fields + +Note that it is undefined behaviour to access any of a `FastLDF`'s fields, apart from: + +- `fastldf.model`: The original model from which this `FastLDF` was constructed. +- `fastldf.adtype`: The AD type used for gradient calculations, or `nothing` if no AD + type was provided. + +# Extended help + +Up until DynamicPPL v0.38, there have been two ways of evaluating a DynamicPPL model at a +given set of parameters: + +1. With `unflatten` + `evaluate!!` with `DefaultContext`: this stores a vector of parameters + inside a VarInfo's metadata, then reads parameter values from the VarInfo during evaluation. + +2. With `InitFromParams`: this reads parameter values from a NamedTuple or a Dict, and stores + them inside a VarInfo's metadata. + +In general, both of these approaches work fine, but the fact that they modify the VarInfo's +metadata can often be quite wasteful. In particular, it is very common that the only outputs +we care about from model evaluation are those which are stored in accumulators, such as log +probability densities, or `ValuesAsInModel`. + +To avoid this issue, we use `OnlyAccsVarInfo`, which is a VarInfo that only contains +accumulators. It implements enough of the `AbstractVarInfo` interface to not error during +model evaluation. + +Because `OnlyAccsVarInfo` does not store any parameter values, when evaluating a model with +it, it is mandatory that parameters are provided from outside the VarInfo, namely via +`InitContext`. + +The main problem that we face is that it is not possible to directly implement +`DynamicPPL.init(rng, vn, dist, strategy)` for `strategy::InitFromParams{<:AbstractVector}`. +In particular, it is not clear: + + - which parts of the vector correspond to which random variables, and + - whether the variables are linked or unlinked. + +Traditionally, this problem has been solved by `unflatten`, because that function would +place values into the VarInfo's metadata alongside the information about ranges and linking. +That way, when we evaluate with `DefaultContext`, we can read this information out again. +However, we want to avoid using a metadata. Thus, here, we _extract this information from +the VarInfo_ a single time when constructing a `FastLDF` object. Inside the FastLDF, we +store a mapping from VarNames to ranges in that vector, along with link status. + +For VarNames with identity optics, this is stored in a NamedTuple for efficiency. For all +other VarNames, this is stored in a Dict. The internal data structure used to represent this +could almost certainly be optimised further. See e.g. the discussion in +https://github.com/TuringLang/DynamicPPL.jl/issues/1116. + +When evaluating the model, this allows us to combine the parameter vector together with those +ranges to create an `InitFromParams{VectorWithRanges}`, which lets us very quickly read +parameter values from the vector. + +Note that this assumes that the ranges and link status are static throughout the lifetime of +the `FastLDF` object. Therefore, a `FastLDF` object cannot handle models which have variable +numbers of parameters, or models which may visit random variables in different orders depending +on stochastic control flow. **Indeed, silent errors may occur with such models.** This is a +general limitation of vectorised parameters: the original `unflatten` + `evaluate!!` +approach also fails with such models. +""" +struct FastLDF{ + M<:Model, + AD<:Union{ADTypes.AbstractADType,Nothing}, + F<:Function, + N<:NamedTuple, + ADP<:Union{Nothing,DI.GradientPrep}, +} + model::M + adtype::AD + _getlogdensity::F + _iden_varname_ranges::N + _varname_ranges::Dict{VarName,RangeAndLinked} + _adprep::ADP + + function FastLDF( + model::Model, + getlogdensity::Function=getlogjoint_internal, + varinfo::AbstractVarInfo=VarInfo(model); + adtype::Union{ADTypes.AbstractADType,Nothing}=nothing, + ) + # Figure out which variable corresponds to which index, and + # which variables are linked. + all_iden_ranges, all_ranges = get_ranges_and_linked(varinfo) + # Do AD prep if needed + prep = if adtype === nothing + nothing + else + # Make backend-specific tweaks to the adtype + adtype = DynamicPPL.tweak_adtype(adtype, model, varinfo) + x = [val for val in varinfo[:]] + DI.prepare_gradient( + FastLogDensityAt(model, getlogdensity, all_iden_ranges, all_ranges), + adtype, + x, + ) + end + return new{ + typeof(model), + typeof(adtype), + typeof(getlogdensity), + typeof(all_iden_ranges), + typeof(prep), + }( + model, adtype, getlogdensity, all_iden_ranges, all_ranges, prep + ) + end +end + +################################### +# LogDensityProblems.jl interface # +################################### +""" + fast_ldf_accs(getlogdensity::Function) + +Determine which accumulators are needed for fast evaluation with the given +`getlogdensity` function. +""" +fast_ldf_accs(::Function) = default_accumulators() +fast_ldf_accs(::typeof(getlogjoint_internal)) = default_accumulators() +function fast_ldf_accs(::typeof(getlogjoint)) + return AccumulatorTuple((LogPriorAccumulator(), LogLikelihoodAccumulator())) +end +function fast_ldf_accs(::typeof(getlogprior_internal)) + return AccumulatorTuple((LogPriorAccumulator(), LogJacobianAccumulator())) +end +fast_ldf_accs(::typeof(getlogprior)) = AccumulatorTuple((LogPriorAccumulator(),)) +fast_ldf_accs(::typeof(getloglikelihood)) = AccumulatorTuple((LogLikelihoodAccumulator(),)) + +struct FastLogDensityAt{M<:Model,F<:Function,N<:NamedTuple} + model::M + getlogdensity::F + iden_varname_ranges::N + varname_ranges::Dict{VarName,RangeAndLinked} +end +function (f::FastLogDensityAt)(params::AbstractVector{<:Real}) + ctx = InitContext( + Random.default_rng(), + InitFromParams( + VectorWithRanges(f.iden_varname_ranges, f.varname_ranges, params), nothing + ), + ) + model = DynamicPPL.setleafcontext(f.model, ctx) + accs = fast_ldf_accs(f.getlogdensity) + # Calling `evaluate!!` would be fine, but would lead to an extra call to resetaccs!!, + # which is unnecessary. So we shortcircuit this by simply calling `_evaluate!!` + # directly. To preserve thread-safety we need to reproduce the ThreadSafeVarInfo logic + # here. + # TODO(penelopeysm): This should _not_ check Threads.nthreads(). I still don't know what + # it _should_ do, but this is wrong regardless. + # https://github.com/TuringLang/DynamicPPL.jl/issues/1086 + vi = if Threads.nthreads() > 1 + accs = map( + acc -> DynamicPPL.convert_eltype(float_type_with_fallback(eltype(params)), acc), + accs, + ) + ThreadSafeVarInfo(OnlyAccsVarInfo(accs)) + else + OnlyAccsVarInfo(accs) + end + _, vi = DynamicPPL._evaluate!!(model, vi) + return f.getlogdensity(vi) +end + +function LogDensityProblems.logdensity(fldf::FastLDF, params::AbstractVector{<:Real}) + return FastLogDensityAt( + fldf.model, fldf._getlogdensity, fldf._iden_varname_ranges, fldf._varname_ranges + )( + params + ) +end + +function LogDensityProblems.logdensity_and_gradient( + fldf::FastLDF, params::AbstractVector{<:Real} +) + return DI.value_and_gradient( + FastLogDensityAt( + fldf.model, fldf._getlogdensity, fldf._iden_varname_ranges, fldf._varname_ranges + ), + fldf._adprep, + fldf.adtype, + params, + ) +end + +###################################################### +# Helper functions to extract ranges and link status # +###################################################### + +# This fails for SimpleVarInfo, but honestly there is no reason to support that here. The +# fact is that evaluation doesn't use a VarInfo, it only uses it once to generate the ranges +# and link status. So there is no motivation to use SimpleVarInfo inside a +# LogDensityFunction any more, we can just always use typed VarInfo. In fact one could argue +# that there is no purpose in supporting untyped VarInfo either. +""" + get_ranges_and_linked(varinfo::VarInfo) + +Given a `VarInfo`, extract the ranges of each variable in the vectorised parameter +representation, along with whether each variable is linked or unlinked. + +This function should return a tuple containing: + +- A NamedTuple mapping VarNames with identity optics to their corresponding `RangeAndLinked` +- A Dict mapping all other VarNames to their corresponding `RangeAndLinked`. +""" +function get_ranges_and_linked(varinfo::VarInfo{<:NamedTuple{syms}}) where {syms} + all_iden_ranges = NamedTuple() + all_ranges = Dict{VarName,RangeAndLinked}() + offset = 1 + for sym in syms + md = varinfo.metadata[sym] + this_md_iden, this_md_others, offset = get_ranges_and_linked_metadata(md, offset) + all_iden_ranges = merge(all_iden_ranges, this_md_iden) + all_ranges = merge(all_ranges, this_md_others) + end + return all_iden_ranges, all_ranges +end +function get_ranges_and_linked(varinfo::VarInfo{<:Union{Metadata,VarNamedVector}}) + all_iden, all_others, _ = get_ranges_and_linked_metadata(varinfo.metadata, 1) + return all_iden, all_others +end +function get_ranges_and_linked_metadata(md::Metadata, start_offset::Int) + all_iden_ranges = NamedTuple() + all_ranges = Dict{VarName,RangeAndLinked}() + offset = start_offset + for (vn, idx) in md.idcs + is_linked = md.is_transformed[idx] + range = md.ranges[idx] .+ (start_offset - 1) + if AbstractPPL.getoptic(vn) === identity + all_iden_ranges = merge( + all_iden_ranges, + NamedTuple((AbstractPPL.getsym(vn) => RangeAndLinked(range, is_linked),)), + ) + else + all_ranges[vn] = RangeAndLinked(range, is_linked) + end + offset += length(range) + end + return all_iden_ranges, all_ranges, offset +end +function get_ranges_and_linked_metadata(vnv::VarNamedVector, start_offset::Int) + all_iden_ranges = NamedTuple() + all_ranges = Dict{VarName,RangeAndLinked}() + offset = start_offset + for (vn, idx) in vnv.varname_to_index + is_linked = vnv.is_unconstrained[idx] + range = vnv.ranges[idx] .+ (start_offset - 1) + if AbstractPPL.getoptic(vn) === identity + all_iden_ranges = merge( + all_iden_ranges, + NamedTuple((AbstractPPL.getsym(vn) => RangeAndLinked(range, is_linked),)), + ) + else + all_ranges[vn] = RangeAndLinked(range, is_linked) + end + offset += length(range) + end + return all_iden_ranges, all_ranges, offset +end diff --git a/src/model.jl b/src/model.jl index 94fcd9fd4..2bcfe8f98 100644 --- a/src/model.jl +++ b/src/model.jl @@ -986,9 +986,13 @@ Return the arguments and keyword arguments to be passed to the evaluator of the ) where {_F,argnames} unwrap_args = [ if is_splat_symbol(var) - :($matchingvalue(varinfo, model.args.$var)...) + :( + $matchingvalue( + $get_param_eltype(varinfo, model.context), model.args.$var + )... + ) else - :($matchingvalue(varinfo, model.args.$var)) + :($matchingvalue($get_param_eltype(varinfo, model.context), model.args.$var)) end for var in argnames ] return quote @@ -1006,6 +1010,30 @@ Return the arguments and keyword arguments to be passed to the evaluator of the end end +""" + get_param_eltype(varinfo::AbstractVarInfo, context::AbstractContext) + +Get the element type of the parameters being used to evaluate a model, using a `varinfo` +under the given `context`. For example, when evaluating a model with ForwardDiff AD, this +should return `ForwardDiff.Dual`. + +By default, this uses `eltype(varinfo)` which is slightly cursed. This relies on the fact +that typically, before evaluation, the parameters will have been inserted into the VarInfo's +metadata field. + +For `InitContext`, it's quite different: because `InitContext` is responsible for supplying +the parameters, we can avoid using `eltype(varinfo)` and instead query the parameters inside +it. See the docstring of `get_param_eltype(strategy::AbstractInitStrategy)` for more +explanation. +""" +function get_param_eltype(vi::AbstractVarInfo, ctx::AbstractParentContext) + return get_param_eltype(vi, DynamicPPL.childcontext(ctx)) +end +get_param_eltype(vi::AbstractVarInfo, ::AbstractContext) = eltype(vi) +function get_param_eltype(::AbstractVarInfo, ctx::InitContext) + return get_param_eltype(ctx.strategy) +end + """ getargnames(model::Model) diff --git a/src/onlyaccs.jl b/src/onlyaccs.jl new file mode 100644 index 000000000..940f23124 --- /dev/null +++ b/src/onlyaccs.jl @@ -0,0 +1,42 @@ +""" + OnlyAccsVarInfo + +This is a wrapper around an `AccumulatorTuple` that implements the minimal `AbstractVarInfo` +interface to work with the `tilde_assume!!` and `tilde_observe!!` functions for +`InitContext`. + +Note that this does not implement almost every other AbstractVarInfo interface function, and +so using this with a different leaf context such as `DefaultContext` will result in errors. + +Conceptually, one can also think of this as a VarInfo that doesn't contain a metadata field. +This is also why it only works with `InitContext`: in this case, the parameters used for +evaluation are supplied by the context instead of the metadata. +""" +struct OnlyAccsVarInfo{Accs<:AccumulatorTuple} <: AbstractVarInfo + accs::Accs +end +OnlyAccsVarInfo() = OnlyAccsVarInfo(default_accumulators()) +function OnlyAccsVarInfo(accs::NTuple{N,AbstractAccumulator}) where {N} + return OnlyAccsVarInfo(AccumulatorTuple(accs)) +end + +# Minimal AbstractVarInfo interface +DynamicPPL.maybe_invlink_before_eval!!(vi::OnlyAccsVarInfo, ::Model) = vi +DynamicPPL.getaccs(vi::OnlyAccsVarInfo) = vi.accs +DynamicPPL.setaccs!!(::OnlyAccsVarInfo, accs::AccumulatorTuple) = OnlyAccsVarInfo(accs) + +# Ideally, we'd define this together with InitContext, but alas that file comes way before +# this one, and sorting out the include order is a pain. +function tilde_assume!!( + ctx::InitContext, + dist::Distribution, + vn::VarName, + vi::Union{OnlyAccsVarInfo,ThreadSafeVarInfo{<:OnlyAccsVarInfo}}, +) + # For OnlyAccsVarInfo, since we don't need to write into the VarInfo, we can + # cut out a lot of the code above. + val, transform = init(ctx.rng, vn, dist, ctx.strategy) + x, inv_logjac = with_logabsdet_jacobian(transform, val) + vi = accumulate_assume!!(vi, x, -inv_logjac, vn, dist) + return x, vi +end diff --git a/src/utils.jl b/src/utils.jl index b55a2f715..75fb805dc 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -15,6 +15,41 @@ This is Float64 on 64-bit systems and Float32 on 32-bit systems. """ const LogProbType = float(Real) +""" + typed_identity(x) + +Identity function, but with an overload for `with_logabsdet_jacobian` to ensure +that it returns a sensible zero logjac. + +The problem with plain old `identity` is that the default definition of +`with_logabsdet_jacobian` for `identity` returns `zero(eltype(x))`: +https://github.com/JuliaMath/ChangesOfVariables.jl/blob/d6a8115fc9b9419decbdb48e2c56ec9675b4c6a4/src/with_ladj.jl#L154 + +This is fine for most samples `x`, but if `eltype(x)` doesn't return a sensible type (e.g. +if it's `Any`), then using `identity` will error with `zero(Any)`. This can happen with, +for example, `ProductNamedTupleDistribution`: + +```julia +julia> using Distributions; d = product_distribution((a = Normal(), b = LKJCholesky(3, 0.5))); + +julia> eltype(rand(d)) +Any +``` + +The same problem precludes us from eventually broadening the scope of DynamicPPL.jl to +support distributions with non-numeric samples. + +Furthermore, in principle, the type of the log-probability should be separate from the type +of the sample. Thus, instead of using `zero(LogProbType)`, we should use the eltype of the +LogJacobianAccumulator. There's no easy way to thread that through here, but if a way to do +this is discovered, then `typed_identity` is what will allow us to obtain that custom +behaviour. +""" +function typed_identity end +@inline typed_identity(x) = x +@inline Bijectors.with_logabsdet_jacobian(::typeof(typed_identity), x) = + (x, zero(LogProbType)) + """ @addlogprob!(ex) diff --git a/test/Project.toml b/test/Project.toml index 2dbd5b455..efd916308 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -6,6 +6,7 @@ Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697" Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" BangBang = "198e06fe-97b7-11e9-32a5-e1d131e6ad66" Bijectors = "76274a88-744f-5084-9051-94815aaf08c4" +Chairmarks = "0ca39b1e-fe0b-4e98-acfc-b1656634c4de" Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa" DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63" Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b" diff --git a/test/fasteval.jl b/test/fasteval.jl new file mode 100644 index 000000000..db2333711 --- /dev/null +++ b/test/fasteval.jl @@ -0,0 +1,233 @@ +module DynamicPPLFastLDFTests + +using AbstractPPL: AbstractPPL +using Chairmarks +using DynamicPPL +using Distributions +using DistributionsAD: filldist +using ADTypes +using DynamicPPL.Experimental: FastLDF +using DynamicPPL.TestUtils.AD: run_ad, WithExpectedResult, NoTest +using LinearAlgebra: I +using Test +using LogDensityProblems: LogDensityProblems + +using ForwardDiff: ForwardDiff +using ReverseDiff: ReverseDiff +# Need to include this block here in case we run this test file standalone +@static if VERSION < v"1.12" + using Pkg + Pkg.add("Mooncake") + using Mooncake: Mooncake +end + +@testset "FastLDF: Correctness" begin + @testset "$(m.f)" for m in DynamicPPL.TestUtils.DEMO_MODELS + @testset "$varinfo_func" for varinfo_func in [ + DynamicPPL.untyped_varinfo, + DynamicPPL.typed_varinfo, + DynamicPPL.untyped_vector_varinfo, + DynamicPPL.typed_vector_varinfo, + ] + unlinked_vi = varinfo_func(m) + @testset "$islinked" for islinked in (false, true) + vi = if islinked + DynamicPPL.link!!(unlinked_vi, m) + else + unlinked_vi + end + nt_ranges, dict_ranges = DynamicPPL.Experimental.get_ranges_and_linked(vi) + params = [x for x in vi[:]] + # Iterate over all variables + for vn in keys(vi) + # Check that `getindex_internal` returns the same thing as using the ranges + # directly + range_with_linked = if AbstractPPL.getoptic(vn) === identity + nt_ranges[AbstractPPL.getsym(vn)] + else + dict_ranges[vn] + end + @test params[range_with_linked.range] == + DynamicPPL.getindex_internal(vi, vn) + # Check that the link status is correct + @test range_with_linked.is_linked == islinked + end + + # Compare results of FastLDF vs ordinary LogDensityFunction. These tests + # can eventually go once we replace LogDensityFunction with FastLDF, but + # for now it helps to have this check! (Eventually we should just check + # against manually computed log-densities). + # + # TODO(penelopeysm): I think we need to add tests for some really + # pathological models here. + @testset "$getlogdensity" for getlogdensity in ( + DynamicPPL.getlogjoint_internal, + DynamicPPL.getlogjoint, + DynamicPPL.getloglikelihood, + DynamicPPL.getlogprior_internal, + DynamicPPL.getlogprior, + ) + ldf = DynamicPPL.LogDensityFunction(m, getlogdensity, vi) + fldf = FastLDF(m, getlogdensity, vi) + @test LogDensityProblems.logdensity(ldf, params) ≈ + LogDensityProblems.logdensity(fldf, params) + end + end + end + end + + @testset "Threaded observe" begin + if Threads.nthreads() > 1 + @model function threaded(y) + x ~ Normal() + Threads.@threads for i in eachindex(y) + y[i] ~ Normal(x) + end + end + N = 100 + model = threaded(zeros(N)) + ldf = DynamicPPL.Experimental.FastLDF(model) + + xs = [1.0] + @test LogDensityProblems.logdensity(ldf, xs) ≈ + logpdf(Normal(), xs[1]) + N * logpdf(Normal(xs[1]), 0.0) + end + end +end + +@testset "FastLDF: performance" begin + if Threads.nthreads() == 1 + # Evaluating these three models should not lead to any allocations (but only when + # not using TSVI). + @model function f() + x ~ Normal() + return 1.0 ~ Normal(x) + end + @model function submodel_inner() + m ~ Normal(0, 1) + s ~ Exponential() + return (m=m, s=s) + end + # Note that for the allocation tests to work on this one, `inner` has + # to be passed as an argument to `submodel_outer`, instead of just + # being called inside the model function itself + @model function submodel_outer(inner) + params ~ to_submodel(inner) + y ~ Normal(params.m, params.s) + return 1.0 ~ Normal(y) + end + @testset for model in + (f(), submodel_inner() | (; s=0.0), submodel_outer(submodel_inner())) + vi = VarInfo(model) + fldf = DynamicPPL.Experimental.FastLDF( + model, DynamicPPL.getlogjoint_internal, vi + ) + x = vi[:] + bench = median(@be LogDensityProblems.logdensity(fldf, x)) + @test iszero(bench.allocs) + end + end +end + +@testset "AD with FastLDF" begin + # Used as the ground truth that others are compared against. + ref_adtype = AutoForwardDiff() + + test_adtypes = @static if VERSION < v"1.12" + [ + AutoReverseDiff(; compile=false), + AutoReverseDiff(; compile=true), + AutoMooncake(; config=nothing), + ] + else + [AutoReverseDiff(; compile=false), AutoReverseDiff(; compile=true)] + end + + @testset "Correctness" begin + @testset "$(m.f)" for m in DynamicPPL.TestUtils.DEMO_MODELS + varinfo = VarInfo(m) + linked_varinfo = DynamicPPL.link(varinfo, m) + f = FastLDF(m, getlogjoint_internal, linked_varinfo) + x = [p for p in linked_varinfo[:]] + + # Calculate reference logp + gradient of logp using ForwardDiff + ref_ad_result = run_ad(m, ref_adtype; varinfo=linked_varinfo, test=NoTest()) + ref_logp, ref_grad = ref_ad_result.value_actual, ref_ad_result.grad_actual + + @testset "$adtype" for adtype in test_adtypes + @info "Testing AD on: $(m.f) - $adtype" + + @test run_ad( + m, + adtype; + varinfo=linked_varinfo, + test=WithExpectedResult(ref_logp, ref_grad), + ) isa Any + end + end + end + + # Test that various different ways of specifying array types as arguments work with all + # ADTypes. + @testset "Array argument types" begin + test_m = randn(2, 3) + + function eval_logp_and_grad(model, m, adtype) + ldf = FastLDF(model(); adtype=adtype) + return LogDensityProblems.logdensity_and_gradient(ldf, m[:]) + end + + @model function scalar_matrix_model(::Type{T}=Float64) where {T<:Real} + m = Matrix{T}(undef, 2, 3) + return m ~ filldist(MvNormal(zeros(2), I), 3) + end + + scalar_matrix_model_reference = eval_logp_and_grad( + scalar_matrix_model, test_m, ref_adtype + ) + + @model function matrix_model(::Type{T}=Matrix{Float64}) where {T} + m = T(undef, 2, 3) + return m ~ filldist(MvNormal(zeros(2), I), 3) + end + + matrix_model_reference = eval_logp_and_grad(matrix_model, test_m, ref_adtype) + + @model function scalar_array_model(::Type{T}=Float64) where {T<:Real} + m = Array{T}(undef, 2, 3) + return m ~ filldist(MvNormal(zeros(2), I), 3) + end + + scalar_array_model_reference = eval_logp_and_grad( + scalar_array_model, test_m, ref_adtype + ) + + @model function array_model(::Type{T}=Array{Float64}) where {T} + m = T(undef, 2, 3) + return m ~ filldist(MvNormal(zeros(2), I), 3) + end + + array_model_reference = eval_logp_and_grad(array_model, test_m, ref_adtype) + + @testset "$adtype" for adtype in test_adtypes + scalar_matrix_model_logp_and_grad = eval_logp_and_grad( + scalar_matrix_model, test_m, adtype + ) + @test scalar_matrix_model_logp_and_grad[1] ≈ scalar_matrix_model_reference[1] + @test scalar_matrix_model_logp_and_grad[2] ≈ scalar_matrix_model_reference[2] + matrix_model_logp_and_grad = eval_logp_and_grad(matrix_model, test_m, adtype) + @test matrix_model_logp_and_grad[1] ≈ matrix_model_reference[1] + @test matrix_model_logp_and_grad[2] ≈ matrix_model_reference[2] + scalar_array_model_logp_and_grad = eval_logp_and_grad( + scalar_array_model, test_m, adtype + ) + @test scalar_array_model_logp_and_grad[1] ≈ scalar_array_model_reference[1] + @test scalar_array_model_logp_and_grad[2] ≈ scalar_array_model_reference[2] + array_model_logp_and_grad = eval_logp_and_grad(array_model, test_m, adtype) + @test array_model_logp_and_grad[1] ≈ array_model_reference[1] + @test array_model_logp_and_grad[2] ≈ array_model_reference[2] + end + end +end + +end diff --git a/test/integration/enzyme/main.jl b/test/integration/enzyme/main.jl index b40bbeb8f..ea4ec497d 100644 --- a/test/integration/enzyme/main.jl +++ b/test/integration/enzyme/main.jl @@ -6,8 +6,10 @@ import Enzyme: set_runtime_activity, Forward, Reverse, Const using ForwardDiff: ForwardDiff # run_ad uses FD for correctness test ADTYPES = Dict( - "EnzymeForward" => AutoEnzyme(; mode=set_runtime_activity(Forward)), - "EnzymeReverse" => AutoEnzyme(; mode=set_runtime_activity(Reverse)), + "EnzymeForward" => + AutoEnzyme(; mode=set_runtime_activity(Forward), function_annotation=Const), + "EnzymeReverse" => + AutoEnzyme(; mode=set_runtime_activity(Reverse), function_annotation=Const), ) @testset "$ad_key" for (ad_key, ad_type) in ADTYPES diff --git a/test/runtests.jl b/test/runtests.jl index 861d3bb87..10fac8b0f 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -89,6 +89,7 @@ include("test_util.jl") include("ext/DynamicPPLMooncakeExt.jl") end include("ad.jl") + include("fasteval.jl") end @testset "prob and logprob macro" begin @test_throws ErrorException prob"..." From 9624103885a6f8b6cac70b1d0796da5a6227b65a Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Fri, 14 Nov 2025 00:34:59 +0000 Subject: [PATCH 06/15] implement `LogDensityProblems.dimension` --- src/fasteval.jl | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/src/fasteval.jl b/src/fasteval.jl index c668b1413..aa2fdd933 100644 --- a/src/fasteval.jl +++ b/src/fasteval.jl @@ -149,6 +149,7 @@ struct FastLDF{ _iden_varname_ranges::N _varname_ranges::Dict{VarName,RangeAndLinked} _adprep::ADP + _dim::Int function FastLDF( model::Model, @@ -159,13 +160,14 @@ struct FastLDF{ # Figure out which variable corresponds to which index, and # which variables are linked. all_iden_ranges, all_ranges = get_ranges_and_linked(varinfo) + x = [val for val in varinfo[:]] + dim = length(x) # Do AD prep if needed prep = if adtype === nothing nothing else # Make backend-specific tweaks to the adtype adtype = DynamicPPL.tweak_adtype(adtype, model, varinfo) - x = [val for val in varinfo[:]] DI.prepare_gradient( FastLogDensityAt(model, getlogdensity, all_iden_ranges, all_ranges), adtype, @@ -179,7 +181,7 @@ struct FastLDF{ typeof(all_iden_ranges), typeof(prep), }( - model, adtype, getlogdensity, all_iden_ranges, all_ranges, prep + model, adtype, getlogdensity, all_iden_ranges, all_ranges, prep, dim ) end end @@ -260,6 +262,10 @@ function LogDensityProblems.logdensity_and_gradient( ) end +function LogDensityProblems.dimension(fldf::FastLDF) + return fldf._dim +end + ###################################################### # Helper functions to extract ranges and link status # ###################################################### From ce807139b31919b40afc98bcc522e2f41e14dc30 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Fri, 14 Nov 2025 00:40:37 +0000 Subject: [PATCH 07/15] forgot about capabilities... --- src/fasteval.jl | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/src/fasteval.jl b/src/fasteval.jl index aa2fdd933..4f402f4a8 100644 --- a/src/fasteval.jl +++ b/src/fasteval.jl @@ -262,6 +262,16 @@ function LogDensityProblems.logdensity_and_gradient( ) end +function LogDensityProblems.capabilities( + ::Type{<:DynamicPPL.Experimental.FastLDF{M,Nothing}} +) where {M} + return LogDensityProblems.LogDensityOrder{0}() +end +function LogDensityProblems.capabilities( + ::Type{<:DynamicPPL.Experimental.FastLDF{M,<:ADTypes.AbstractADType}} +) where {M} + return LogDensityProblems.LogDensityOrder{1}() +end function LogDensityProblems.dimension(fldf::FastLDF) return fldf._dim end From 8553e401182b894a653b638e5978f00ae42ae031 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Tue, 18 Nov 2025 12:03:14 +0000 Subject: [PATCH 08/15] use interpolation in run_ad --- src/test_utils/ad.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/test_utils/ad.jl b/src/test_utils/ad.jl index a49ffd18b..d7a34e6e0 100644 --- a/src/test_utils/ad.jl +++ b/src/test_utils/ad.jl @@ -298,8 +298,8 @@ function run_ad( # Benchmark grad_time, primal_time = if benchmark - primal_benchmark = @be (ldf, params) logdensity(_[1], _[2]) - grad_benchmark = @be (ldf, params) logdensity_and_gradient(_[1], _[2]) + primal_benchmark = @be logdensity($ldf, $params) + grad_benchmark = @be logdensity_and_gradient($ldf, $params) median_primal = median(primal_benchmark).time median_grad = median(grad_benchmark).time r(f) = round(f; sigdigits=4) From 3cd8d3431e14ebc581266c1323d1db8a5bd4c0eb Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Tue, 18 Nov 2025 17:48:32 +0000 Subject: [PATCH 09/15] Improvements to benchmark outputs (#1146) * print output * fix * reenable * add more lines to guide the eye * reorder table * print tgrad / trel as well * forgot this type --- benchmarks/benchmarks.jl | 38 +++++++++++++++++++++++++++++++++----- src/test_utils/ad.jl | 10 +++++++++- 2 files changed, 42 insertions(+), 6 deletions(-) diff --git a/benchmarks/benchmarks.jl b/benchmarks/benchmarks.jl index 3af6573cf..e8ffa7e0b 100644 --- a/benchmarks/benchmarks.jl +++ b/benchmarks/benchmarks.jl @@ -98,12 +98,15 @@ function run(; to_json=false) }[] for (model_name, model, varinfo_choice, adbackend, islinked) in chosen_combinations - @info "Running benchmark for $model_name" + @info "Running benchmark for $model_name, $varinfo_choice, $adbackend, $islinked" relative_eval_time, relative_ad_eval_time = try results = benchmark(model, varinfo_choice, adbackend, islinked) + @info " t(eval) = $(results.primal_time)" + @info " t(grad) = $(results.grad_time)" (results.primal_time / reference_time), (results.grad_time / results.primal_time) catch e + @info "benchmark errored: $e" missing, missing end push!( @@ -155,18 +158,33 @@ function combine(head_filename::String, base_filename::String) all_testcases = union(Set(keys(head_testcases)), Set(keys(base_testcases))) @info "$(length(all_testcases)) unique test cases found" sorted_testcases = sort( - collect(all_testcases); by=(c -> (c.model_name, c.ad_backend, c.varinfo, c.linked)) + collect(all_testcases); by=(c -> (c.model_name, c.linked, c.varinfo, c.ad_backend)) ) results_table = Tuple{ - String,Int,String,String,Bool,String,String,String,String,String,String + String, + Int, + String, + String, + Bool, + String, + String, + String, + String, + String, + String, + String, + String, + String, }[] + sublabels = ["base", "this PR", "speedup"] results_colnames = [ [ EmptyCells(5), MultiColumn(3, "t(eval) / t(ref)"), MultiColumn(3, "t(grad) / t(eval)"), + MultiColumn(3, "t(grad) / t(ref)"), ], - [colnames[1:5]..., "base", "this PR", "speedup", "base", "this PR", "speedup"], + [colnames[1:5]..., sublabels..., sublabels..., sublabels...], ] sprint_float(x::Float64) = @sprintf("%.2f", x) sprint_float(m::Missing) = "err" @@ -183,6 +201,10 @@ function combine(head_filename::String, base_filename::String) # Finally that lets us do this division safely speedup_eval = base_eval / head_eval speedup_grad = base_grad / head_grad + # As well as this multiplication, which is t(grad) / t(ref) + head_grad_vs_ref = head_grad * head_eval + base_grad_vs_ref = base_grad * base_eval + speedup_grad_vs_ref = base_grad_vs_ref / head_grad_vs_ref push!( results_table, ( @@ -197,6 +219,9 @@ function combine(head_filename::String, base_filename::String) sprint_float(base_grad), sprint_float(head_grad), sprint_float(speedup_grad), + sprint_float(base_grad_vs_ref), + sprint_float(head_grad_vs_ref), + sprint_float(speedup_grad_vs_ref), ), ) end @@ -212,7 +237,10 @@ function combine(head_filename::String, base_filename::String) backend=:text, fit_table_in_display_horizontally=false, fit_table_in_display_vertically=false, - table_format=TextTableFormat(; horizontal_line_at_merged_column_labels=true), + table_format=TextTableFormat(; + horizontal_line_at_merged_column_labels=true, + horizontal_lines_at_data_rows=collect(3:3:length(results_table)), + ), ) println("```") end diff --git a/src/test_utils/ad.jl b/src/test_utils/ad.jl index d7a34e6e0..8ee850877 100644 --- a/src/test_utils/ad.jl +++ b/src/test_utils/ad.jl @@ -5,7 +5,13 @@ using Chairmarks: @be import DifferentiationInterface as DI using DocStringExtensions using DynamicPPL: - Model, LogDensityFunction, VarInfo, AbstractVarInfo, getlogjoint_internal, link + DynamicPPL, + Model, + LogDensityFunction, + VarInfo, + AbstractVarInfo, + getlogjoint_internal, + link using LogDensityProblems: logdensity, logdensity_and_gradient using Random: AbstractRNG, default_rng using Statistics: median @@ -298,7 +304,9 @@ function run_ad( # Benchmark grad_time, primal_time = if benchmark + logdensity(ldf, params) # Warm-up primal_benchmark = @be logdensity($ldf, $params) + logdensity_and_gradient(ldf, params) # Warm-up grad_benchmark = @be logdensity_and_gradient($ldf, $params) median_primal = median(primal_benchmark).time median_grad = median(grad_benchmark).time From 4a1156038eb673dc9567d8a3a4d008455ec83908 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Sat, 22 Nov 2025 00:26:16 +0000 Subject: [PATCH 10/15] Allow generation of `ParamsWithStats` from `FastLDF` plus parameters, and also `bundle_samples` (#1129) * Implement `ParamsWithStats` for `FastLDF` * Add comments * Implement `bundle_samples` for ParamsWithStats -> MCMCChains * Remove redundant comment * don't need Statistics? --- ext/DynamicPPLMCMCChainsExt.jl | 37 ++++++++++++++++ src/DynamicPPL.jl | 2 +- src/chains.jl | 57 ++++++++++++++++++++++++ src/fasteval.jl | 81 ++++++++++++++++++++++++---------- test/chains.jl | 28 +++++++++++- 5 files changed, 180 insertions(+), 25 deletions(-) diff --git a/ext/DynamicPPLMCMCChainsExt.jl b/ext/DynamicPPLMCMCChainsExt.jl index d8c343917..e74f0b8a9 100644 --- a/ext/DynamicPPLMCMCChainsExt.jl +++ b/ext/DynamicPPLMCMCChainsExt.jl @@ -140,6 +140,43 @@ function AbstractMCMC.to_samples( end end +function AbstractMCMC.bundle_samples( + ts::Vector{<:DynamicPPL.ParamsWithStats}, + model::DynamicPPL.Model, + spl::AbstractMCMC.AbstractSampler, + state, + chain_type::Type{MCMCChains.Chains}; + save_state=false, + stats=missing, + sort_chain=false, + discard_initial=0, + thinning=1, + kwargs..., +) + bare_chain = AbstractMCMC.from_samples(MCMCChains.Chains, reshape(ts, :, 1)) + + # Add additional MCMC-specific info + info = bare_chain.info + if save_state + info = merge(info, (model=model, sampler=spl, samplerstate=state)) + end + if !ismissing(stats) + info = merge(info, (start_time=stats.start, stop_time=stats.stop)) + end + + # Reconstruct the chain with the extra information + # Yeah, this is quite ugly. Blame MCMCChains. + chain = MCMCChains.Chains( + bare_chain.value.data, + names(bare_chain), + bare_chain.name_map; + info=info, + start=discard_initial + 1, + thin=thinning, + ) + return sort_chain ? sort(chain) : chain +end + """ predict([rng::AbstractRNG,] model::Model, chain::MCMCChains.Chains; include_all=false) diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index e9b902363..6d3900e91 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -202,6 +202,7 @@ include("logdensityfunction.jl") include("model_utils.jl") include("extract_priors.jl") include("values_as_in_model.jl") +include("experimental.jl") include("chains.jl") include("bijector.jl") @@ -209,7 +210,6 @@ include("debug_utils.jl") using .DebugUtils include("test_utils.jl") -include("experimental.jl") include("deprecated.jl") if isdefined(Base.Experimental, :register_error_hint) diff --git a/src/chains.jl b/src/chains.jl index 2b5976b9b..892423822 100644 --- a/src/chains.jl +++ b/src/chains.jl @@ -133,3 +133,60 @@ function ParamsWithStats( end return ParamsWithStats(params, stats) end + +""" + ParamsWithStats( + param_vector::AbstractVector, + ldf::DynamicPPL.Experimental.FastLDF, + stats::NamedTuple=NamedTuple(); + include_colon_eq::Bool=true, + include_log_probs::Bool=true, + ) + +Generate a `ParamsWithStats` by re-evaluating the given `ldf` with the provided +`param_vector`. + +This method is intended to replace the old method of obtaining parameters and statistics +via `unflatten` plus re-evaluation. It is faster for two reasons: + +1. It does not rely on `deepcopy`-ing the VarInfo object (this used to be mandatory as + otherwise re-evaluation would mutate the VarInfo, rendering it unusable for subsequent + MCMC iterations). +2. The re-evaluation is faster as it uses `OnlyAccsVarInfo`. +""" +function ParamsWithStats( + param_vector::AbstractVector, + ldf::DynamicPPL.Experimental.FastLDF, + stats::NamedTuple=NamedTuple(); + include_colon_eq::Bool=true, + include_log_probs::Bool=true, +) + strategy = InitFromParams( + VectorWithRanges(ldf._iden_varname_ranges, ldf._varname_ranges, param_vector), + nothing, + ) + accs = if include_log_probs + ( + DynamicPPL.LogPriorAccumulator(), + DynamicPPL.LogLikelihoodAccumulator(), + DynamicPPL.ValuesAsInModelAccumulator(include_colon_eq), + ) + else + (DynamicPPL.ValuesAsInModelAccumulator(include_colon_eq),) + end + _, vi = DynamicPPL.Experimental.fast_evaluate!!( + ldf.model, strategy, AccumulatorTuple(accs) + ) + params = DynamicPPL.getacc(vi, Val(:ValuesAsInModel)).values + if include_log_probs + stats = merge( + stats, + ( + logprior=DynamicPPL.getlogprior(vi), + loglikelihood=DynamicPPL.getloglikelihood(vi), + lp=DynamicPPL.getlogjoint(vi), + ), + ) + end + return ParamsWithStats(params, stats) +end diff --git a/src/fasteval.jl b/src/fasteval.jl index 4f402f4a8..722760fa1 100644 --- a/src/fasteval.jl +++ b/src/fasteval.jl @@ -3,6 +3,7 @@ using DynamicPPL: AccumulatorTuple, InitContext, InitFromParams, + AbstractInitStrategy, LogJacobianAccumulator, LogLikelihoodAccumulator, LogPriorAccumulator, @@ -28,6 +29,60 @@ using LogDensityProblems: LogDensityProblems import DifferentiationInterface as DI using Random: Random +""" + DynamicPPL.Experimental.fast_evaluate!!( + [rng::Random.AbstractRNG,] + model::Model, + strategy::AbstractInitStrategy, + accs::AccumulatorTuple, params::AbstractVector{<:Real} + ) + +Evaluate a model using parameters obtained via `strategy`, and only computing the results in +the provided accumulators. + +It is assumed that the accumulators passed in have been initialised to appropriate values, +as this function will not reset them. The default constructors for each accumulator will do +this for you correctly. + +Returns a tuple of the model's return value, plus an `OnlyAccsVarInfo`. Note that the `accs` +argument may be mutated (depending on how the accumulators are implemented); hence the `!!` +in the function name. +""" +@inline function fast_evaluate!!( + # Note that this `@inline` is mandatory for performance. If it's not inlined, it leads + # to extra allocations (even for trivial models) and much slower runtime. + rng::Random.AbstractRNG, + model::Model, + strategy::AbstractInitStrategy, + accs::AccumulatorTuple, +) + ctx = InitContext(rng, strategy) + model = DynamicPPL.setleafcontext(model, ctx) + # Calling `evaluate!!` would be fine, but would lead to an extra call to resetaccs!!, + # which is unnecessary. So we shortcircuit this by simply calling `_evaluate!!` + # directly. To preserve thread-safety we need to reproduce the ThreadSafeVarInfo logic + # here. + # TODO(penelopeysm): This should _not_ check Threads.nthreads(). I still don't know what + # it _should_ do, but this is wrong regardless. + # https://github.com/TuringLang/DynamicPPL.jl/issues/1086 + vi = if Threads.nthreads() > 1 + param_eltype = DynamicPPL.get_param_eltype(strategy) + accs = map(accs) do acc + DynamicPPL.convert_eltype(float_type_with_fallback(param_eltype), acc) + end + ThreadSafeVarInfo(OnlyAccsVarInfo(accs)) + else + OnlyAccsVarInfo(accs) + end + return DynamicPPL._evaluate!!(model, vi) +end +@inline function fast_evaluate!!( + model::Model, strategy::AbstractInitStrategy, accs::AccumulatorTuple +) + # This `@inline` is also mandatory for performance + return fast_evaluate!!(Random.default_rng(), model, strategy, accs) +end + """ FastLDF( model::Model, @@ -213,31 +268,11 @@ struct FastLogDensityAt{M<:Model,F<:Function,N<:NamedTuple} varname_ranges::Dict{VarName,RangeAndLinked} end function (f::FastLogDensityAt)(params::AbstractVector{<:Real}) - ctx = InitContext( - Random.default_rng(), - InitFromParams( - VectorWithRanges(f.iden_varname_ranges, f.varname_ranges, params), nothing - ), + strategy = InitFromParams( + VectorWithRanges(f.iden_varname_ranges, f.varname_ranges, params), nothing ) - model = DynamicPPL.setleafcontext(f.model, ctx) accs = fast_ldf_accs(f.getlogdensity) - # Calling `evaluate!!` would be fine, but would lead to an extra call to resetaccs!!, - # which is unnecessary. So we shortcircuit this by simply calling `_evaluate!!` - # directly. To preserve thread-safety we need to reproduce the ThreadSafeVarInfo logic - # here. - # TODO(penelopeysm): This should _not_ check Threads.nthreads(). I still don't know what - # it _should_ do, but this is wrong regardless. - # https://github.com/TuringLang/DynamicPPL.jl/issues/1086 - vi = if Threads.nthreads() > 1 - accs = map( - acc -> DynamicPPL.convert_eltype(float_type_with_fallback(eltype(params)), acc), - accs, - ) - ThreadSafeVarInfo(OnlyAccsVarInfo(accs)) - else - OnlyAccsVarInfo(accs) - end - _, vi = DynamicPPL._evaluate!!(model, vi) + _, vi = fast_evaluate!!(f.model, strategy, accs) return f.getlogdensity(vi) end diff --git a/test/chains.jl b/test/chains.jl index ab0ff4475..43b877d62 100644 --- a/test/chains.jl +++ b/test/chains.jl @@ -4,7 +4,7 @@ using DynamicPPL using Distributions using Test -@testset "ParamsWithStats" begin +@testset "ParamsWithStats from VarInfo" begin @model function f(z) x ~ Normal() y := x + 1 @@ -66,4 +66,30 @@ using Test end end +@testset "ParamsWithStats from FastLDF" begin + @testset "$(m.f)" for m in DynamicPPL.TestUtils.DEMO_MODELS + unlinked_vi = VarInfo(m) + @testset "$islinked" for islinked in (false, true) + vi = if islinked + DynamicPPL.link!!(unlinked_vi, m) + else + unlinked_vi + end + params = [x for x in vi[:]] + + # Get the ParamsWithStats using FastLDF + fldf = DynamicPPL.Experimental.FastLDF(m, getlogjoint, vi) + ps = ParamsWithStats(params, fldf) + + # Check that length of parameters is as expected + @test length(ps.params) == length(keys(vi)) + + # Iterate over all variables to check that their values match + for vn in keys(vi) + @test ps.params[vn] == vi[vn] + end + end + end +end + end # module From 766f6635903c401a79d3c2427dc60225f0053dad Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Tue, 25 Nov 2025 11:41:51 +0000 Subject: [PATCH 11/15] Make FastLDF the default (#1139) * Make FastLDF the default * Add miscellaneous LogDensityProblems tests * Use `init!!` instead of `fast_evaluate!!` * Rename files, rebalance tests --- HISTORY.md | 23 +- docs/src/api.md | 8 +- ext/DynamicPPLMarginalLogDensitiesExt.jl | 11 +- src/DynamicPPL.jl | 4 + src/chains.jl | 8 +- src/experimental.jl | 2 - src/fasteval.jl | 387 --------------- src/logdensityfunction.jl | 579 +++++++++++------------ src/model.jl | 54 ++- test/ad.jl | 137 ------ test/chains.jl | 8 +- test/fasteval.jl | 233 --------- test/logdensityfunction.jl | 263 ++++++++-- test/runtests.jl | 7 +- 14 files changed, 584 insertions(+), 1140 deletions(-) delete mode 100644 src/fasteval.jl delete mode 100644 test/ad.jl delete mode 100644 test/fasteval.jl diff --git a/HISTORY.md b/HISTORY.md index 0f0102ce4..91306c219 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -4,6 +4,17 @@ ### Breaking changes +#### Fast Log Density Functions + +This version provides a reimplementation of `LogDensityFunction` that provides performance improvements on the order of 2–10× for both model evaluation as well as automatic differentiation. +Exact speedups depend on the model size: larger models have less significant speedups because the bulk of the work is done in calls to `logpdf`. + +For more information about how this is accomplished, please see https://github.com/TuringLang/DynamicPPL.jl/pull/1113 as well as the `src/fasteval.jl` file, which contains extensive comments. + +As a result of this change, `LogDensityFunction` no longer stores a VarInfo inside it. +In general, if `ldf` is a `LogDensityFunction`, it is now only valid to access `ldf.model` and `ldf.adtype`. +If you were previously relying on this behaviour, you will need to store a VarInfo separately. + #### Parent and leaf contexts The `DynamicPPL.NodeTrait` function has been removed. @@ -24,18 +35,6 @@ Removed the method `returned(::Model, values, keys)`; please use `returned(::Mod The method `DynamicPPL.init` (for implementing `AbstractInitStrategy`) now has a different signature: it must return a tuple of the generated value, plus a transform function that maps it back to unlinked space. This is a generalisation of the previous behaviour, where `init` would always return an unlinked value (in effect forcing the transform to be the identity function). -### Other changes - -#### FastLDF - -Added `DynamicPPL.Experimental.FastLDF`, a version of `LogDensityFunction` that provides performance improvements on the order of 2–10× for both model evaluation as well as automatic differentiation. -Exact speedups depend on the model size: larger models have less significant speedups because the bulk of the work is done in calls to `logpdf`. - -Please note that `FastLDF` is currently considered internal and its API may change without warning. -We intend to replace `LogDensityFunction` with `FastLDF` in a release in the near future, but until then we recommend not using it. - -For more information about `FastLDF`, please see https://github.com/TuringLang/DynamicPPL.jl/pull/1113 as well as the `src/fasteval.jl` file, which contains extensive comments. - ## 0.38.9 Remove warning when using Enzyme as the AD backend. diff --git a/docs/src/api.md b/docs/src/api.md index e81f18dc7..adb476db5 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -66,6 +66,12 @@ The [LogDensityProblems.jl](https://github.com/tpapp/LogDensityProblems.jl) inte LogDensityFunction ``` +Internally, this is accomplished using [`init!!`](@ref) on: + +```@docs +OnlyAccsVarInfo +``` + ## Condition and decondition A [`Model`](@ref) can be conditioned on a set of observations with [`AbstractPPL.condition`](@ref) or its alias [`|`](@ref). @@ -510,7 +516,7 @@ The function `init!!` is used to initialise, or overwrite, values in a VarInfo. It is really a thin wrapper around using `evaluate!!` with an `InitContext`. ```@docs -DynamicPPL.init!! +init!! ``` To accomplish this, an initialisation _strategy_ is required, which defines how new values are to be obtained. diff --git a/ext/DynamicPPLMarginalLogDensitiesExt.jl b/ext/DynamicPPLMarginalLogDensitiesExt.jl index 2155fa161..8b3040757 100644 --- a/ext/DynamicPPLMarginalLogDensitiesExt.jl +++ b/ext/DynamicPPLMarginalLogDensitiesExt.jl @@ -6,8 +6,13 @@ using MarginalLogDensities: MarginalLogDensities # A thin wrapper to adapt a DynamicPPL.LogDensityFunction to the interface expected by # MarginalLogDensities. It's helpful to have a struct so that we can dispatch on its type # below. -struct LogDensityFunctionWrapper{L<:DynamicPPL.LogDensityFunction} +struct LogDensityFunctionWrapper{ + L<:DynamicPPL.LogDensityFunction,V<:DynamicPPL.AbstractVarInfo +} logdensity::L + # This field is used only to reconstruct the VarInfo later on; it's not needed for the + # actual log-density evaluation. + varinfo::V end function (lw::LogDensityFunctionWrapper)(x, _) return LogDensityProblems.logdensity(lw.logdensity, x) @@ -101,7 +106,7 @@ function DynamicPPL.marginalize( # Construct the marginal log-density model. f = DynamicPPL.LogDensityFunction(model, getlogprob, varinfo) mld = MarginalLogDensities.MarginalLogDensity( - LogDensityFunctionWrapper(f), varinfo[:], varindices, (), method; kwargs... + LogDensityFunctionWrapper(f, varinfo), varinfo[:], varindices, (), method; kwargs... ) return mld end @@ -190,7 +195,7 @@ function DynamicPPL.VarInfo( unmarginalized_params::Union{AbstractVector,Nothing}=nothing, ) # Extract the original VarInfo. Its contents will in general be junk. - original_vi = mld.logdensity.logdensity.varinfo + original_vi = mld.logdensity.varinfo # Extract the stored parameters, which includes the modes for any marginalized # parameters full_params = MarginalLogDensities.cached_params(mld) diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index 6d3900e91..a885f6a96 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -92,8 +92,12 @@ export AbstractVarInfo, getargnames, extract_priors, values_as_in_model, + # evaluation + evaluate!!, + init!!, # LogDensityFunction LogDensityFunction, + OnlyAccsVarInfo, # Leaf contexts AbstractContext, contextualize, diff --git a/src/chains.jl b/src/chains.jl index 892423822..f176b8e68 100644 --- a/src/chains.jl +++ b/src/chains.jl @@ -137,7 +137,7 @@ end """ ParamsWithStats( param_vector::AbstractVector, - ldf::DynamicPPL.Experimental.FastLDF, + ldf::DynamicPPL.LogDensityFunction, stats::NamedTuple=NamedTuple(); include_colon_eq::Bool=true, include_log_probs::Bool=true, @@ -156,7 +156,7 @@ via `unflatten` plus re-evaluation. It is faster for two reasons: """ function ParamsWithStats( param_vector::AbstractVector, - ldf::DynamicPPL.Experimental.FastLDF, + ldf::DynamicPPL.LogDensityFunction, stats::NamedTuple=NamedTuple(); include_colon_eq::Bool=true, include_log_probs::Bool=true, @@ -174,9 +174,7 @@ function ParamsWithStats( else (DynamicPPL.ValuesAsInModelAccumulator(include_colon_eq),) end - _, vi = DynamicPPL.Experimental.fast_evaluate!!( - ldf.model, strategy, AccumulatorTuple(accs) - ) + _, vi = DynamicPPL.init!!(ldf.model, OnlyAccsVarInfo(AccumulatorTuple(accs)), strategy) params = DynamicPPL.getacc(vi, Val(:ValuesAsInModel)).values if include_log_probs stats = merge( diff --git a/src/experimental.jl b/src/experimental.jl index c644c09b2..8c82dca68 100644 --- a/src/experimental.jl +++ b/src/experimental.jl @@ -2,8 +2,6 @@ module Experimental using DynamicPPL: DynamicPPL -include("fasteval.jl") - # This file only defines the names of the functions, and their docstrings. The actual implementations are in `ext/DynamicPPLJETExt.jl`, since we don't want to depend on JET.jl other than as a weak dependency. """ is_suitable_varinfo(model::Model, varinfo::AbstractVarInfo; kwargs...) diff --git a/src/fasteval.jl b/src/fasteval.jl deleted file mode 100644 index 722760fa1..000000000 --- a/src/fasteval.jl +++ /dev/null @@ -1,387 +0,0 @@ -using DynamicPPL: - AbstractVarInfo, - AccumulatorTuple, - InitContext, - InitFromParams, - AbstractInitStrategy, - LogJacobianAccumulator, - LogLikelihoodAccumulator, - LogPriorAccumulator, - Model, - ThreadSafeVarInfo, - VarInfo, - OnlyAccsVarInfo, - RangeAndLinked, - VectorWithRanges, - Metadata, - VarNamedVector, - default_accumulators, - float_type_with_fallback, - getlogjoint, - getlogjoint_internal, - getloglikelihood, - getlogprior, - getlogprior_internal -using ADTypes: ADTypes -using BangBang: BangBang -using AbstractPPL: AbstractPPL, VarName -using LogDensityProblems: LogDensityProblems -import DifferentiationInterface as DI -using Random: Random - -""" - DynamicPPL.Experimental.fast_evaluate!!( - [rng::Random.AbstractRNG,] - model::Model, - strategy::AbstractInitStrategy, - accs::AccumulatorTuple, params::AbstractVector{<:Real} - ) - -Evaluate a model using parameters obtained via `strategy`, and only computing the results in -the provided accumulators. - -It is assumed that the accumulators passed in have been initialised to appropriate values, -as this function will not reset them. The default constructors for each accumulator will do -this for you correctly. - -Returns a tuple of the model's return value, plus an `OnlyAccsVarInfo`. Note that the `accs` -argument may be mutated (depending on how the accumulators are implemented); hence the `!!` -in the function name. -""" -@inline function fast_evaluate!!( - # Note that this `@inline` is mandatory for performance. If it's not inlined, it leads - # to extra allocations (even for trivial models) and much slower runtime. - rng::Random.AbstractRNG, - model::Model, - strategy::AbstractInitStrategy, - accs::AccumulatorTuple, -) - ctx = InitContext(rng, strategy) - model = DynamicPPL.setleafcontext(model, ctx) - # Calling `evaluate!!` would be fine, but would lead to an extra call to resetaccs!!, - # which is unnecessary. So we shortcircuit this by simply calling `_evaluate!!` - # directly. To preserve thread-safety we need to reproduce the ThreadSafeVarInfo logic - # here. - # TODO(penelopeysm): This should _not_ check Threads.nthreads(). I still don't know what - # it _should_ do, but this is wrong regardless. - # https://github.com/TuringLang/DynamicPPL.jl/issues/1086 - vi = if Threads.nthreads() > 1 - param_eltype = DynamicPPL.get_param_eltype(strategy) - accs = map(accs) do acc - DynamicPPL.convert_eltype(float_type_with_fallback(param_eltype), acc) - end - ThreadSafeVarInfo(OnlyAccsVarInfo(accs)) - else - OnlyAccsVarInfo(accs) - end - return DynamicPPL._evaluate!!(model, vi) -end -@inline function fast_evaluate!!( - model::Model, strategy::AbstractInitStrategy, accs::AccumulatorTuple -) - # This `@inline` is also mandatory for performance - return fast_evaluate!!(Random.default_rng(), model, strategy, accs) -end - -""" - FastLDF( - model::Model, - getlogdensity::Function=getlogjoint_internal, - varinfo::AbstractVarInfo=VarInfo(model); - adtype::Union{ADTypes.AbstractADType,Nothing}=nothing, - ) - -A struct which contains a model, along with all the information necessary to: - - - calculate its log density at a given point; - - and if `adtype` is provided, calculate the gradient of the log density at that point. - -This information can be extracted using the LogDensityProblems.jl interface, specifically, -using `LogDensityProblems.logdensity` and `LogDensityProblems.logdensity_and_gradient`. If -`adtype` is nothing, then only `logdensity` is implemented. If `adtype` is a concrete AD -backend type, then `logdensity_and_gradient` is also implemented. - -There are several options for `getlogdensity` that are 'supported' out of the box: - -- [`getlogjoint_internal`](@ref): calculate the log joint, including the log-Jacobian term - for any variables that have been linked in the provided VarInfo. -- [`getlogprior_internal`](@ref): calculate the log prior, including the log-Jacobian term - for any variables that have been linked in the provided VarInfo. -- [`getlogjoint`](@ref): calculate the log joint in the model space, ignoring any effects of - linking -- [`getlogprior`](@ref): calculate the log prior in the model space, ignoring any effects of - linking -- [`getloglikelihood`](@ref): calculate the log likelihood (this is unaffected by linking, - since transforms are only applied to random variables) - -!!! note - By default, `FastLDF` uses `getlogjoint_internal`, i.e., the result of - `LogDensityProblems.logdensity(f, x)` will depend on whether the `FastLDF` was created - with a linked or unlinked VarInfo. This is done primarily to ease interoperability with - MCMC samplers. - -If you provide one of these functions, a `VarInfo` will be automatically created for you. If -you provide a different function, you have to manually create a VarInfo and pass it as the -third argument. - -If the `adtype` keyword argument is provided, then this struct will also store the adtype -along with other information for efficient calculation of the gradient of the log density. -Note that preparing a `FastLDF` with an AD type `AutoBackend()` requires the AD backend -itself to have been loaded (e.g. with `import Backend`). - -## Fields - -Note that it is undefined behaviour to access any of a `FastLDF`'s fields, apart from: - -- `fastldf.model`: The original model from which this `FastLDF` was constructed. -- `fastldf.adtype`: The AD type used for gradient calculations, or `nothing` if no AD - type was provided. - -# Extended help - -Up until DynamicPPL v0.38, there have been two ways of evaluating a DynamicPPL model at a -given set of parameters: - -1. With `unflatten` + `evaluate!!` with `DefaultContext`: this stores a vector of parameters - inside a VarInfo's metadata, then reads parameter values from the VarInfo during evaluation. - -2. With `InitFromParams`: this reads parameter values from a NamedTuple or a Dict, and stores - them inside a VarInfo's metadata. - -In general, both of these approaches work fine, but the fact that they modify the VarInfo's -metadata can often be quite wasteful. In particular, it is very common that the only outputs -we care about from model evaluation are those which are stored in accumulators, such as log -probability densities, or `ValuesAsInModel`. - -To avoid this issue, we use `OnlyAccsVarInfo`, which is a VarInfo that only contains -accumulators. It implements enough of the `AbstractVarInfo` interface to not error during -model evaluation. - -Because `OnlyAccsVarInfo` does not store any parameter values, when evaluating a model with -it, it is mandatory that parameters are provided from outside the VarInfo, namely via -`InitContext`. - -The main problem that we face is that it is not possible to directly implement -`DynamicPPL.init(rng, vn, dist, strategy)` for `strategy::InitFromParams{<:AbstractVector}`. -In particular, it is not clear: - - - which parts of the vector correspond to which random variables, and - - whether the variables are linked or unlinked. - -Traditionally, this problem has been solved by `unflatten`, because that function would -place values into the VarInfo's metadata alongside the information about ranges and linking. -That way, when we evaluate with `DefaultContext`, we can read this information out again. -However, we want to avoid using a metadata. Thus, here, we _extract this information from -the VarInfo_ a single time when constructing a `FastLDF` object. Inside the FastLDF, we -store a mapping from VarNames to ranges in that vector, along with link status. - -For VarNames with identity optics, this is stored in a NamedTuple for efficiency. For all -other VarNames, this is stored in a Dict. The internal data structure used to represent this -could almost certainly be optimised further. See e.g. the discussion in -https://github.com/TuringLang/DynamicPPL.jl/issues/1116. - -When evaluating the model, this allows us to combine the parameter vector together with those -ranges to create an `InitFromParams{VectorWithRanges}`, which lets us very quickly read -parameter values from the vector. - -Note that this assumes that the ranges and link status are static throughout the lifetime of -the `FastLDF` object. Therefore, a `FastLDF` object cannot handle models which have variable -numbers of parameters, or models which may visit random variables in different orders depending -on stochastic control flow. **Indeed, silent errors may occur with such models.** This is a -general limitation of vectorised parameters: the original `unflatten` + `evaluate!!` -approach also fails with such models. -""" -struct FastLDF{ - M<:Model, - AD<:Union{ADTypes.AbstractADType,Nothing}, - F<:Function, - N<:NamedTuple, - ADP<:Union{Nothing,DI.GradientPrep}, -} - model::M - adtype::AD - _getlogdensity::F - _iden_varname_ranges::N - _varname_ranges::Dict{VarName,RangeAndLinked} - _adprep::ADP - _dim::Int - - function FastLDF( - model::Model, - getlogdensity::Function=getlogjoint_internal, - varinfo::AbstractVarInfo=VarInfo(model); - adtype::Union{ADTypes.AbstractADType,Nothing}=nothing, - ) - # Figure out which variable corresponds to which index, and - # which variables are linked. - all_iden_ranges, all_ranges = get_ranges_and_linked(varinfo) - x = [val for val in varinfo[:]] - dim = length(x) - # Do AD prep if needed - prep = if adtype === nothing - nothing - else - # Make backend-specific tweaks to the adtype - adtype = DynamicPPL.tweak_adtype(adtype, model, varinfo) - DI.prepare_gradient( - FastLogDensityAt(model, getlogdensity, all_iden_ranges, all_ranges), - adtype, - x, - ) - end - return new{ - typeof(model), - typeof(adtype), - typeof(getlogdensity), - typeof(all_iden_ranges), - typeof(prep), - }( - model, adtype, getlogdensity, all_iden_ranges, all_ranges, prep, dim - ) - end -end - -################################### -# LogDensityProblems.jl interface # -################################### -""" - fast_ldf_accs(getlogdensity::Function) - -Determine which accumulators are needed for fast evaluation with the given -`getlogdensity` function. -""" -fast_ldf_accs(::Function) = default_accumulators() -fast_ldf_accs(::typeof(getlogjoint_internal)) = default_accumulators() -function fast_ldf_accs(::typeof(getlogjoint)) - return AccumulatorTuple((LogPriorAccumulator(), LogLikelihoodAccumulator())) -end -function fast_ldf_accs(::typeof(getlogprior_internal)) - return AccumulatorTuple((LogPriorAccumulator(), LogJacobianAccumulator())) -end -fast_ldf_accs(::typeof(getlogprior)) = AccumulatorTuple((LogPriorAccumulator(),)) -fast_ldf_accs(::typeof(getloglikelihood)) = AccumulatorTuple((LogLikelihoodAccumulator(),)) - -struct FastLogDensityAt{M<:Model,F<:Function,N<:NamedTuple} - model::M - getlogdensity::F - iden_varname_ranges::N - varname_ranges::Dict{VarName,RangeAndLinked} -end -function (f::FastLogDensityAt)(params::AbstractVector{<:Real}) - strategy = InitFromParams( - VectorWithRanges(f.iden_varname_ranges, f.varname_ranges, params), nothing - ) - accs = fast_ldf_accs(f.getlogdensity) - _, vi = fast_evaluate!!(f.model, strategy, accs) - return f.getlogdensity(vi) -end - -function LogDensityProblems.logdensity(fldf::FastLDF, params::AbstractVector{<:Real}) - return FastLogDensityAt( - fldf.model, fldf._getlogdensity, fldf._iden_varname_ranges, fldf._varname_ranges - )( - params - ) -end - -function LogDensityProblems.logdensity_and_gradient( - fldf::FastLDF, params::AbstractVector{<:Real} -) - return DI.value_and_gradient( - FastLogDensityAt( - fldf.model, fldf._getlogdensity, fldf._iden_varname_ranges, fldf._varname_ranges - ), - fldf._adprep, - fldf.adtype, - params, - ) -end - -function LogDensityProblems.capabilities( - ::Type{<:DynamicPPL.Experimental.FastLDF{M,Nothing}} -) where {M} - return LogDensityProblems.LogDensityOrder{0}() -end -function LogDensityProblems.capabilities( - ::Type{<:DynamicPPL.Experimental.FastLDF{M,<:ADTypes.AbstractADType}} -) where {M} - return LogDensityProblems.LogDensityOrder{1}() -end -function LogDensityProblems.dimension(fldf::FastLDF) - return fldf._dim -end - -###################################################### -# Helper functions to extract ranges and link status # -###################################################### - -# This fails for SimpleVarInfo, but honestly there is no reason to support that here. The -# fact is that evaluation doesn't use a VarInfo, it only uses it once to generate the ranges -# and link status. So there is no motivation to use SimpleVarInfo inside a -# LogDensityFunction any more, we can just always use typed VarInfo. In fact one could argue -# that there is no purpose in supporting untyped VarInfo either. -""" - get_ranges_and_linked(varinfo::VarInfo) - -Given a `VarInfo`, extract the ranges of each variable in the vectorised parameter -representation, along with whether each variable is linked or unlinked. - -This function should return a tuple containing: - -- A NamedTuple mapping VarNames with identity optics to their corresponding `RangeAndLinked` -- A Dict mapping all other VarNames to their corresponding `RangeAndLinked`. -""" -function get_ranges_and_linked(varinfo::VarInfo{<:NamedTuple{syms}}) where {syms} - all_iden_ranges = NamedTuple() - all_ranges = Dict{VarName,RangeAndLinked}() - offset = 1 - for sym in syms - md = varinfo.metadata[sym] - this_md_iden, this_md_others, offset = get_ranges_and_linked_metadata(md, offset) - all_iden_ranges = merge(all_iden_ranges, this_md_iden) - all_ranges = merge(all_ranges, this_md_others) - end - return all_iden_ranges, all_ranges -end -function get_ranges_and_linked(varinfo::VarInfo{<:Union{Metadata,VarNamedVector}}) - all_iden, all_others, _ = get_ranges_and_linked_metadata(varinfo.metadata, 1) - return all_iden, all_others -end -function get_ranges_and_linked_metadata(md::Metadata, start_offset::Int) - all_iden_ranges = NamedTuple() - all_ranges = Dict{VarName,RangeAndLinked}() - offset = start_offset - for (vn, idx) in md.idcs - is_linked = md.is_transformed[idx] - range = md.ranges[idx] .+ (start_offset - 1) - if AbstractPPL.getoptic(vn) === identity - all_iden_ranges = merge( - all_iden_ranges, - NamedTuple((AbstractPPL.getsym(vn) => RangeAndLinked(range, is_linked),)), - ) - else - all_ranges[vn] = RangeAndLinked(range, is_linked) - end - offset += length(range) - end - return all_iden_ranges, all_ranges, offset -end -function get_ranges_and_linked_metadata(vnv::VarNamedVector, start_offset::Int) - all_iden_ranges = NamedTuple() - all_ranges = Dict{VarName,RangeAndLinked}() - offset = start_offset - for (vn, idx) in vnv.varname_to_index - is_linked = vnv.is_unconstrained[idx] - range = vnv.ranges[idx] .+ (start_offset - 1) - if AbstractPPL.getoptic(vn) === identity - all_iden_ranges = merge( - all_iden_ranges, - NamedTuple((AbstractPPL.getsym(vn) => RangeAndLinked(range, is_linked),)), - ) - else - all_ranges[vn] = RangeAndLinked(range, is_linked) - end - offset += length(range) - end - return all_iden_ranges, all_ranges, offset -end diff --git a/src/logdensityfunction.jl b/src/logdensityfunction.jl index 7c7438c9f..65eab448e 100644 --- a/src/logdensityfunction.jl +++ b/src/logdensityfunction.jl @@ -1,312 +1,263 @@ -using AbstractMCMC: AbstractModel +using DynamicPPL: + AbstractVarInfo, + AccumulatorTuple, + InitContext, + InitFromParams, + AbstractInitStrategy, + LogJacobianAccumulator, + LogLikelihoodAccumulator, + LogPriorAccumulator, + Model, + ThreadSafeVarInfo, + VarInfo, + OnlyAccsVarInfo, + RangeAndLinked, + VectorWithRanges, + Metadata, + VarNamedVector, + default_accumulators, + float_type_with_fallback, + getlogjoint, + getlogjoint_internal, + getloglikelihood, + getlogprior, + getlogprior_internal +using ADTypes: ADTypes +using BangBang: BangBang +using AbstractPPL: AbstractPPL, VarName +using LogDensityProblems: LogDensityProblems import DifferentiationInterface as DI +using Random: Random """ - is_supported(adtype::AbstractADType) - -Check if the given AD type is formally supported by DynamicPPL. - -AD backends that are not formally supported can still be used for gradient -calculation; it is just that the DynamicPPL developers do not commit to -maintaining compatibility with them. -""" -is_supported(::ADTypes.AbstractADType) = false -is_supported(::ADTypes.AutoEnzyme) = true -is_supported(::ADTypes.AutoForwardDiff) = true -is_supported(::ADTypes.AutoMooncake) = true -is_supported(::ADTypes.AutoReverseDiff) = true - -""" - LogDensityFunction( + DynamicPPL.LogDensityFunction( model::Model, getlogdensity::Function=getlogjoint_internal, - varinfo::AbstractVarInfo=ldf_default_varinfo(model, getlogdensity); - adtype::Union{ADTypes.AbstractADType,Nothing}=nothing + varinfo::AbstractVarInfo=VarInfo(model); + adtype::Union{ADTypes.AbstractADType,Nothing}=nothing, ) A struct which contains a model, along with all the information necessary to: - calculate its log density at a given point; - - and if `adtype` is provided, calculate the gradient of the log density at - that point. - -This information can be extracted using the LogDensityProblems.jl interface, -specifically, using `LogDensityProblems.logdensity` and -`LogDensityProblems.logdensity_and_gradient`. If `adtype` is nothing, then only -`logdensity` is implemented. If `adtype` is a concrete AD backend type, then -`logdensity_and_gradient` is also implemented. - -There are several options for `getlogdensity` that are 'supported' out of the -box: - -- [`getlogjoint_internal`](@ref): calculate the log joint, including the - log-Jacobian term for any variables that have been linked in the provided - VarInfo. -- [`getlogprior_internal`](@ref): calculate the log prior, including the - log-Jacobian term for any variables that have been linked in the provided - VarInfo. -- [`getlogjoint`](@ref): calculate the log joint in the model space, ignoring - any effects of linking -- [`getlogprior`](@ref): calculate the log prior in the model space, ignoring - any effects of linking -- [`getloglikelihood`](@ref): calculate the log likelihood (this is unaffected - by linking, since transforms are only applied to random variables) + - and if `adtype` is provided, calculate the gradient of the log density at that point. + +This information can be extracted using the LogDensityProblems.jl interface, specifically, +using `LogDensityProblems.logdensity` and `LogDensityProblems.logdensity_and_gradient`. If +`adtype` is nothing, then only `logdensity` is implemented. If `adtype` is a concrete AD +backend type, then `logdensity_and_gradient` is also implemented. + +There are several options for `getlogdensity` that are 'supported' out of the box: + +- [`getlogjoint_internal`](@ref): calculate the log joint, including the log-Jacobian term + for any variables that have been linked in the provided VarInfo. +- [`getlogprior_internal`](@ref): calculate the log prior, including the log-Jacobian term + for any variables that have been linked in the provided VarInfo. +- [`getlogjoint`](@ref): calculate the log joint in the model space, ignoring any effects of + linking +- [`getlogprior`](@ref): calculate the log prior in the model space, ignoring any effects of + linking +- [`getloglikelihood`](@ref): calculate the log likelihood (this is unaffected by linking, + since transforms are only applied to random variables) !!! note - By default, `LogDensityFunction` uses `getlogjoint_internal`, i.e., the - result of `LogDensityProblems.logdensity(f, x)` will depend on whether the - `LogDensityFunction` was created with a linked or unlinked VarInfo. This - is done primarily to ease interoperability with MCMC samplers. - -If you provide one of these functions, a `VarInfo` will be automatically created -for you. If you provide a different function, you have to manually create a -VarInfo and pass it as the third argument. - -If the `adtype` keyword argument is provided, then this struct will also store -the adtype along with other information for efficient calculation of the -gradient of the log density. Note that preparing a `LogDensityFunction` with an -AD type `AutoBackend()` requires the AD backend itself to have been loaded -(e.g. with `import Backend`). - -# Fields -$(FIELDS) - -# Examples - -```jldoctest -julia> using Distributions - -julia> using DynamicPPL: LogDensityFunction, setaccs!! - -julia> @model function demo(x) - m ~ Normal() - x ~ Normal(m, 1) - end -demo (generic function with 2 methods) - -julia> model = demo(1.0); - -julia> f = LogDensityFunction(model); - -julia> # It implements the interface of LogDensityProblems.jl. - using LogDensityProblems - -julia> LogDensityProblems.logdensity(f, [0.0]) --2.3378770664093453 - -julia> LogDensityProblems.dimension(f) -1 - -julia> # By default it uses `VarInfo` under the hood, but this is not necessary. - f = LogDensityFunction(model, getlogjoint_internal, SimpleVarInfo(model)); - -julia> LogDensityProblems.logdensity(f, [0.0]) --2.3378770664093453 - -julia> # One can also specify evaluating e.g. the log prior only: - f_prior = LogDensityFunction(model, getlogprior); - -julia> LogDensityProblems.logdensity(f_prior, [0.0]) == logpdf(Normal(), 0.0) -true - -julia> # If we also need to calculate the gradient, we can specify an AD backend. - import ForwardDiff, ADTypes - -julia> f = LogDensityFunction(model, adtype=ADTypes.AutoForwardDiff()); - -julia> LogDensityProblems.logdensity_and_gradient(f, [0.0]) -(-2.3378770664093453, [1.0]) -``` + By default, `LogDensityFunction` uses `getlogjoint_internal`, i.e., the result of + `LogDensityProblems.logdensity(f, x)` will depend on whether the `LogDensityFunction` + was created with a linked or unlinked VarInfo. This is done primarily to ease + interoperability with MCMC samplers. + +If you provide one of these functions, a `VarInfo` will be automatically created for you. If +you provide a different function, you have to manually create a VarInfo and pass it as the +third argument. + +If the `adtype` keyword argument is provided, then this struct will also store the adtype +along with other information for efficient calculation of the gradient of the log density. +Note that preparing a `LogDensityFunction` with an AD type `AutoBackend()` requires the AD +backend itself to have been loaded (e.g. with `import Backend`). + +## Fields + +Note that it is undefined behaviour to access any of a `LogDensityFunction`'s fields, apart +from: + +- `ldf.model`: The original model from which this `LogDensityFunction` was constructed. +- `ldf.adtype`: The AD type used for gradient calculations, or `nothing` if no AD + type was provided. + +# Extended help + +Up until DynamicPPL v0.38, there have been two ways of evaluating a DynamicPPL model at a +given set of parameters: + +1. With `unflatten` + `evaluate!!` with `DefaultContext`: this stores a vector of parameters + inside a VarInfo's metadata, then reads parameter values from the VarInfo during evaluation. + +2. With `InitFromParams`: this reads parameter values from a NamedTuple or a Dict, and stores + them inside a VarInfo's metadata. + +In general, both of these approaches work fine, but the fact that they modify the VarInfo's +metadata can often be quite wasteful. In particular, it is very common that the only outputs +we care about from model evaluation are those which are stored in accumulators, such as log +probability densities, or `ValuesAsInModel`. + +To avoid this issue, we use `OnlyAccsVarInfo`, which is a VarInfo that only contains +accumulators. It implements enough of the `AbstractVarInfo` interface to not error during +model evaluation. + +Because `OnlyAccsVarInfo` does not store any parameter values, when evaluating a model with +it, it is mandatory that parameters are provided from outside the VarInfo, namely via +`InitContext`. + +The main problem that we face is that it is not possible to directly implement +`DynamicPPL.init(rng, vn, dist, strategy)` for `strategy::InitFromParams{<:AbstractVector}`. +In particular, it is not clear: + + - which parts of the vector correspond to which random variables, and + - whether the variables are linked or unlinked. + +Traditionally, this problem has been solved by `unflatten`, because that function would +place values into the VarInfo's metadata alongside the information about ranges and linking. +That way, when we evaluate with `DefaultContext`, we can read this information out again. +However, we want to avoid using a metadata. Thus, here, we _extract this information from +the VarInfo_ a single time when constructing a `LogDensityFunction` object. Inside the +LogDensityFunction, we store a mapping from VarNames to ranges in that vector, along with +link status. + +For VarNames with identity optics, this is stored in a NamedTuple for efficiency. For all +other VarNames, this is stored in a Dict. The internal data structure used to represent this +could almost certainly be optimised further. See e.g. the discussion in +https://github.com/TuringLang/DynamicPPL.jl/issues/1116. + +When evaluating the model, this allows us to combine the parameter vector together with those +ranges to create an `InitFromParams{VectorWithRanges}`, which lets us very quickly read +parameter values from the vector. + +Note that this assumes that the ranges and link status are static throughout the lifetime of +the `LogDensityFunction` object. Therefore, a `LogDensityFunction` object cannot handle +models which have variable numbers of parameters, or models which may visit random variables +in different orders depending on stochastic control flow. **Indeed, silent errors may occur +with such models.** This is a general limitation of vectorised parameters: the original +`unflatten` + `evaluate!!` approach also fails with such models. """ struct LogDensityFunction{ - M<:Model,F<:Function,V<:AbstractVarInfo,AD<:Union{Nothing,ADTypes.AbstractADType} -} <: AbstractModel - "model used for evaluation" + M<:Model, + AD<:Union{ADTypes.AbstractADType,Nothing}, + F<:Function, + N<:NamedTuple, + ADP<:Union{Nothing,DI.GradientPrep}, +} model::M - "function to be called on `varinfo` to extract the log density. By default `getlogjoint_internal`." - getlogdensity::F - "varinfo used for evaluation. If not specified, generated with `ldf_default_varinfo`." - varinfo::V - "AD type used for evaluation of log density gradient. If `nothing`, no gradient can be calculated" adtype::AD - "(internal use only) gradient preparation object for the model" - prep::Union{Nothing,DI.GradientPrep} + _getlogdensity::F + _iden_varname_ranges::N + _varname_ranges::Dict{VarName,RangeAndLinked} + _adprep::ADP + _dim::Int function LogDensityFunction( model::Model, getlogdensity::Function=getlogjoint_internal, - varinfo::AbstractVarInfo=ldf_default_varinfo(model, getlogdensity); + varinfo::AbstractVarInfo=VarInfo(model); adtype::Union{ADTypes.AbstractADType,Nothing}=nothing, ) - if adtype === nothing - prep = nothing + # Figure out which variable corresponds to which index, and + # which variables are linked. + all_iden_ranges, all_ranges = get_ranges_and_linked(varinfo) + x = [val for val in varinfo[:]] + dim = length(x) + # Do AD prep if needed + prep = if adtype === nothing + nothing else # Make backend-specific tweaks to the adtype - adtype = tweak_adtype(adtype, model, varinfo) - # Check whether it is supported - is_supported(adtype) || - @warn "The AD backend $adtype is not officially supported by DynamicPPL. Gradient calculations may still work, but compatibility is not guaranteed." - # Get a set of dummy params to use for prep - x = [val for val in varinfo[:]] - if use_closure(adtype) - prep = DI.prepare_gradient( - LogDensityAt(model, getlogdensity, varinfo), adtype, x - ) - else - prep = DI.prepare_gradient( - logdensity_at, - adtype, - x, - DI.Constant(model), - DI.Constant(getlogdensity), - DI.Constant(varinfo), - ) - end + adtype = DynamicPPL.tweak_adtype(adtype, model, varinfo) + DI.prepare_gradient( + LogDensityAt(model, getlogdensity, all_iden_ranges, all_ranges), + adtype, + x, + ) end - return new{typeof(model),typeof(getlogdensity),typeof(varinfo),typeof(adtype)}( - model, getlogdensity, varinfo, adtype, prep + return new{ + typeof(model), + typeof(adtype), + typeof(getlogdensity), + typeof(all_iden_ranges), + typeof(prep), + }( + model, adtype, getlogdensity, all_iden_ranges, all_ranges, prep, dim ) end end +################################### +# LogDensityProblems.jl interface # +################################### """ - LogDensityFunction( - ldf::LogDensityFunction, - adtype::Union{Nothing,ADTypes.AbstractADType} - ) + fast_ldf_accs(getlogdensity::Function) -Create a new LogDensityFunction using the model and varinfo from the given -`ldf` argument, but with the AD type set to `adtype`. To remove the AD type, -pass `nothing` as the second argument. +Determine which accumulators are needed for fast evaluation with the given +`getlogdensity` function. """ -function LogDensityFunction( - f::LogDensityFunction, adtype::Union{Nothing,ADTypes.AbstractADType} -) - return if adtype === f.adtype - f # Avoid recomputing prep if not needed - else - LogDensityFunction(f.model, f.getlogdensity, f.varinfo; adtype=adtype) - end +fast_ldf_accs(::Function) = default_accumulators() +fast_ldf_accs(::typeof(getlogjoint_internal)) = default_accumulators() +function fast_ldf_accs(::typeof(getlogjoint)) + return AccumulatorTuple((LogPriorAccumulator(), LogLikelihoodAccumulator())) end - -""" - ldf_default_varinfo(model::Model, getlogdensity::Function) - -Create the default AbstractVarInfo that should be used for evaluating the log density. - -Only the accumulators necesessary for `getlogdensity` will be used. -""" -function ldf_default_varinfo(::Model, getlogdensity::Function) - msg = """ - LogDensityFunction does not know what sort of VarInfo should be used when \ - `getlogdensity` is $getlogdensity. Please specify a VarInfo explicitly. - """ - return error(msg) +function fast_ldf_accs(::typeof(getlogprior_internal)) + return AccumulatorTuple((LogPriorAccumulator(), LogJacobianAccumulator())) end +fast_ldf_accs(::typeof(getlogprior)) = AccumulatorTuple((LogPriorAccumulator(),)) +fast_ldf_accs(::typeof(getloglikelihood)) = AccumulatorTuple((LogLikelihoodAccumulator(),)) -ldf_default_varinfo(model::Model, ::typeof(getlogjoint_internal)) = VarInfo(model) - -function ldf_default_varinfo(model::Model, ::typeof(getlogjoint)) - return setaccs!!(VarInfo(model), (LogPriorAccumulator(), LogLikelihoodAccumulator())) -end - -function ldf_default_varinfo(model::Model, ::typeof(getlogprior_internal)) - return setaccs!!(VarInfo(model), (LogPriorAccumulator(), LogJacobianAccumulator())) -end - -function ldf_default_varinfo(model::Model, ::typeof(getlogprior)) - return setaccs!!(VarInfo(model), (LogPriorAccumulator(),)) -end - -function ldf_default_varinfo(model::Model, ::typeof(getloglikelihood)) - return setaccs!!(VarInfo(model), (LogLikelihoodAccumulator(),)) +struct LogDensityAt{M<:Model,F<:Function,N<:NamedTuple} + model::M + getlogdensity::F + iden_varname_ranges::N + varname_ranges::Dict{VarName,RangeAndLinked} end - -""" - logdensity_at( - x::AbstractVector, - model::Model, - getlogdensity::Function, - varinfo::AbstractVarInfo, +function (f::LogDensityAt)(params::AbstractVector{<:Real}) + strategy = InitFromParams( + VectorWithRanges(f.iden_varname_ranges, f.varname_ranges, params), nothing ) + accs = fast_ldf_accs(f.getlogdensity) + _, vi = DynamicPPL.init!!(f.model, OnlyAccsVarInfo(accs), strategy) + return f.getlogdensity(vi) +end -Evaluate the log density of the given `model` at the given parameter values -`x`, using the given `varinfo`. Note that the `varinfo` argument is provided -only for its structure, in the sense that the parameters from the vector `x` -are inserted into it, and its own parameters are discarded. `getlogdensity` is -the function that extracts the log density from the evaluated varinfo. -""" -function logdensity_at( - x::AbstractVector, model::Model, getlogdensity::Function, varinfo::AbstractVarInfo +function LogDensityProblems.logdensity( + ldf::LogDensityFunction, params::AbstractVector{<:Real} ) - varinfo_new = unflatten(varinfo, x) - varinfo_eval = last(evaluate!!(model, varinfo_new)) - return getlogdensity(varinfo_eval) + return LogDensityAt( + ldf.model, ldf._getlogdensity, ldf._iden_varname_ranges, ldf._varname_ranges + )( + params + ) end -""" - LogDensityAt{M<:Model,F<:Function,V<:AbstractVarInfo}( - model::M - getlogdensity::F, - varinfo::V +function LogDensityProblems.logdensity_and_gradient( + ldf::LogDensityFunction, params::AbstractVector{<:Real} +) + return DI.value_and_gradient( + LogDensityAt( + ldf.model, ldf._getlogdensity, ldf._iden_varname_ranges, ldf._varname_ranges + ), + ldf._adprep, + ldf.adtype, + params, ) - -A callable struct that serves the same purpose as `x -> logdensity_at(x, model, -getlogdensity, varinfo)`. -""" -struct LogDensityAt{M<:Model,F<:Function,V<:AbstractVarInfo} - model::M - getlogdensity::F - varinfo::V -end -function (ld::LogDensityAt)(x::AbstractVector) - return logdensity_at(x, ld.model, ld.getlogdensity, ld.varinfo) end -### LogDensityProblems interface - -function LogDensityProblems.capabilities( - ::Type{<:LogDensityFunction{M,F,V,Nothing}} -) where {M,F,V} +function LogDensityProblems.capabilities(::Type{<:LogDensityFunction{M,Nothing}}) where {M} return LogDensityProblems.LogDensityOrder{0}() end function LogDensityProblems.capabilities( - ::Type{<:LogDensityFunction{M,F,V,AD}} -) where {M,F,V,AD<:ADTypes.AbstractADType} + ::Type{<:LogDensityFunction{M,<:ADTypes.AbstractADType}} +) where {M} return LogDensityProblems.LogDensityOrder{1}() end -function LogDensityProblems.logdensity(f::LogDensityFunction, x::AbstractVector) - return logdensity_at(x, f.model, f.getlogdensity, f.varinfo) +function LogDensityProblems.dimension(ldf::LogDensityFunction) + return ldf._dim end -function LogDensityProblems.logdensity_and_gradient( - f::LogDensityFunction{M,F,V,AD}, x::AbstractVector -) where {M,F,V,AD<:ADTypes.AbstractADType} - f.prep === nothing && - error("Gradient preparation not available; this should not happen") - x = [val for val in x] # Concretise type - # Make branching statically inferrable, i.e. type-stable (even if the two - # branches happen to return different types) - return if use_closure(f.adtype) - DI.value_and_gradient( - LogDensityAt(f.model, f.getlogdensity, f.varinfo), f.prep, f.adtype, x - ) - else - DI.value_and_gradient( - logdensity_at, - f.prep, - f.adtype, - x, - DI.Constant(f.model), - DI.Constant(f.getlogdensity), - DI.Constant(f.varinfo), - ) - end -end - -# TODO: should we instead implement and call on `length(f.varinfo)` (at least in the cases where no sampler is involved)? -LogDensityProblems.dimension(f::LogDensityFunction) = length(getparams(f)) - -### Utils """ tweak_adtype( @@ -325,53 +276,77 @@ By default, this just returns the input unchanged. """ tweak_adtype(adtype::ADTypes.AbstractADType, ::Model, ::AbstractVarInfo) = adtype -""" - use_closure(adtype::ADTypes.AbstractADType) - -In LogDensityProblems, we want to calculate the derivative of logdensity(f, x) -with respect to x, where f is the model (in our case LogDensityFunction) and is -a constant. However, DifferentiationInterface generally expects a -single-argument function g(x) to differentiate. +###################################################### +# Helper functions to extract ranges and link status # +###################################################### -There are two ways of dealing with this: - -1. Construct a closure over the model, i.e. let g = Base.Fix1(logdensity, f) - -2. Use a constant DI.Context. This lets us pass a two-argument function to DI, - as long as we also give it the 'inactive argument' (i.e. the model) wrapped - in `DI.Constant`. - -The relative performance of the two approaches, however, depends on the AD -backend used. Some benchmarks are provided here: -https://github.com/TuringLang/DynamicPPL.jl/issues/946#issuecomment-2931604829 - -This function is used to determine whether a given AD backend should use a -closure or a constant. If `use_closure(adtype)` returns `true`, then the -closure approach will be used. By default, this function returns `false`, i.e. -the constant approach will be used. +# This fails for SimpleVarInfo, but honestly there is no reason to support that here. The +# fact is that evaluation doesn't use a VarInfo, it only uses it once to generate the ranges +# and link status. So there is no motivation to use SimpleVarInfo inside a +# LogDensityFunction any more, we can just always use typed VarInfo. In fact one could argue +# that there is no purpose in supporting untyped VarInfo either. """ -use_closure(::ADTypes.AbstractADType) = true -use_closure(::ADTypes.AutoEnzyme) = false + get_ranges_and_linked(varinfo::VarInfo) -""" - getmodel(f) +Given a `VarInfo`, extract the ranges of each variable in the vectorised parameter +representation, along with whether each variable is linked or unlinked. -Return the `DynamicPPL.Model` wrapped in the given log-density function `f`. -""" -getmodel(f::DynamicPPL.LogDensityFunction) = f.model +This function should return a tuple containing: +- A NamedTuple mapping VarNames with identity optics to their corresponding `RangeAndLinked` +- A Dict mapping all other VarNames to their corresponding `RangeAndLinked`. """ - setmodel(f, model[, adtype]) - -Set the `DynamicPPL.Model` in the given log-density function `f` to `model`. -""" -function setmodel(f::DynamicPPL.LogDensityFunction, model::DynamicPPL.Model) - return LogDensityFunction(model, f.getlogdensity, f.varinfo; adtype=f.adtype) +function get_ranges_and_linked(varinfo::VarInfo{<:NamedTuple{syms}}) where {syms} + all_iden_ranges = NamedTuple() + all_ranges = Dict{VarName,RangeAndLinked}() + offset = 1 + for sym in syms + md = varinfo.metadata[sym] + this_md_iden, this_md_others, offset = get_ranges_and_linked_metadata(md, offset) + all_iden_ranges = merge(all_iden_ranges, this_md_iden) + all_ranges = merge(all_ranges, this_md_others) + end + return all_iden_ranges, all_ranges +end +function get_ranges_and_linked(varinfo::VarInfo{<:Union{Metadata,VarNamedVector}}) + all_iden, all_others, _ = get_ranges_and_linked_metadata(varinfo.metadata, 1) + return all_iden, all_others +end +function get_ranges_and_linked_metadata(md::Metadata, start_offset::Int) + all_iden_ranges = NamedTuple() + all_ranges = Dict{VarName,RangeAndLinked}() + offset = start_offset + for (vn, idx) in md.idcs + is_linked = md.is_transformed[idx] + range = md.ranges[idx] .+ (start_offset - 1) + if AbstractPPL.getoptic(vn) === identity + all_iden_ranges = merge( + all_iden_ranges, + NamedTuple((AbstractPPL.getsym(vn) => RangeAndLinked(range, is_linked),)), + ) + else + all_ranges[vn] = RangeAndLinked(range, is_linked) + end + offset += length(range) + end + return all_iden_ranges, all_ranges, offset +end +function get_ranges_and_linked_metadata(vnv::VarNamedVector, start_offset::Int) + all_iden_ranges = NamedTuple() + all_ranges = Dict{VarName,RangeAndLinked}() + offset = start_offset + for (vn, idx) in vnv.varname_to_index + is_linked = vnv.is_unconstrained[idx] + range = vnv.ranges[idx] .+ (start_offset - 1) + if AbstractPPL.getoptic(vn) === identity + all_iden_ranges = merge( + all_iden_ranges, + NamedTuple((AbstractPPL.getsym(vn) => RangeAndLinked(range, is_linked),)), + ) + else + all_ranges[vn] = RangeAndLinked(range, is_linked) + end + offset += length(range) + end + return all_iden_ranges, all_ranges, offset end - -""" - getparams(f::LogDensityFunction) - -Return the parameters of the wrapped varinfo as a vector. -""" -getparams(f::LogDensityFunction) = f.varinfo[:] diff --git a/src/model.jl b/src/model.jl index 2bcfe8f98..9029318b1 100644 --- a/src/model.jl +++ b/src/model.jl @@ -881,30 +881,56 @@ end [init_strategy::AbstractInitStrategy=InitFromPrior()] ) -Evaluate the `model` and replace the values of the model's random variables -in the given `varinfo` with new values, using a specified initialisation strategy. -If the values in `varinfo` are not set, they will be added -using a specified initialisation strategy. +Evaluate the `model` and replace the values of the model's random variables in the given +`varinfo` with new values, using a specified initialisation strategy. If the values in +`varinfo` are not set, they will be added using a specified initialisation strategy. If `init_strategy` is not provided, defaults to `InitFromPrior()`. Returns a tuple of the model's return value, plus the updated `varinfo` object. """ -function init!!( +@inline function init!!( + # Note that this `@inline` is mandatory for performance, especially for + # LogDensityFunction. If it's not inlined, it leads to extra allocations (even for + # trivial models) and much slower runtime. rng::Random.AbstractRNG, model::Model, - varinfo::AbstractVarInfo, - init_strategy::AbstractInitStrategy=InitFromPrior(), + vi::AbstractVarInfo, + strategy::AbstractInitStrategy=InitFromPrior(), ) - new_model = setleafcontext(model, InitContext(rng, init_strategy)) - return evaluate!!(new_model, varinfo) + ctx = InitContext(rng, strategy) + model = DynamicPPL.setleafcontext(model, ctx) + # TODO(penelopeysm): This should _not_ check Threads.nthreads(). I still don't know what + # it _should_ do, but this is wrong regardless. + # https://github.com/TuringLang/DynamicPPL.jl/issues/1086 + return if Threads.nthreads() > 1 + # TODO(penelopeysm): The logic for setting eltype of accs is very similar to that + # used in `unflatten`. The reason why we need it here is because the VarInfo `vi` + # won't have been filled with parameters prior to `init!!` being called. + # + # Note that this eltype promotion is only needed for threadsafe evaluation. In an + # ideal world, this code should be handled inside `evaluate_threadsafe!!` or a + # similar method. In other words, it should not be here, and it should not be inside + # `unflatten` either. The problem is performance. Shifting this code around can have + # massive, inexplicable, impacts on performance. This should be investigated + # properly. + param_eltype = DynamicPPL.get_param_eltype(strategy) + accs = map(vi.accs) do acc + DynamicPPL.convert_eltype(float_type_with_fallback(param_eltype), acc) + end + vi = DynamicPPL.setaccs!!(vi, accs) + tsvi = ThreadSafeVarInfo(resetaccs!!(vi)) + retval, tsvi_new = DynamicPPL._evaluate!!(model, tsvi) + return retval, setaccs!!(tsvi_new.varinfo, DynamicPPL.getaccs(tsvi_new)) + else + return DynamicPPL._evaluate!!(model, resetaccs!!(vi)) + end end -function init!!( - model::Model, - varinfo::AbstractVarInfo, - init_strategy::AbstractInitStrategy=InitFromPrior(), +@inline function init!!( + model::Model, vi::AbstractVarInfo, strategy::AbstractInitStrategy=InitFromPrior() ) - return init!!(Random.default_rng(), model, varinfo, init_strategy) + # This `@inline` is also mandatory for performance + return init!!(Random.default_rng(), model, vi, strategy) end """ diff --git a/test/ad.jl b/test/ad.jl deleted file mode 100644 index 0236c232f..000000000 --- a/test/ad.jl +++ /dev/null @@ -1,137 +0,0 @@ -using DynamicPPL: LogDensityFunction -using DynamicPPL.TestUtils.AD: run_ad, WithExpectedResult, NoTest - -@testset "Automatic differentiation" begin - # Used as the ground truth that others are compared against. - ref_adtype = AutoForwardDiff() - - test_adtypes = [ - AutoReverseDiff(; compile=false), - AutoReverseDiff(; compile=true), - AutoMooncake(; config=nothing), - ] - - @testset "Unsupported backends" begin - @model demo() = x ~ Normal() - @test_logs (:warn, r"not officially supported") LogDensityFunction( - demo(); adtype=AutoZygote() - ) - end - - @testset "Correctness" begin - @testset "$(m.f)" for m in DynamicPPL.TestUtils.DEMO_MODELS - rand_param_values = DynamicPPL.TestUtils.rand_prior_true(m) - vns = DynamicPPL.TestUtils.varnames(m) - varinfos = DynamicPPL.TestUtils.setup_varinfos(m, rand_param_values, vns) - - @testset "$(short_varinfo_name(varinfo))" for varinfo in varinfos - linked_varinfo = DynamicPPL.link(varinfo, m) - f = LogDensityFunction(m, getlogjoint_internal, linked_varinfo) - x = DynamicPPL.getparams(f) - - # Calculate reference logp + gradient of logp using ForwardDiff - ref_ad_result = run_ad(m, ref_adtype; varinfo=linked_varinfo, test=NoTest()) - ref_logp, ref_grad = ref_ad_result.value_actual, ref_ad_result.grad_actual - - @testset "$adtype" for adtype in test_adtypes - @info "Testing AD on: $(m.f) - $(short_varinfo_name(linked_varinfo)) - $adtype" - - # Put predicates here to avoid long lines - is_mooncake = adtype isa AutoMooncake - is_1_10 = v"1.10" <= VERSION < v"1.11" - is_1_11_or_1_12 = v"1.11" <= VERSION < v"1.13" - is_svi_vnv = - linked_varinfo isa SimpleVarInfo{<:DynamicPPL.VarNamedVector} - is_svi_od = linked_varinfo isa SimpleVarInfo{<:OrderedDict} - - # Mooncake doesn't work with several combinations of SimpleVarInfo. - if is_mooncake && is_1_11_or_1_12 && is_svi_vnv - # https://github.com/compintell/Mooncake.jl/issues/470 - @test_throws ArgumentError DynamicPPL.LogDensityFunction( - m, getlogjoint_internal, linked_varinfo; adtype=adtype - ) - elseif is_mooncake && is_1_10 && is_svi_vnv - # TODO: report upstream - @test_throws UndefRefError DynamicPPL.LogDensityFunction( - m, getlogjoint_internal, linked_varinfo; adtype=adtype - ) - elseif is_mooncake && is_1_10 && is_svi_od - # TODO: report upstream - @test_throws Mooncake.MooncakeRuleCompilationError DynamicPPL.LogDensityFunction( - m, getlogjoint_internal, linked_varinfo; adtype=adtype - ) - else - @test run_ad( - m, - adtype; - varinfo=linked_varinfo, - test=WithExpectedResult(ref_logp, ref_grad), - ) isa Any - end - end - end - end - end - - # Test that various different ways of specifying array types as arguments work with all - # ADTypes. - @testset "Array argument types" begin - test_m = randn(2, 3) - - function eval_logp_and_grad(model, m, adtype) - ldf = LogDensityFunction(model(); adtype=adtype) - return LogDensityProblems.logdensity_and_gradient(ldf, m[:]) - end - - @model function scalar_matrix_model(::Type{T}=Float64) where {T<:Real} - m = Matrix{T}(undef, 2, 3) - return m ~ filldist(MvNormal(zeros(2), I), 3) - end - - scalar_matrix_model_reference = eval_logp_and_grad( - scalar_matrix_model, test_m, ref_adtype - ) - - @model function matrix_model(::Type{T}=Matrix{Float64}) where {T} - m = T(undef, 2, 3) - return m ~ filldist(MvNormal(zeros(2), I), 3) - end - - matrix_model_reference = eval_logp_and_grad(matrix_model, test_m, ref_adtype) - - @model function scalar_array_model(::Type{T}=Float64) where {T<:Real} - m = Array{T}(undef, 2, 3) - return m ~ filldist(MvNormal(zeros(2), I), 3) - end - - scalar_array_model_reference = eval_logp_and_grad( - scalar_array_model, test_m, ref_adtype - ) - - @model function array_model(::Type{T}=Array{Float64}) where {T} - m = T(undef, 2, 3) - return m ~ filldist(MvNormal(zeros(2), I), 3) - end - - array_model_reference = eval_logp_and_grad(array_model, test_m, ref_adtype) - - @testset "$adtype" for adtype in test_adtypes - scalar_matrix_model_logp_and_grad = eval_logp_and_grad( - scalar_matrix_model, test_m, adtype - ) - @test scalar_matrix_model_logp_and_grad[1] ≈ scalar_matrix_model_reference[1] - @test scalar_matrix_model_logp_and_grad[2] ≈ scalar_matrix_model_reference[2] - matrix_model_logp_and_grad = eval_logp_and_grad(matrix_model, test_m, adtype) - @test matrix_model_logp_and_grad[1] ≈ matrix_model_reference[1] - @test matrix_model_logp_and_grad[2] ≈ matrix_model_reference[2] - scalar_array_model_logp_and_grad = eval_logp_and_grad( - scalar_array_model, test_m, adtype - ) - @test scalar_array_model_logp_and_grad[1] ≈ scalar_array_model_reference[1] - @test scalar_array_model_logp_and_grad[2] ≈ scalar_array_model_reference[2] - array_model_logp_and_grad = eval_logp_and_grad(array_model, test_m, adtype) - @test array_model_logp_and_grad[1] ≈ array_model_reference[1] - @test array_model_logp_and_grad[2] ≈ array_model_reference[2] - end - end -end diff --git a/test/chains.jl b/test/chains.jl index 43b877d62..12a9ece71 100644 --- a/test/chains.jl +++ b/test/chains.jl @@ -66,7 +66,7 @@ using Test end end -@testset "ParamsWithStats from FastLDF" begin +@testset "ParamsWithStats from LogDensityFunction" begin @testset "$(m.f)" for m in DynamicPPL.TestUtils.DEMO_MODELS unlinked_vi = VarInfo(m) @testset "$islinked" for islinked in (false, true) @@ -77,9 +77,9 @@ end end params = [x for x in vi[:]] - # Get the ParamsWithStats using FastLDF - fldf = DynamicPPL.Experimental.FastLDF(m, getlogjoint, vi) - ps = ParamsWithStats(params, fldf) + # Get the ParamsWithStats using LogDensityFunction + ldf = DynamicPPL.LogDensityFunction(m, getlogjoint, vi) + ps = ParamsWithStats(params, ldf) # Check that length of parameters is as expected @test length(ps.params) == length(keys(vi)) diff --git a/test/fasteval.jl b/test/fasteval.jl deleted file mode 100644 index db2333711..000000000 --- a/test/fasteval.jl +++ /dev/null @@ -1,233 +0,0 @@ -module DynamicPPLFastLDFTests - -using AbstractPPL: AbstractPPL -using Chairmarks -using DynamicPPL -using Distributions -using DistributionsAD: filldist -using ADTypes -using DynamicPPL.Experimental: FastLDF -using DynamicPPL.TestUtils.AD: run_ad, WithExpectedResult, NoTest -using LinearAlgebra: I -using Test -using LogDensityProblems: LogDensityProblems - -using ForwardDiff: ForwardDiff -using ReverseDiff: ReverseDiff -# Need to include this block here in case we run this test file standalone -@static if VERSION < v"1.12" - using Pkg - Pkg.add("Mooncake") - using Mooncake: Mooncake -end - -@testset "FastLDF: Correctness" begin - @testset "$(m.f)" for m in DynamicPPL.TestUtils.DEMO_MODELS - @testset "$varinfo_func" for varinfo_func in [ - DynamicPPL.untyped_varinfo, - DynamicPPL.typed_varinfo, - DynamicPPL.untyped_vector_varinfo, - DynamicPPL.typed_vector_varinfo, - ] - unlinked_vi = varinfo_func(m) - @testset "$islinked" for islinked in (false, true) - vi = if islinked - DynamicPPL.link!!(unlinked_vi, m) - else - unlinked_vi - end - nt_ranges, dict_ranges = DynamicPPL.Experimental.get_ranges_and_linked(vi) - params = [x for x in vi[:]] - # Iterate over all variables - for vn in keys(vi) - # Check that `getindex_internal` returns the same thing as using the ranges - # directly - range_with_linked = if AbstractPPL.getoptic(vn) === identity - nt_ranges[AbstractPPL.getsym(vn)] - else - dict_ranges[vn] - end - @test params[range_with_linked.range] == - DynamicPPL.getindex_internal(vi, vn) - # Check that the link status is correct - @test range_with_linked.is_linked == islinked - end - - # Compare results of FastLDF vs ordinary LogDensityFunction. These tests - # can eventually go once we replace LogDensityFunction with FastLDF, but - # for now it helps to have this check! (Eventually we should just check - # against manually computed log-densities). - # - # TODO(penelopeysm): I think we need to add tests for some really - # pathological models here. - @testset "$getlogdensity" for getlogdensity in ( - DynamicPPL.getlogjoint_internal, - DynamicPPL.getlogjoint, - DynamicPPL.getloglikelihood, - DynamicPPL.getlogprior_internal, - DynamicPPL.getlogprior, - ) - ldf = DynamicPPL.LogDensityFunction(m, getlogdensity, vi) - fldf = FastLDF(m, getlogdensity, vi) - @test LogDensityProblems.logdensity(ldf, params) ≈ - LogDensityProblems.logdensity(fldf, params) - end - end - end - end - - @testset "Threaded observe" begin - if Threads.nthreads() > 1 - @model function threaded(y) - x ~ Normal() - Threads.@threads for i in eachindex(y) - y[i] ~ Normal(x) - end - end - N = 100 - model = threaded(zeros(N)) - ldf = DynamicPPL.Experimental.FastLDF(model) - - xs = [1.0] - @test LogDensityProblems.logdensity(ldf, xs) ≈ - logpdf(Normal(), xs[1]) + N * logpdf(Normal(xs[1]), 0.0) - end - end -end - -@testset "FastLDF: performance" begin - if Threads.nthreads() == 1 - # Evaluating these three models should not lead to any allocations (but only when - # not using TSVI). - @model function f() - x ~ Normal() - return 1.0 ~ Normal(x) - end - @model function submodel_inner() - m ~ Normal(0, 1) - s ~ Exponential() - return (m=m, s=s) - end - # Note that for the allocation tests to work on this one, `inner` has - # to be passed as an argument to `submodel_outer`, instead of just - # being called inside the model function itself - @model function submodel_outer(inner) - params ~ to_submodel(inner) - y ~ Normal(params.m, params.s) - return 1.0 ~ Normal(y) - end - @testset for model in - (f(), submodel_inner() | (; s=0.0), submodel_outer(submodel_inner())) - vi = VarInfo(model) - fldf = DynamicPPL.Experimental.FastLDF( - model, DynamicPPL.getlogjoint_internal, vi - ) - x = vi[:] - bench = median(@be LogDensityProblems.logdensity(fldf, x)) - @test iszero(bench.allocs) - end - end -end - -@testset "AD with FastLDF" begin - # Used as the ground truth that others are compared against. - ref_adtype = AutoForwardDiff() - - test_adtypes = @static if VERSION < v"1.12" - [ - AutoReverseDiff(; compile=false), - AutoReverseDiff(; compile=true), - AutoMooncake(; config=nothing), - ] - else - [AutoReverseDiff(; compile=false), AutoReverseDiff(; compile=true)] - end - - @testset "Correctness" begin - @testset "$(m.f)" for m in DynamicPPL.TestUtils.DEMO_MODELS - varinfo = VarInfo(m) - linked_varinfo = DynamicPPL.link(varinfo, m) - f = FastLDF(m, getlogjoint_internal, linked_varinfo) - x = [p for p in linked_varinfo[:]] - - # Calculate reference logp + gradient of logp using ForwardDiff - ref_ad_result = run_ad(m, ref_adtype; varinfo=linked_varinfo, test=NoTest()) - ref_logp, ref_grad = ref_ad_result.value_actual, ref_ad_result.grad_actual - - @testset "$adtype" for adtype in test_adtypes - @info "Testing AD on: $(m.f) - $adtype" - - @test run_ad( - m, - adtype; - varinfo=linked_varinfo, - test=WithExpectedResult(ref_logp, ref_grad), - ) isa Any - end - end - end - - # Test that various different ways of specifying array types as arguments work with all - # ADTypes. - @testset "Array argument types" begin - test_m = randn(2, 3) - - function eval_logp_and_grad(model, m, adtype) - ldf = FastLDF(model(); adtype=adtype) - return LogDensityProblems.logdensity_and_gradient(ldf, m[:]) - end - - @model function scalar_matrix_model(::Type{T}=Float64) where {T<:Real} - m = Matrix{T}(undef, 2, 3) - return m ~ filldist(MvNormal(zeros(2), I), 3) - end - - scalar_matrix_model_reference = eval_logp_and_grad( - scalar_matrix_model, test_m, ref_adtype - ) - - @model function matrix_model(::Type{T}=Matrix{Float64}) where {T} - m = T(undef, 2, 3) - return m ~ filldist(MvNormal(zeros(2), I), 3) - end - - matrix_model_reference = eval_logp_and_grad(matrix_model, test_m, ref_adtype) - - @model function scalar_array_model(::Type{T}=Float64) where {T<:Real} - m = Array{T}(undef, 2, 3) - return m ~ filldist(MvNormal(zeros(2), I), 3) - end - - scalar_array_model_reference = eval_logp_and_grad( - scalar_array_model, test_m, ref_adtype - ) - - @model function array_model(::Type{T}=Array{Float64}) where {T} - m = T(undef, 2, 3) - return m ~ filldist(MvNormal(zeros(2), I), 3) - end - - array_model_reference = eval_logp_and_grad(array_model, test_m, ref_adtype) - - @testset "$adtype" for adtype in test_adtypes - scalar_matrix_model_logp_and_grad = eval_logp_and_grad( - scalar_matrix_model, test_m, adtype - ) - @test scalar_matrix_model_logp_and_grad[1] ≈ scalar_matrix_model_reference[1] - @test scalar_matrix_model_logp_and_grad[2] ≈ scalar_matrix_model_reference[2] - matrix_model_logp_and_grad = eval_logp_and_grad(matrix_model, test_m, adtype) - @test matrix_model_logp_and_grad[1] ≈ matrix_model_reference[1] - @test matrix_model_logp_and_grad[2] ≈ matrix_model_reference[2] - scalar_array_model_logp_and_grad = eval_logp_and_grad( - scalar_array_model, test_m, adtype - ) - @test scalar_array_model_logp_and_grad[1] ≈ scalar_array_model_reference[1] - @test scalar_array_model_logp_and_grad[2] ≈ scalar_array_model_reference[2] - array_model_logp_and_grad = eval_logp_and_grad(array_model, test_m, adtype) - @test array_model_logp_and_grad[1] ≈ array_model_reference[1] - @test array_model_logp_and_grad[2] ≈ array_model_reference[2] - end - end -end - -end diff --git a/test/logdensityfunction.jl b/test/logdensityfunction.jl index fbd868f71..06492d6e1 100644 --- a/test/logdensityfunction.jl +++ b/test/logdensityfunction.jl @@ -1,49 +1,240 @@ -using Test, DynamicPPL, ADTypes, LogDensityProblems, ForwardDiff - -@testset "`getmodel` and `setmodel`" begin - @testset "$(nameof(model))" for model in DynamicPPL.TestUtils.DEMO_MODELS - model = DynamicPPL.TestUtils.DEMO_MODELS[1] - ℓ = DynamicPPL.LogDensityFunction(model) - @test DynamicPPL.getmodel(ℓ) == model - @test DynamicPPL.setmodel(ℓ, model).model == model +module DynamicPPLLDFTests + +using AbstractPPL: AbstractPPL +using Chairmarks +using DynamicPPL +using Distributions +using DistributionsAD: filldist +using ADTypes +using DynamicPPL.TestUtils.AD: run_ad, WithExpectedResult, NoTest +using LinearAlgebra: I +using Test +using LogDensityProblems: LogDensityProblems + +using ForwardDiff: ForwardDiff +using ReverseDiff: ReverseDiff +using Mooncake: Mooncake + +@testset "LogDensityFunction: Correctness" begin + @testset "$(m.f)" for m in DynamicPPL.TestUtils.DEMO_MODELS + @testset "$varinfo_func" for varinfo_func in [ + DynamicPPL.untyped_varinfo, + DynamicPPL.typed_varinfo, + DynamicPPL.untyped_vector_varinfo, + DynamicPPL.typed_vector_varinfo, + ] + unlinked_vi = varinfo_func(m) + @testset "$islinked" for islinked in (false, true) + vi = if islinked + DynamicPPL.link!!(unlinked_vi, m) + else + unlinked_vi + end + nt_ranges, dict_ranges = DynamicPPL.get_ranges_and_linked(vi) + params = [x for x in vi[:]] + # Iterate over all variables + for vn in keys(vi) + # Check that `getindex_internal` returns the same thing as using the ranges + # directly + range_with_linked = if AbstractPPL.getoptic(vn) === identity + nt_ranges[AbstractPPL.getsym(vn)] + else + dict_ranges[vn] + end + @test params[range_with_linked.range] == + DynamicPPL.getindex_internal(vi, vn) + # Check that the link status is correct + @test range_with_linked.is_linked == islinked + end + end + end + end + + @testset "Threaded observe" begin + if Threads.nthreads() > 1 + @model function threaded(y) + x ~ Normal() + Threads.@threads for i in eachindex(y) + y[i] ~ Normal(x) + end + end + N = 100 + model = threaded(zeros(N)) + ldf = DynamicPPL.LogDensityFunction(model) + + xs = [1.0] + @test LogDensityProblems.logdensity(ldf, xs) ≈ + logpdf(Normal(), xs[1]) + N * logpdf(Normal(xs[1]), 0.0) + end end end -@testset "LogDensityFunction" begin - @testset "$(nameof(model))" for model in DynamicPPL.TestUtils.DEMO_MODELS - example_values = DynamicPPL.TestUtils.rand_prior_true(model) - vns = DynamicPPL.TestUtils.varnames(model) - varinfos = DynamicPPL.TestUtils.setup_varinfos(model, example_values, vns) - - vi = first(varinfos) - theta = vi[:] - ldf_joint = DynamicPPL.LogDensityFunction(model) - @test LogDensityProblems.logdensity(ldf_joint, theta) ≈ logjoint(model, vi) - ldf_prior = DynamicPPL.LogDensityFunction(model, getlogprior) - @test LogDensityProblems.logdensity(ldf_prior, theta) ≈ logprior(model, vi) - ldf_likelihood = DynamicPPL.LogDensityFunction(model, getloglikelihood) - @test LogDensityProblems.logdensity(ldf_likelihood, theta) ≈ - loglikelihood(model, vi) - - @testset "$(varinfo)" for varinfo in varinfos - # Note use of `getlogjoint` rather than `getlogjoint_internal` here ... - logdensity = DynamicPPL.LogDensityFunction(model, getlogjoint, varinfo) - θ = varinfo[:] - # ... because it has to match with `logjoint(model, vi)`, which always returns - # the unlinked value - @test LogDensityProblems.logdensity(logdensity, θ) ≈ logjoint(model, varinfo) - @test LogDensityProblems.dimension(logdensity) == length(θ) +@testset "LogDensityFunction: interface" begin + # miscellaneous parts of the LogDensityProblems interface + @testset "dimensions" begin + @model function m1() + x ~ Normal() + y ~ Normal() + return nothing + end + model = m1() + ldf = DynamicPPL.LogDensityFunction(model) + @test LogDensityProblems.dimension(ldf) == 2 + + @model function m2() + x ~ Dirichlet(ones(4)) + y ~ Categorical(x) + return nothing end + model = m2() + ldf = DynamicPPL.LogDensityFunction(model) + @test LogDensityProblems.dimension(ldf) == 5 + linked_vi = DynamicPPL.link!!(VarInfo(model), model) + ldf = DynamicPPL.LogDensityFunction(model, getlogjoint_internal, linked_vi) + @test LogDensityProblems.dimension(ldf) == 4 end @testset "capabilities" begin - model = DynamicPPL.TestUtils.DEMO_MODELS[1] + @model f() = x ~ Normal() + model = f() + # No adtype ldf = DynamicPPL.LogDensityFunction(model) @test LogDensityProblems.capabilities(typeof(ldf)) == LogDensityProblems.LogDensityOrder{0}() - - ldf_with_ad = DynamicPPL.LogDensityFunction(model; adtype=AutoForwardDiff()) - @test LogDensityProblems.capabilities(typeof(ldf_with_ad)) == + # With adtype + ldf = DynamicPPL.LogDensityFunction(model; adtype=AutoForwardDiff()) + @test LogDensityProblems.capabilities(typeof(ldf)) == LogDensityProblems.LogDensityOrder{1}() end end + +@testset "LogDensityFunction: performance" begin + if Threads.nthreads() == 1 + # Evaluating these three models should not lead to any allocations (but only when + # not using TSVI). + @model function f() + x ~ Normal() + return 1.0 ~ Normal(x) + end + @model function submodel_inner() + m ~ Normal(0, 1) + s ~ Exponential() + return (m=m, s=s) + end + # Note that for the allocation tests to work on this one, `inner` has + # to be passed as an argument to `submodel_outer`, instead of just + # being called inside the model function itself + @model function submodel_outer(inner) + params ~ to_submodel(inner) + y ~ Normal(params.m, params.s) + return 1.0 ~ Normal(y) + end + @testset for model in + (f(), submodel_inner() | (; s=0.0), submodel_outer(submodel_inner())) + vi = VarInfo(model) + ldf = DynamicPPL.LogDensityFunction(model, DynamicPPL.getlogjoint_internal, vi) + x = vi[:] + bench = median(@be LogDensityProblems.logdensity(ldf, x)) + @test iszero(bench.allocs) + end + end +end + +@testset "AD with LogDensityFunction" begin + # Used as the ground truth that others are compared against. + ref_adtype = AutoForwardDiff() + + test_adtypes = [ + AutoReverseDiff(; compile=false), + AutoReverseDiff(; compile=true), + AutoMooncake(; config=nothing), + ] + + @testset "Correctness" begin + @testset "$(m.f)" for m in DynamicPPL.TestUtils.DEMO_MODELS + varinfo = VarInfo(m) + linked_varinfo = DynamicPPL.link(varinfo, m) + f = LogDensityFunction(m, getlogjoint_internal, linked_varinfo) + x = [p for p in linked_varinfo[:]] + + # Calculate reference logp + gradient of logp using ForwardDiff + ref_ad_result = run_ad(m, ref_adtype; varinfo=linked_varinfo, test=NoTest()) + ref_logp, ref_grad = ref_ad_result.value_actual, ref_ad_result.grad_actual + + @testset "$adtype" for adtype in test_adtypes + @info "Testing AD on: $(m.f) - $adtype" + + @test run_ad( + m, + adtype; + varinfo=linked_varinfo, + test=WithExpectedResult(ref_logp, ref_grad), + ) isa Any + end + end + end + + # Test that various different ways of specifying array types as arguments work with all + # ADTypes. + @testset "Array argument types" begin + test_m = randn(2, 3) + + function eval_logp_and_grad(model, m, adtype) + ldf = LogDensityFunction(model(); adtype=adtype) + return LogDensityProblems.logdensity_and_gradient(ldf, m[:]) + end + + @model function scalar_matrix_model(::Type{T}=Float64) where {T<:Real} + m = Matrix{T}(undef, 2, 3) + return m ~ filldist(MvNormal(zeros(2), I), 3) + end + + scalar_matrix_model_reference = eval_logp_and_grad( + scalar_matrix_model, test_m, ref_adtype + ) + + @model function matrix_model(::Type{T}=Matrix{Float64}) where {T} + m = T(undef, 2, 3) + return m ~ filldist(MvNormal(zeros(2), I), 3) + end + + matrix_model_reference = eval_logp_and_grad(matrix_model, test_m, ref_adtype) + + @model function scalar_array_model(::Type{T}=Float64) where {T<:Real} + m = Array{T}(undef, 2, 3) + return m ~ filldist(MvNormal(zeros(2), I), 3) + end + + scalar_array_model_reference = eval_logp_and_grad( + scalar_array_model, test_m, ref_adtype + ) + + @model function array_model(::Type{T}=Array{Float64}) where {T} + m = T(undef, 2, 3) + return m ~ filldist(MvNormal(zeros(2), I), 3) + end + + array_model_reference = eval_logp_and_grad(array_model, test_m, ref_adtype) + + @testset "$adtype" for adtype in test_adtypes + scalar_matrix_model_logp_and_grad = eval_logp_and_grad( + scalar_matrix_model, test_m, adtype + ) + @test scalar_matrix_model_logp_and_grad[1] ≈ scalar_matrix_model_reference[1] + @test scalar_matrix_model_logp_and_grad[2] ≈ scalar_matrix_model_reference[2] + matrix_model_logp_and_grad = eval_logp_and_grad(matrix_model, test_m, adtype) + @test matrix_model_logp_and_grad[1] ≈ matrix_model_reference[1] + @test matrix_model_logp_and_grad[2] ≈ matrix_model_reference[2] + scalar_array_model_logp_and_grad = eval_logp_and_grad( + scalar_array_model, test_m, adtype + ) + @test scalar_array_model_logp_and_grad[1] ≈ scalar_array_model_reference[1] + @test scalar_array_model_logp_and_grad[2] ≈ scalar_array_model_reference[2] + array_model_logp_and_grad = eval_logp_and_grad(array_model, test_m, adtype) + @test array_model_logp_and_grad[1] ≈ array_model_reference[1] + @test array_model_logp_and_grad[2] ≈ array_model_reference[2] + end + end +end + +end diff --git a/test/runtests.jl b/test/runtests.jl index 1474b426a..9649aebbb 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -13,6 +13,7 @@ using ForwardDiff using LogDensityProblems using MacroTools using MCMCChains +using Mooncake using StableRNGs using ReverseDiff using Mooncake @@ -57,7 +58,6 @@ include("test_util.jl") include("simple_varinfo.jl") include("model.jl") include("distribution_wrappers.jl") - include("logdensityfunction.jl") include("linking.jl") include("serialization.jl") include("pointwise_logdensities.jl") @@ -68,10 +68,11 @@ include("test_util.jl") include("debug_utils.jl") include("submodels.jl") include("chains.jl") - include("bijector.jl") end if GROUP == "All" || GROUP == "Group2" + include("bijector.jl") + include("logdensityfunction.jl") @testset "extensions" begin include("ext/DynamicPPLMCMCChainsExt.jl") include("ext/DynamicPPLJETExt.jl") @@ -80,8 +81,6 @@ include("test_util.jl") @testset "ad" begin include("ext/DynamicPPLForwardDiffExt.jl") include("ext/DynamicPPLMooncakeExt.jl") - include("ad.jl") - include("fasteval.jl") end @testset "prob and logprob macro" begin @test_throws ErrorException prob"..." From 60736ba6ce9a58222621387fd34ad1dbe55afe5d Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Sat, 15 Nov 2025 17:02:12 +0000 Subject: [PATCH 12/15] Improve type stability when all parameters are linked or unlinked --- src/contexts/init.jl | 20 ++++++++++---- src/logdensityfunction.jl | 56 +++++++++++++++++++++++++++++--------- test/logdensityfunction.jl | 16 +++++++++++ 3 files changed, 74 insertions(+), 18 deletions(-) diff --git a/src/contexts/init.jl b/src/contexts/init.jl index a79969a13..b70bf2bf1 100644 --- a/src/contexts/init.jl +++ b/src/contexts/init.jl @@ -214,7 +214,7 @@ struct RangeAndLinked end """ - VectorWithRanges( + VectorWithRanges{Tlink}( iden_varname_ranges::NamedTuple, varname_ranges::Dict{VarName,RangeAndLinked}, vect::AbstractVector{<:Real}, @@ -231,13 +231,19 @@ non-identity-optic VarNames are stored in the `varname_ranges` Dict. It would be nice to improve the NamedTuple and Dict approach. See, e.g. https://github.com/TuringLang/DynamicPPL.jl/issues/1116. """ -struct VectorWithRanges{N<:NamedTuple,T<:AbstractVector{<:Real}} +struct VectorWithRanges{Tlink,N<:NamedTuple,T<:AbstractVector{<:Real}} # This NamedTuple stores the ranges for identity VarNames iden_varname_ranges::N # This Dict stores the ranges for all other VarNames varname_ranges::Dict{VarName,RangeAndLinked} # The full parameter vector which we index into to get variable values vect::T + + function VectorWithRanges{Tlink}( + iden_varname_ranges::N, varname_ranges::Dict{VarName,RangeAndLinked}, vect::T + ) where {Tlink,N,T} + return new{Tlink,N,T}(iden_varname_ranges, varname_ranges, vect) + end end function _get_range_and_linked( @@ -252,11 +258,15 @@ function init( ::Random.AbstractRNG, vn::VarName, dist::Distribution, - p::InitFromParams{<:VectorWithRanges}, -) + p::InitFromParams{<:VectorWithRanges{T}}, +) where {T} vr = p.params range_and_linked = _get_range_and_linked(vr, vn) - transform = if range_and_linked.is_linked + # T can either be `nothing` (i.e., link status is mixed, in which + # case we use the stored link status), or `true` / `false`, which + # indicates that all variables are linked / unlinked. + linked = isnothing(T) ? range_and_linked.is_linked : T + transform = if linked from_linked_vec_transform(dist) else from_vec_transform(dist) diff --git a/src/logdensityfunction.jl b/src/logdensityfunction.jl index 65eab448e..ffffa9568 100644 --- a/src/logdensityfunction.jl +++ b/src/logdensityfunction.jl @@ -140,6 +140,9 @@ with such models.** This is a general limitation of vectorised parameters: the o `unflatten` + `evaluate!!` approach also fails with such models. """ struct LogDensityFunction{ + # true if all variables are linked; false if all variables are unlinked; nothing if + # mixed + Tlink, M<:Model, AD<:Union{ADTypes.AbstractADType,Nothing}, F<:Function, @@ -163,6 +166,21 @@ struct LogDensityFunction{ # Figure out which variable corresponds to which index, and # which variables are linked. all_iden_ranges, all_ranges = get_ranges_and_linked(varinfo) + # Figure out if all variables are linked, unlinked, or mixed + link_statuses = Bool[] + for ral in all_iden_ranges + push!(link_statuses, ral.is_linked) + end + for (_, ral) in all_ranges + push!(link_statuses, ral.is_linked) + end + Tlink = if all(link_statuses) + true + elseif all(!s for s in link_statuses) + false + else + nothing + end x = [val for val in varinfo[:]] dim = length(x) # Do AD prep if needed @@ -172,12 +190,13 @@ struct LogDensityFunction{ # Make backend-specific tweaks to the adtype adtype = DynamicPPL.tweak_adtype(adtype, model, varinfo) DI.prepare_gradient( - LogDensityAt(model, getlogdensity, all_iden_ranges, all_ranges), + LogDensityAt{Tlink}(model, getlogdensity, all_iden_ranges, all_ranges), adtype, x, ) end return new{ + Tlink, typeof(model), typeof(adtype), typeof(getlogdensity), @@ -209,15 +228,24 @@ end fast_ldf_accs(::typeof(getlogprior)) = AccumulatorTuple((LogPriorAccumulator(),)) fast_ldf_accs(::typeof(getloglikelihood)) = AccumulatorTuple((LogLikelihoodAccumulator(),)) -struct LogDensityAt{M<:Model,F<:Function,N<:NamedTuple} +struct LogDensityAt{Tlink,M<:Model,F<:Function,N<:NamedTuple} model::M getlogdensity::F iden_varname_ranges::N varname_ranges::Dict{VarName,RangeAndLinked} + + function LogDensityAt{Tlink}( + model::M, + getlogdensity::F, + iden_varname_ranges::N, + varname_ranges::Dict{VarName,RangeAndLinked}, + ) where {Tlink,M,F,N} + return new{Tlink,M,F,N}(model, getlogdensity, iden_varname_ranges, varname_ranges) + end end -function (f::LogDensityAt)(params::AbstractVector{<:Real}) +function (f::LogDensityAt{Tlink})(params::AbstractVector{<:Real}) where {Tlink} strategy = InitFromParams( - VectorWithRanges(f.iden_varname_ranges, f.varname_ranges, params), nothing + VectorWithRanges{Tlink}(f.iden_varname_ranges, f.varname_ranges, params), nothing ) accs = fast_ldf_accs(f.getlogdensity) _, vi = DynamicPPL.init!!(f.model, OnlyAccsVarInfo(accs), strategy) @@ -225,9 +253,9 @@ function (f::LogDensityAt)(params::AbstractVector{<:Real}) end function LogDensityProblems.logdensity( - ldf::LogDensityFunction, params::AbstractVector{<:Real} -) - return LogDensityAt( + ldf::LogDensityFunction{Tlink}, params::AbstractVector{<:Real} +) where {Tlink} + return LogDensityAt{Tlink}( ldf.model, ldf._getlogdensity, ldf._iden_varname_ranges, ldf._varname_ranges )( params @@ -235,10 +263,10 @@ function LogDensityProblems.logdensity( end function LogDensityProblems.logdensity_and_gradient( - ldf::LogDensityFunction, params::AbstractVector{<:Real} -) + ldf::LogDensityFunction{Tlink}, params::AbstractVector{<:Real} +) where {Tlink} return DI.value_and_gradient( - LogDensityAt( + LogDensityAt{Tlink}( ldf.model, ldf._getlogdensity, ldf._iden_varname_ranges, ldf._varname_ranges ), ldf._adprep, @@ -247,12 +275,14 @@ function LogDensityProblems.logdensity_and_gradient( ) end -function LogDensityProblems.capabilities(::Type{<:LogDensityFunction{M,Nothing}}) where {M} +function LogDensityProblems.capabilities( + ::Type{<:LogDensityFunction{T,M,Nothing}} +) where {T,M} return LogDensityProblems.LogDensityOrder{0}() end function LogDensityProblems.capabilities( - ::Type{<:LogDensityFunction{M,<:ADTypes.AbstractADType}} -) where {M} + ::Type{<:LogDensityFunction{T,M,<:ADTypes.AbstractADType}} +) where {T,M} return LogDensityProblems.LogDensityOrder{1}() end function LogDensityProblems.dimension(ldf::LogDensityFunction) diff --git a/test/logdensityfunction.jl b/test/logdensityfunction.jl index 06492d6e1..f43ed45a4 100644 --- a/test/logdensityfunction.jl +++ b/test/logdensityfunction.jl @@ -108,6 +108,22 @@ end end end +@testset "LogDensityFunction: Type stability" begin + @testset "$(m.f)" for m in DynamicPPL.TestUtils.DEMO_MODELS + unlinked_vi = DynamicPPL.VarInfo(m) + @testset "$islinked" for islinked in (false, true) + vi = if islinked + DynamicPPL.link!!(unlinked_vi, m) + else + unlinked_vi + end + ldf = DynamicPPL.LogDensityFunction(m, DynamicPPL.getlogjoint_internal, vi) + x = vi[:] + @inferred LogDensityProblems.logdensity(ldf, x) + end + end +end + @testset "LogDensityFunction: performance" begin if Threads.nthreads() == 1 # Evaluating these three models should not lead to any allocations (but only when From d953f7b46c06198247a1d388a8dc414eff703759 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Tue, 25 Nov 2025 12:36:24 +0000 Subject: [PATCH 13/15] fix a merge conflict --- src/chains.jl | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/chains.jl b/src/chains.jl index f176b8e68..4660a1a31 100644 --- a/src/chains.jl +++ b/src/chains.jl @@ -156,13 +156,15 @@ via `unflatten` plus re-evaluation. It is faster for two reasons: """ function ParamsWithStats( param_vector::AbstractVector, - ldf::DynamicPPL.LogDensityFunction, + ldf::DynamicPPL.LogDensityFunction{Tlink}, stats::NamedTuple=NamedTuple(); include_colon_eq::Bool=true, include_log_probs::Bool=true, -) +) where {Tlink} strategy = InitFromParams( - VectorWithRanges(ldf._iden_varname_ranges, ldf._varname_ranges, param_vector), + VectorWithRanges{Tlink}( + ldf._iden_varname_ranges, ldf._varname_ranges, param_vector + ), nothing, ) accs = if include_log_probs From d7da26dc9de4ef27138b54670c5af1e1a4e0a5fb Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Tue, 25 Nov 2025 12:41:50 +0000 Subject: [PATCH 14/15] fix enzyme gc crash (locally at least) --- test/integration/enzyme/main.jl | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/test/integration/enzyme/main.jl b/test/integration/enzyme/main.jl index ea4ec497d..edfd67d18 100644 --- a/test/integration/enzyme/main.jl +++ b/test/integration/enzyme/main.jl @@ -5,11 +5,15 @@ using Test: @test, @testset import Enzyme: set_runtime_activity, Forward, Reverse, Const using ForwardDiff: ForwardDiff # run_ad uses FD for correctness test -ADTYPES = Dict( - "EnzymeForward" => +ADTYPES = ( + ( + "EnzymeForward", AutoEnzyme(; mode=set_runtime_activity(Forward), function_annotation=Const), - "EnzymeReverse" => + ), + ( + "EnzymeReverse", AutoEnzyme(; mode=set_runtime_activity(Reverse), function_annotation=Const), + ), ) @testset "$ad_key" for (ad_key, ad_type) in ADTYPES From dd9389fcd15687838caf7b7005a332189157d39b Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Tue, 25 Nov 2025 21:39:19 +0000 Subject: [PATCH 15/15] RangeLinkedValAcc --- ext/DynamicPPLForwardDiffExt.jl | 6 +- src/logdensityfunction.jl | 213 ++++++++++++++++----------- test/ext/DynamicPPLForwardDiffExt.jl | 5 +- 3 files changed, 128 insertions(+), 96 deletions(-) diff --git a/ext/DynamicPPLForwardDiffExt.jl b/ext/DynamicPPLForwardDiffExt.jl index 7ea51918f..4b6d3fb41 100644 --- a/ext/DynamicPPLForwardDiffExt.jl +++ b/ext/DynamicPPLForwardDiffExt.jl @@ -8,12 +8,8 @@ use_dynamicppl_tag(::ADTypes.AutoForwardDiff{<:Any,Nothing}) = true use_dynamicppl_tag(::ADTypes.AutoForwardDiff) = false function DynamicPPL.tweak_adtype( - ad::ADTypes.AutoForwardDiff{chunk_size}, - ::DynamicPPL.Model, - vi::DynamicPPL.AbstractVarInfo, + ad::ADTypes.AutoForwardDiff{chunk_size}, ::DynamicPPL.Model, params::AbstractVector ) where {chunk_size} - params = vi[:] - # Use DynamicPPL tag to improve stack traces # https://www.stochasticlifestyle.com/improved-forwarddiff-jl-stacktraces-with-package-tags/ # NOTE: DifferentiationInterface disables tag checking if the diff --git a/src/logdensityfunction.jl b/src/logdensityfunction.jl index ffffa9568..abfb61c94 100644 --- a/src/logdensityfunction.jl +++ b/src/logdensityfunction.jl @@ -157,38 +157,44 @@ struct LogDensityFunction{ _adprep::ADP _dim::Int + """ + function LogDensityFunction( + model::Model, + getlogdensity::Function=getlogjoint_internal, + link::Union{Bool,Set{VarName}}=false; + adtype::Union{ADTypes.AbstractADType,Nothing}=nothing, + ) + + Generate a `LogDensityFunction` for the given model. + + The `link` argument specifies which VarNames in the model should be linked. This can + either be a Bool (if `link=true` all variables are linked; if `link=false` all variables + are unlinked); or a `Set{VarName}` specifying exactly which variables should be linked. + Any sub-variables of the set's elements will be linked. + """ function LogDensityFunction( model::Model, getlogdensity::Function=getlogjoint_internal, - varinfo::AbstractVarInfo=VarInfo(model); + link::Union{Bool,Set{VarName}}=false; adtype::Union{ADTypes.AbstractADType,Nothing}=nothing, ) - # Figure out which variable corresponds to which index, and - # which variables are linked. - all_iden_ranges, all_ranges = get_ranges_and_linked(varinfo) - # Figure out if all variables are linked, unlinked, or mixed - link_statuses = Bool[] - for ral in all_iden_ranges - push!(link_statuses, ral.is_linked) - end - for (_, ral) in all_ranges - push!(link_statuses, ral.is_linked) - end - Tlink = if all(link_statuses) - true - elseif all(!s for s in link_statuses) - false - else - nothing - end - x = [val for val in varinfo[:]] + # Run the model once to determine variable ranges and linking. Because the + # parameters stored in the LogDensityFunction are never used, we can just use + # InitFromPrior to create new values. The actual values don't matter, only the + # length, since that's used for gradient prep. + vi = OnlyAccsVarInfo(AccumulatorTuple((RangeLinkedValueAcc(link),))) + _, vi = DynamicPPL.init!!(model, vi, InitFromPrior()) + rlvacc = first(vi.accs) + Tlink, all_iden_ranges, all_ranges, x = get_data(rlvacc) + @info Tlink, all_iden_ranges, all_ranges, x + # That gives us all the information we need to create the LogDensityFunction. dim = length(x) # Do AD prep if needed prep = if adtype === nothing nothing else # Make backend-specific tweaks to the adtype - adtype = DynamicPPL.tweak_adtype(adtype, model, varinfo) + adtype = DynamicPPL.tweak_adtype(adtype, model, x) DI.prepare_gradient( LogDensityAt{Tlink}(model, getlogdensity, all_iden_ranges, all_ranges), adtype, @@ -293,7 +299,7 @@ end tweak_adtype( adtype::ADTypes.AbstractADType, model::Model, - varinfo::AbstractVarInfo, + params::AbstractVector ) Return an 'optimised' form of the adtype. This is useful for doing @@ -304,79 +310,108 @@ model. By default, this just returns the input unchanged. """ -tweak_adtype(adtype::ADTypes.AbstractADType, ::Model, ::AbstractVarInfo) = adtype +tweak_adtype(adtype::ADTypes.AbstractADType, ::Model, ::AbstractVector) = adtype -###################################################### -# Helper functions to extract ranges and link status # -###################################################### +############################## +# RangeLinkedVal accumulator # +############################## -# This fails for SimpleVarInfo, but honestly there is no reason to support that here. The -# fact is that evaluation doesn't use a VarInfo, it only uses it once to generate the ranges -# and link status. So there is no motivation to use SimpleVarInfo inside a -# LogDensityFunction any more, we can just always use typed VarInfo. In fact one could argue -# that there is no purpose in supporting untyped VarInfo either. -""" - get_ranges_and_linked(varinfo::VarInfo) - -Given a `VarInfo`, extract the ranges of each variable in the vectorised parameter -representation, along with whether each variable is linked or unlinked. - -This function should return a tuple containing: +struct RangeLinkedValueAcc{L<:Union{Bool,Set{VarName}},N<:NamedTuple} <: AbstractAccumulator + should_link::L + current_index::Int + iden_varname_ranges::N + varname_ranges::Dict{VarName,RangeAndLinked} + values::Vector{Any} +end +function RangeLinkedValueAcc(should_link::Union{Bool,Set{VarName}}) + return RangeLinkedValueAcc(should_link, 1, (;), Dict{VarName,RangeAndLinked}(), Any[]) +end -- A NamedTuple mapping VarNames with identity optics to their corresponding `RangeAndLinked` -- A Dict mapping all other VarNames to their corresponding `RangeAndLinked`. -""" -function get_ranges_and_linked(varinfo::VarInfo{<:NamedTuple{syms}}) where {syms} - all_iden_ranges = NamedTuple() - all_ranges = Dict{VarName,RangeAndLinked}() - offset = 1 - for sym in syms - md = varinfo.metadata[sym] - this_md_iden, this_md_others, offset = get_ranges_and_linked_metadata(md, offset) - all_iden_ranges = merge(all_iden_ranges, this_md_iden) - all_ranges = merge(all_ranges, this_md_others) +function get_data(rlvacc::RangeLinkedValueAcc) + link_statuses = Bool[] + for ral in rlvacc.iden_varname_ranges + push!(link_statuses, ral.is_linked) end - return all_iden_ranges, all_ranges -end -function get_ranges_and_linked(varinfo::VarInfo{<:Union{Metadata,VarNamedVector}}) - all_iden, all_others, _ = get_ranges_and_linked_metadata(varinfo.metadata, 1) - return all_iden, all_others + for (_, ral) in rlvacc.varname_ranges + push!(link_statuses, ral.is_linked) + end + Tlink = if all(link_statuses) + true + elseif all(!s for s in link_statuses) + false + else + nothing + end + return ( + Tlink, rlvacc.iden_varname_ranges, rlvacc.varname_ranges, [v for v in rlvacc.values] + ) end -function get_ranges_and_linked_metadata(md::Metadata, start_offset::Int) - all_iden_ranges = NamedTuple() - all_ranges = Dict{VarName,RangeAndLinked}() - offset = start_offset - for (vn, idx) in md.idcs - is_linked = md.is_transformed[idx] - range = md.ranges[idx] .+ (start_offset - 1) - if AbstractPPL.getoptic(vn) === identity - all_iden_ranges = merge( - all_iden_ranges, - NamedTuple((AbstractPPL.getsym(vn) => RangeAndLinked(range, is_linked),)), - ) - else - all_ranges[vn] = RangeAndLinked(range, is_linked) - end - offset += length(range) + +accumulator_name(::Type{<:RangeLinkedValueAcc}) = :RangeLinkedValueAcc +accumulate_observe!!(acc::RangeLinkedValueAcc, dist, val, vn) = acc +function accumulate_assume!!( + acc::RangeLinkedValueAcc, val, logjac, vn::VarName{sym}, dist::Distribution +) where {sym} + link_this_vn = if acc.should_link isa Bool + acc.should_link + else + # Set{VarName} + any(should_link_vn -> subsumes(should_link_vn, vn), acc.should_link) + end + val = if link_this_vn + to_linked_vec_transform(dist)(val) + else + to_vec_transform(dist)(val) + end + new_values = vcat(acc.values, val) + len = length(val) + range = (acc.current_index):(acc.current_index + len - 1) + ral = RangeAndLinked(range, link_this_vn) + iden_varnames, other_varnames = if getoptic(vn) === identity + merge(acc.iden_varname_ranges, (sym => ral,)), acc.varname_ranges + else + acc.varname_ranges[vn] = ral + acc.iden_varname_ranges, acc.varname_ranges end - return all_iden_ranges, all_ranges, offset + return RangeLinkedValueAcc( + acc.should_link, acc.current_index + len, iden_varnames, other_varnames, new_values + ) end -function get_ranges_and_linked_metadata(vnv::VarNamedVector, start_offset::Int) - all_iden_ranges = NamedTuple() - all_ranges = Dict{VarName,RangeAndLinked}() - offset = start_offset - for (vn, idx) in vnv.varname_to_index - is_linked = vnv.is_unconstrained[idx] - range = vnv.ranges[idx] .+ (start_offset - 1) - if AbstractPPL.getoptic(vn) === identity - all_iden_ranges = merge( - all_iden_ranges, - NamedTuple((AbstractPPL.getsym(vn) => RangeAndLinked(range, is_linked),)), - ) - else - all_ranges[vn] = RangeAndLinked(range, is_linked) - end - offset += length(range) +function Base.copy(acc::RangeLinkedValueAcc) + return RangeLinkedValueAcc( + acc.should_link, + acc.current_index, + acc.iden_varname_ranges, + copy(acc.varname_ranges), + copy(acc.values), + ) +end +_zero(acc::RangeLinkedValueAcc) = RangeLinkedValueAcc(acc.should_link) +reset(acc::RangeLinkedValueAcc) = _zero(acc) +split(acc::RangeLinkedValueAcc) = _zero(acc) +function combine(acc1::RangeLinkedValueAcc, acc2::RangeLinkedValueAcc) + new_values = vcat(acc1.values, acc2.values) + new_current_index = acc1.current_index + acc2.current_index - 1 + acc2_iden_varnames_shifted = NamedTuple( + k => RangeAndLinked((ral.range .+ (acc1.current_index - 1)), ral.is_linked) for + (k, ral) in pairs(acc2.iden_varname_ranges) + ) + new_iden_varname_ranges = merge(acc1.iden_varname_ranges, acc2_iden_varnames_shifted) + acc2_varname_ranges_shifted = Dict{VarName,RangeAndLinked}() + for (k, ral) in acc2.varname_ranges + acc2_varname_ranges_shifted[k] = RangeAndLinked( + (ral.range .+ (acc1.current_index - 1)), ral.is_linked + ) end - return all_iden_ranges, all_ranges, offset + new_varname_ranges = merge(acc1.varname_ranges, acc2_varname_ranges_shifted) + return RangeLinkedValueAcc( + # TODO: using acc1.should_link is not really 'correct', but `should_link` only + # affects model evaluation and `combine` only runs at the end of model evaluation, + # so it shouldn't matter + acc1.should_link, + new_current_index, + new_iden_varname_ranges, + new_varname_ranges, + new_values, + ) end diff --git a/test/ext/DynamicPPLForwardDiffExt.jl b/test/ext/DynamicPPLForwardDiffExt.jl index 44db66296..b58c3e7bc 100644 --- a/test/ext/DynamicPPLForwardDiffExt.jl +++ b/test/ext/DynamicPPLForwardDiffExt.jl @@ -14,16 +14,17 @@ using Test: @test, @testset @model f() = x ~ MvNormal(zeros(MODEL_SIZE), I) model = f() varinfo = VarInfo(model) + x = varinfo[:] @testset "Chunk size setting" for chunksize in (nothing, 0) base_adtype = AutoForwardDiff(; chunksize=chunksize) - new_adtype = DynamicPPL.tweak_adtype(base_adtype, model, varinfo) + new_adtype = DynamicPPL.tweak_adtype(base_adtype, model, x) @test new_adtype isa AutoForwardDiff{MODEL_SIZE} end @testset "Tag setting" begin base_adtype = AutoForwardDiff() - new_adtype = DynamicPPL.tweak_adtype(base_adtype, model, varinfo) + new_adtype = DynamicPPL.tweak_adtype(base_adtype, model, x) @test new_adtype.tag isa ForwardDiff.Tag{DynamicPPL.DynamicPPLTag} end end