From 8c3d30f78b375e00c1ed4fba75b0d8d0414e43a0 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Tue, 21 Oct 2025 18:08:25 +0100 Subject: [PATCH 01/45] v0.39 --- HISTORY.md | 2 ++ Project.toml | 2 +- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/HISTORY.md b/HISTORY.md index 57ccaecd1..8165317f6 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -1,5 +1,7 @@ # DynamicPPL Changelog +## 0.39.0 + ## 0.38.0 ### Breaking changes diff --git a/Project.toml b/Project.toml index 2fe65fd7b..7f58083bf 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "DynamicPPL" uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8" -version = "0.38.0" +version = "0.39.0" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" From 7300c224dc0fbac2febea2b6185c74f48181dcc9 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Fri, 31 Oct 2025 18:48:58 +0000 Subject: [PATCH 02/45] Update DPPL compats for benchmarks and docs --- benchmarks/Project.toml | 2 +- docs/Project.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/benchmarks/Project.toml b/benchmarks/Project.toml index 0d4e9a654..55ca81da0 100644 --- a/benchmarks/Project.toml +++ b/benchmarks/Project.toml @@ -23,7 +23,7 @@ DynamicPPL = {path = "../"} ADTypes = "1.14.0" BenchmarkTools = "1.6.0" Distributions = "0.25.117" -DynamicPPL = "0.38" +DynamicPPL = "0.39" Enzyme = "0.13" ForwardDiff = "0.10.38, 1" LogDensityProblems = "2.1.2" diff --git a/docs/Project.toml b/docs/Project.toml index fed06ebde..69e0a4c5a 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -19,7 +19,7 @@ Accessors = "0.1" Distributions = "0.25" Documenter = "1" DocumenterMermaid = "0.1, 0.2" -DynamicPPL = "0.38" +DynamicPPL = "0.39" FillArrays = "0.13, 1" ForwardDiff = "0.10, 1" JET = "0.9, 0.10, 0.11" From 79150baf0a70380f46d9573746a900f7c38bb370 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Tue, 4 Nov 2025 17:31:18 +0000 Subject: [PATCH 03/45] remove merge conflict markers --- HISTORY.md | 2 -- 1 file changed, 2 deletions(-) diff --git a/HISTORY.md b/HISTORY.md index cd4cbc767..45be1772d 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -19,8 +19,6 @@ The generic method `returned(::Model, values, keys)` is deprecated and will be r Added a compatibility entry for JET@0.11. -> > > > > > > main - ## 0.38.1 Added `from_linked_vec_transform` and `from_vec_transform` methods for `ProductNamedTupleDistribution`. From 4ca95281ebbe451f99813f009bfc9c7140330a45 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Tue, 11 Nov 2025 16:24:14 +0000 Subject: [PATCH 04/45] Remove `NodeTrait` (#1133) * Remove NodeTrait * Changelog * Fix exports * docs * fix a bug * Fix doctests * Fix test * tweak changelog --- HISTORY.md | 17 +++++++ docs/src/api.md | 46 +++++++++++++---- src/DynamicPPL.jl | 13 +++-- src/contexts.jl | 82 ++++++++++++------------------ src/contexts/conditionfix.jl | 92 +++++++++++----------------------- src/contexts/default.jl | 1 - src/contexts/init.jl | 1 - src/contexts/prefix.jl | 17 ++----- src/contexts/transformation.jl | 1 - src/model.jl | 4 +- src/test_utils/contexts.jl | 19 ++----- test/contexts.jl | 36 +++++++------ 12 files changed, 154 insertions(+), 175 deletions(-) diff --git a/HISTORY.md b/HISTORY.md index 613957c33..f181897f7 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -2,6 +2,23 @@ ## 0.39.0 +### Breaking changes + +#### Parent and leaf contexts + +The `DynamicPPL.NodeTrait` function has been removed. +Instead of implementing this, parent contexts should subtype `DynamicPPL.AbstractParentContext`. +This is an abstract type which requires you to overload two functions, `DynamicPPL.childcontext` and `DynamicPPL.setchildcontext`. + +There should generally be few reasons to define your own parent contexts (the only one we are aware of, outside of DynamicPPL itself, is `Turing.Inference.GibbsContext`), so this change should not really affect users. + +Leaf contexts require no changes, apart from a removal of the `NodeTrait` function. + +`ConditionContext` and `PrefixContext` are no longer exported. +You should not need to use these directly, please use `AbstractPPL.condition` and `DynamicPPL.prefix` instead. + +#### Miscellaneous + Removed the method `returned(::Model, values, keys)`; please use `returned(::Model, ::AbstractDict{<:VarName})` instead. ## 0.38.9 diff --git a/docs/src/api.md b/docs/src/api.md index bbe39fb73..63dafdfca 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -352,13 +352,6 @@ Base.empty! SimpleVarInfo ``` -### Tilde-pipeline - -```@docs -tilde_assume!! -tilde_observe!! -``` - ### Accumulators The subtypes of [`AbstractVarInfo`](@ref) store the cumulative log prior and log likelihood, and sometimes other variables that change during executing, in what are called accumulators. @@ -463,15 +456,48 @@ By default, it does not perform any actual sampling: it only evaluates the model If you wish to sample new values, see the section on [VarInfo initialisation](#VarInfo-initialisation) just below this. The behaviour of a model execution can be changed with evaluation contexts, which are a field of the model. -Contexts are subtypes of `AbstractPPL.AbstractContext`. + +All contexts are subtypes of `AbstractPPL.AbstractContext`. + +Contexts are split into two kinds: + +**Leaf contexts**: These are the most important contexts as they ultimately decide how model evaluation proceeds. +For example, `DefaultContext` evaluates the model using values stored inside a VarInfo's metadata, whereas `InitContext` obtains new values either by sampling or from a known set of parameters. +DynamicPPL has more leaf contexts which are used for internal purposes, but these are the two that are exported. ```@docs DefaultContext -PrefixContext -ConditionContext InitContext ``` +To implement a leaf context, you need to subtype `AbstractPPL.AbstractContext` and implement the `tilde_assume!!` and `tilde_observe!!` methods for your context. + +```@docs +tilde_assume!! +tilde_observe!! +``` + +**Parent contexts**: These essentially act as 'modifiers' for leaf contexts. +For example, `PrefixContext` adds a prefix to all variable names during evaluation, while `ConditionContext` marks certain variables as observed. + +To implement a parent context, you have to subtype `DynamicPPL.AbstractParentContext`, and implement the `childcontext` and `setchildcontext` methods. +If needed, you can also implement `tilde_assume!!` and `tilde_observe!!` for your context. +This is optional; the default implementation is to simply delegate to the child context. + +```@docs +AbstractParentContext +childcontext +setchildcontext +``` + +Since contexts form a tree structure, these functions are automatically defined for manipulating context stacks. +They are mainly useful for modifying the fundamental behaviour (i.e. the leaf context), without affecting any of the modifiers (i.e. parent contexts). + +```@docs +leafcontext +setleafcontext +``` + ### VarInfo initialisation The function `init!!` is used to initialise, or overwrite, values in a VarInfo. diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index e66f3fe11..c43bd89d5 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -94,16 +94,21 @@ export AbstractVarInfo, values_as_in_model, # LogDensityFunction LogDensityFunction, - # Contexts + # Leaf contexts + AbstractContext, contextualize, DefaultContext, - PrefixContext, - ConditionContext, + InitContext, + # Parent contexts + AbstractParentContext, + childcontext, + setchildcontext, + leafcontext, + setleafcontext, # Tilde pipeline tilde_assume!!, tilde_observe!!, # Initialisation - InitContext, AbstractInitStrategy, InitFromPrior, InitFromUniform, diff --git a/src/contexts.jl b/src/contexts.jl index 32a236e8e..46c5b8855 100644 --- a/src/contexts.jl +++ b/src/contexts.jl @@ -1,48 +1,32 @@ """ - NodeTrait(context) - NodeTrait(f, context) + AbstractParentContext -Specifies the role of `context` in the context-tree. +An abstract context that has a child context. -The officially supported traits are: -- `IsLeaf`: `context` does not have any decendants. -- `IsParent`: `context` has a child context to which we often defer. - Expects the following methods to be implemented: - - [`childcontext`](@ref) - - [`setchildcontext`](@ref) -""" -abstract type NodeTrait end -NodeTrait(_, context) = NodeTrait(context) - -""" - IsLeaf - -Specifies that the context is a leaf in the context-tree. -""" -struct IsLeaf <: NodeTrait end -""" - IsParent +Subtypes of `AbstractParentContext` must implement the following interface: -Specifies that the context is a parent in the context-tree. +- `DynamicPPL.childcontext(context::AbstractParentContext)`: Return the child context. +- `DynamicPPL.setchildcontext(parent::AbstractParentContext, child::AbstractContext)`: Reconstruct + `parent` but now using `child` as its child context. """ -struct IsParent <: NodeTrait end +abstract type AbstractParentContext <: AbstractContext end """ - childcontext(context) + childcontext(context::AbstractParentContext) Return the descendant context of `context`. """ childcontext """ - setchildcontext(parent::AbstractContext, child::AbstractContext) + setchildcontext(parent::AbstractParentContext, child::AbstractContext) Reconstruct `parent` but now using `child` is its [`childcontext`](@ref), effectively updating the child context. # Examples ```jldoctest -julia> using DynamicPPL: DynamicTransformationContext +julia> using DynamicPPL: DynamicTransformationContext, ConditionContext julia> ctx = ConditionContext((; a = 1)); @@ -60,12 +44,11 @@ setchildcontext """ leafcontext(context::AbstractContext) -Return the leaf of `context`, i.e. the first descendant context that `IsLeaf`. +Return the leaf of `context`, i.e. the first descendant context that is not an +`AbstractParentContext`. """ -leafcontext(context::AbstractContext) = - leafcontext(NodeTrait(leafcontext, context), context) -leafcontext(::IsLeaf, context::AbstractContext) = context -leafcontext(::IsParent, context::AbstractContext) = leafcontext(childcontext(context)) +leafcontext(context::AbstractContext) = context +leafcontext(context::AbstractParentContext) = leafcontext(childcontext(context)) """ setleafcontext(left::AbstractContext, right::AbstractContext) @@ -80,12 +63,10 @@ original leaf context of `left`. ```jldoctest julia> using DynamicPPL: leafcontext, setleafcontext, childcontext, setchildcontext, AbstractContext, DynamicTransformationContext -julia> struct ParentContext{C} <: AbstractContext +julia> struct ParentContext{C} <: AbstractParentContext context::C end -julia> DynamicPPL.NodeTrait(::ParentContext) = DynamicPPL.IsParent() - julia> DynamicPPL.childcontext(context::ParentContext) = context.context julia> DynamicPPL.setchildcontext(::ParentContext, child) = ParentContext(child) @@ -104,21 +85,10 @@ julia> # Append another parent context. ParentContext(ParentContext(ParentContext(DefaultContext()))) ``` """ -function setleafcontext(left::AbstractContext, right::AbstractContext) - return setleafcontext( - NodeTrait(setleafcontext, left), NodeTrait(setleafcontext, right), left, right - ) -end -function setleafcontext( - ::IsParent, ::IsParent, left::AbstractContext, right::AbstractContext -) +function setleafcontext(left::AbstractParentContext, right::AbstractContext) return setchildcontext(left, setleafcontext(childcontext(left), right)) end -function setleafcontext(::IsParent, ::IsLeaf, left::AbstractContext, right::AbstractContext) - return setchildcontext(left, setleafcontext(childcontext(left), right)) -end -setleafcontext(::IsLeaf, ::IsParent, left::AbstractContext, right::AbstractContext) = right -setleafcontext(::IsLeaf, ::IsLeaf, left::AbstractContext, right::AbstractContext) = right +setleafcontext(::AbstractContext, right::AbstractContext) = right """ DynamicPPL.tilde_assume!!( @@ -138,10 +108,15 @@ This function should return a tuple `(x, vi)`, where `x` is the sampled value (w must be in unlinked space!) and `vi` is the updated VarInfo. """ function tilde_assume!!( - context::AbstractContext, right::Distribution, vn::VarName, vi::AbstractVarInfo + context::AbstractParentContext, right::Distribution, vn::VarName, vi::AbstractVarInfo ) return tilde_assume!!(childcontext(context), right, vn, vi) end +function tilde_assume!!( + context::AbstractContext, ::Distribution, ::VarName, ::AbstractVarInfo +) + return error("tilde_assume!! not implemented for context of type $(typeof(context))") +end """ DynamicPPL.tilde_observe!!( @@ -171,7 +146,7 @@ This function should return a tuple `(left, vi)`, where `left` is the same as th `vi` is the updated VarInfo. """ function tilde_observe!!( - context::AbstractContext, + context::AbstractParentContext, right::Distribution, left, vn::Union{VarName,Nothing}, @@ -179,3 +154,12 @@ function tilde_observe!!( ) return tilde_observe!!(childcontext(context), right, left, vn, vi) end +function tilde_observe!!( + context::AbstractContext, + ::Distribution, + ::Any, + ::Union{VarName,Nothing}, + ::AbstractVarInfo, +) + return error("tilde_observe!! not implemented for context of type $(typeof(context))") +end diff --git a/src/contexts/conditionfix.jl b/src/contexts/conditionfix.jl index d3802de85..7a34db5cb 100644 --- a/src/contexts/conditionfix.jl +++ b/src/contexts/conditionfix.jl @@ -11,7 +11,7 @@ when there are varnames that cannot be represented as symbols, e.g. """ struct ConditionContext{ Values<:Union{NamedTuple,AbstractDict{<:VarName}},Ctx<:AbstractContext -} <: AbstractContext +} <: AbstractParentContext values::Values context::Ctx end @@ -41,9 +41,10 @@ function Base.show(io::IO, context::ConditionContext) return print(io, "ConditionContext($(context.values), $(childcontext(context)))") end -NodeTrait(::ConditionContext) = IsParent() childcontext(context::ConditionContext) = context.context -setchildcontext(parent::ConditionContext, child) = ConditionContext(parent.values, child) +function setchildcontext(parent::ConditionContext, child::AbstractContext) + return ConditionContext(parent.values, child) +end """ hasconditioned(context::AbstractContext, vn::VarName) @@ -76,11 +77,8 @@ Return `true` if `vn` is found in `context` or any of its descendants. This is contrast to [`hasconditioned(::AbstractContext, ::VarName)`](@ref) which only checks for `vn` in `context`, not recursively checking if `vn` is in any of its descendants. """ -function hasconditioned_nested(context::AbstractContext, vn) - return hasconditioned_nested(NodeTrait(hasconditioned_nested, context), context, vn) -end -hasconditioned_nested(::IsLeaf, context, vn) = hasconditioned(context, vn) -function hasconditioned_nested(::IsParent, context, vn) +hasconditioned_nested(context::AbstractContext, vn) = hasconditioned(context, vn) +function hasconditioned_nested(context::AbstractParentContext, vn) return hasconditioned(context, vn) || hasconditioned_nested(childcontext(context), vn) end function hasconditioned_nested(context::PrefixContext, vn) @@ -96,15 +94,12 @@ This is contrast to [`getconditioned`](@ref) which only returns the value `vn` i not recursively looking into its descendants. """ function getconditioned_nested(context::AbstractContext, vn) - return getconditioned_nested(NodeTrait(getconditioned_nested, context), context, vn) -end -function getconditioned_nested(::IsLeaf, context, vn) return error("context $(context) does not contain value for $vn") end function getconditioned_nested(context::PrefixContext, vn) return getconditioned_nested(collapse_prefix_stack(context), vn) end -function getconditioned_nested(::IsParent, context, vn) +function getconditioned_nested(context::AbstractParentContext, vn) return if hasconditioned(context, vn) getconditioned(context, vn) else @@ -113,7 +108,7 @@ function getconditioned_nested(::IsParent, context, vn) end """ - decondition(context::AbstractContext, syms...) + decondition_context(context::AbstractContext, syms...) Return `context` but with `syms` no longer conditioned on. @@ -121,13 +116,10 @@ Note that this recursively traverses contexts, deconditioning all along the way. See also: [`condition`](@ref) """ -decondition_context(::IsLeaf, context, args...) = context -function decondition_context(::IsParent, context, args...) +decondition_context(context::AbstractContext, args...) = context +function decondition_context(context::AbstractParentContext, args...) return setchildcontext(context, decondition_context(childcontext(context), args...)) end -function decondition_context(context, args...) - return decondition_context(NodeTrait(context), context, args...) -end function decondition_context(context::ConditionContext) return decondition_context(childcontext(context)) end @@ -160,11 +152,8 @@ Return `NamedTuple` of values that are conditioned on under context`. Note that this will recursively traverse the context stack and return a merged version of the condition values. """ -function conditioned(context::AbstractContext) - return conditioned(NodeTrait(conditioned, context), context) -end -conditioned(::IsLeaf, context) = NamedTuple() -conditioned(::IsParent, context) = conditioned(childcontext(context)) +conditioned(::AbstractContext) = NamedTuple() +conditioned(context::AbstractParentContext) = conditioned(childcontext(context)) function conditioned(context::ConditionContext) # Note the order of arguments to `merge`. The behavior of the rest of DPPL # is that the outermost `context` takes precendence, hence when resolving @@ -176,7 +165,7 @@ function conditioned(context::PrefixContext) return conditioned(collapse_prefix_stack(context)) end -struct FixedContext{Values,Ctx<:AbstractContext} <: AbstractContext +struct FixedContext{Values,Ctx<:AbstractContext} <: AbstractParentContext values::Values context::Ctx end @@ -197,16 +186,17 @@ function Base.show(io::IO, context::FixedContext) return print(io, "FixedContext($(context.values), $(childcontext(context)))") end -NodeTrait(::FixedContext) = IsParent() childcontext(context::FixedContext) = context.context -setchildcontext(parent::FixedContext, child) = FixedContext(parent.values, child) +function setchildcontext(parent::FixedContext, child::AbstractContext) + return FixedContext(parent.values, child) +end """ hasfixed(context::AbstractContext, vn::VarName) Return `true` if a fixed value for `vn` is found in `context`. """ -hasfixed(context::AbstractContext, vn::VarName) = false +hasfixed(::AbstractContext, ::VarName) = false hasfixed(context::FixedContext, vn::VarName) = hasvalue(context.values, vn) function hasfixed(context::FixedContext, vns::AbstractArray{<:VarName}) return all(Base.Fix1(hasvalue, context.values), vns) @@ -230,11 +220,8 @@ Return `true` if a fixed value for `vn` is found in `context` or any of its desc This is contrast to [`hasfixed(::AbstractContext, ::VarName)`](@ref) which only checks for `vn` in `context`, not recursively checking if `vn` is in any of its descendants. """ -function hasfixed_nested(context::AbstractContext, vn) - return hasfixed_nested(NodeTrait(hasfixed_nested, context), context, vn) -end -hasfixed_nested(::IsLeaf, context, vn) = hasfixed(context, vn) -function hasfixed_nested(::IsParent, context, vn) +hasfixed_nested(context::AbstractContext, vn) = hasfixed(context, vn) +function hasfixed_nested(context::AbstractParentContext, vn) return hasfixed(context, vn) || hasfixed_nested(childcontext(context), vn) end function hasfixed_nested(context::PrefixContext, vn) @@ -250,15 +237,12 @@ This is contrast to [`getfixed`](@ref) which only returns the value `vn` in `con not recursively looking into its descendants. """ function getfixed_nested(context::AbstractContext, vn) - return getfixed_nested(NodeTrait(getfixed_nested, context), context, vn) -end -function getfixed_nested(::IsLeaf, context, vn) return error("context $(context) does not contain value for $vn") end function getfixed_nested(context::PrefixContext, vn) return getfixed_nested(collapse_prefix_stack(context), vn) end -function getfixed_nested(::IsParent, context, vn) +function getfixed_nested(context::AbstractParentContext, vn) return if hasfixed(context, vn) getfixed(context, vn) else @@ -283,7 +267,7 @@ end function fix(values::NTuple{<:Any,<:Pair{<:VarName}}) return fix(DefaultContext(), values) end -fix(context::AbstractContext, values::NamedTuple{()}) = context +fix(context::AbstractContext, ::NamedTuple{()}) = context function fix(context::AbstractContext, values::Union{AbstractDict,NamedTuple}) return FixedContext(values, context) end @@ -306,13 +290,10 @@ Note that this recursively traverses contexts, unfixing all along the way. See also: [`fix`](@ref) """ -unfix(::IsLeaf, context, args...) = context -function unfix(::IsParent, context, args...) +unfix(context::AbstractContext, args...) = context +function unfix(context::AbstractParentContext, args...) return setchildcontext(context, unfix(childcontext(context), args...)) end -function unfix(context, args...) - return unfix(NodeTrait(context), context, args...) -end function unfix(context::FixedContext) return unfix(childcontext(context)) end @@ -341,9 +322,8 @@ Return the values that are fixed under `context`. Note that this will recursively traverse the context stack and return a merged version of the fix values. """ -fixed(context::AbstractContext) = fixed(NodeTrait(fixed, context), context) -fixed(::IsLeaf, context) = NamedTuple() -fixed(::IsParent, context) = fixed(childcontext(context)) +fixed(::AbstractContext) = NamedTuple() +fixed(context::AbstractParentContext) = fixed(childcontext(context)) function fixed(context::FixedContext) # Note the order of arguments to `merge`. The behavior of the rest of DPPL # is that the outermost `context` takes precendence, hence when resolving @@ -374,7 +354,7 @@ topic](https://turinglang.org/DynamicPPL.jl/previews/PR892/internals/submodel_co which explains this in much more detail. ```jldoctest -julia> using DynamicPPL: collapse_prefix_stack +julia> using DynamicPPL: collapse_prefix_stack, PrefixContext, ConditionContext julia> c1 = PrefixContext(@varname(a), ConditionContext((x=1, ))); @@ -403,11 +383,8 @@ function collapse_prefix_stack(context::PrefixContext) # depth of the context stack. return prefix_cond_and_fixed_variables(collapsed, context.vn_prefix) end -function collapse_prefix_stack(context::AbstractContext) - return collapse_prefix_stack(NodeTrait(collapse_prefix_stack, context), context) -end -collapse_prefix_stack(::IsLeaf, context) = context -function collapse_prefix_stack(::IsParent, context) +collapse_prefix_stack(context::AbstractContext) = context +function collapse_prefix_stack(context::AbstractParentContext) new_child_context = collapse_prefix_stack(childcontext(context)) return setchildcontext(context, new_child_context) end @@ -448,19 +425,10 @@ function prefix_cond_and_fixed_variables(ctx::FixedContext, prefix::VarName) prefixed_child_ctx = prefix_cond_and_fixed_variables(childcontext(ctx), prefix) return FixedContext(prefixed_vn_dict, prefixed_child_ctx) end -function prefix_cond_and_fixed_variables(c::AbstractContext, prefix::VarName) - return prefix_cond_and_fixed_variables( - NodeTrait(prefix_cond_and_fixed_variables, c), c, prefix - ) -end -function prefix_cond_and_fixed_variables( - ::IsLeaf, context::AbstractContext, prefix::VarName -) +function prefix_cond_and_fixed_variables(context::AbstractContext, ::VarName) return context end -function prefix_cond_and_fixed_variables( - ::IsParent, context::AbstractContext, prefix::VarName -) +function prefix_cond_and_fixed_variables(context::AbstractParentContext, prefix::VarName) return setchildcontext( context, prefix_cond_and_fixed_variables(childcontext(context), prefix) ) diff --git a/src/contexts/default.jl b/src/contexts/default.jl index ec21e1a56..3cafe39f1 100644 --- a/src/contexts/default.jl +++ b/src/contexts/default.jl @@ -17,7 +17,6 @@ with `DefaultContext` means 'calculating the log-probability associated with the in the `AbstractVarInfo`'. """ struct DefaultContext <: AbstractContext end -NodeTrait(::DefaultContext) = IsLeaf() """ DynamicPPL.tilde_assume!!( diff --git a/src/contexts/init.jl b/src/contexts/init.jl index 44dbc5508..83507353f 100644 --- a/src/contexts/init.jl +++ b/src/contexts/init.jl @@ -150,7 +150,6 @@ struct InitContext{R<:Random.AbstractRNG,S<:AbstractInitStrategy} <: AbstractCon return InitContext(Random.default_rng(), strategy) end end -NodeTrait(::InitContext) = IsLeaf() function tilde_assume!!( ctx::InitContext, dist::Distribution, vn::VarName, vi::AbstractVarInfo diff --git a/src/contexts/prefix.jl b/src/contexts/prefix.jl index 24615e683..45307874a 100644 --- a/src/contexts/prefix.jl +++ b/src/contexts/prefix.jl @@ -13,7 +13,7 @@ unique. See also: [`to_submodel`](@ref) """ -struct PrefixContext{Tvn<:VarName,C<:AbstractContext} <: AbstractContext +struct PrefixContext{Tvn<:VarName,C<:AbstractContext} <: AbstractParentContext vn_prefix::Tvn context::C end @@ -23,7 +23,6 @@ function PrefixContext(::Val{sym}, context::AbstractContext) where {sym} end PrefixContext(::Val{sym}) where {sym} = PrefixContext(VarName{sym}()) -NodeTrait(::PrefixContext) = IsParent() childcontext(context::PrefixContext) = context.context function setchildcontext(ctx::PrefixContext, child::AbstractContext) return PrefixContext(ctx.vn_prefix, child) @@ -37,11 +36,8 @@ Apply the prefixes in the context `ctx` to the variable name `vn`. function prefix(ctx::PrefixContext, vn::VarName) return AbstractPPL.prefix(prefix(childcontext(ctx), vn), ctx.vn_prefix) end -function prefix(ctx::AbstractContext, vn::VarName) - return prefix(NodeTrait(ctx), ctx, vn) -end -prefix(::IsLeaf, ::AbstractContext, vn::VarName) = vn -function prefix(::IsParent, ctx::AbstractContext, vn::VarName) +prefix(::AbstractContext, vn::VarName) = vn +function prefix(ctx::AbstractParentContext, vn::VarName) return prefix(childcontext(ctx), vn) end @@ -72,11 +68,8 @@ function prefix_and_strip_contexts(ctx::PrefixContext, vn::VarName) ) return AbstractPPL.prefix(vn_prefixed, ctx.vn_prefix), child_context_without_prefixes end -function prefix_and_strip_contexts(ctx::AbstractContext, vn::VarName) - return prefix_and_strip_contexts(NodeTrait(ctx), ctx, vn) -end -prefix_and_strip_contexts(::IsLeaf, ctx::AbstractContext, vn::VarName) = (vn, ctx) -function prefix_and_strip_contexts(::IsParent, ctx::AbstractContext, vn::VarName) +prefix_and_strip_contexts(ctx::AbstractContext, vn::VarName) = (vn, ctx) +function prefix_and_strip_contexts(ctx::AbstractParentContext, vn::VarName) vn, new_ctx = prefix_and_strip_contexts(childcontext(ctx), vn) return vn, setchildcontext(ctx, new_ctx) end diff --git a/src/contexts/transformation.jl b/src/contexts/transformation.jl index 5153f7857..c2eee2863 100644 --- a/src/contexts/transformation.jl +++ b/src/contexts/transformation.jl @@ -10,7 +10,6 @@ Note that some `AbstractVarInfo` types, must notably `VarInfo`, override the how to do the transformation, used by e.g. `SimpleVarInfo`. """ struct DynamicTransformationContext{isinverse} <: AbstractContext end -NodeTrait(::DynamicTransformationContext) = IsLeaf() function tilde_assume!!( ::DynamicTransformationContext{isinverse}, diff --git a/src/model.jl b/src/model.jl index ec98b90cd..94fcd9fd4 100644 --- a/src/model.jl +++ b/src/model.jl @@ -427,7 +427,7 @@ Return the conditioned values in `model`. ```jldoctest julia> using Distributions -julia> using DynamicPPL: conditioned, contextualize +julia> using DynamicPPL: conditioned, contextualize, PrefixContext, ConditionContext julia> @model function demo() m ~ Normal() @@ -770,7 +770,7 @@ Return the fixed values in `model`. ```jldoctest julia> using Distributions -julia> using DynamicPPL: fixed, contextualize +julia> using DynamicPPL: fixed, contextualize, PrefixContext julia> @model function demo() m ~ Normal() diff --git a/src/test_utils/contexts.jl b/src/test_utils/contexts.jl index aae2e4ec6..c48d2ddfd 100644 --- a/src/test_utils/contexts.jl +++ b/src/test_utils/contexts.jl @@ -4,11 +4,10 @@ # Utilities for testing contexts. # Dummy context to test nested behaviors. -struct TestParentContext{C<:DynamicPPL.AbstractContext} <: DynamicPPL.AbstractContext +struct TestParentContext{C<:DynamicPPL.AbstractContext} <: DynamicPPL.AbstractParentContext context::C end TestParentContext() = TestParentContext(DefaultContext()) -DynamicPPL.NodeTrait(::TestParentContext) = DynamicPPL.IsParent() DynamicPPL.childcontext(context::TestParentContext) = context.context DynamicPPL.setchildcontext(::TestParentContext, child) = TestParentContext(child) function Base.show(io::IO, c::TestParentContext) @@ -25,19 +24,13 @@ This method ensures that `context` - Correctly implements the tilde-pipeline. """ function test_context(context::DynamicPPL.AbstractContext, model::DynamicPPL.Model) - node_trait = DynamicPPL.NodeTrait(context) - if node_trait isa DynamicPPL.IsLeaf - test_leaf_context(context, model) - elseif node_trait isa DynamicPPL.IsParent - test_parent_context(context, model) - else - error("Invalid NodeTrait: $node_trait") - end + return test_leaf_context(context, model) +end +function test_context(context::DynamicPPL.AbstractParentContext, model::DynamicPPL.Model) + return test_parent_context(context, model) end function test_leaf_context(context::DynamicPPL.AbstractContext, model::DynamicPPL.Model) - @test DynamicPPL.NodeTrait(context) isa DynamicPPL.IsLeaf - # Note that for a leaf context we can't assume that it will work with an # empty VarInfo. (For example, DefaultContext will error with empty # varinfos.) Thus we only test evaluation with VarInfos that are already @@ -57,8 +50,6 @@ function test_leaf_context(context::DynamicPPL.AbstractContext, model::DynamicPP end function test_parent_context(context::DynamicPPL.AbstractContext, model::DynamicPPL.Model) - @test DynamicPPL.NodeTrait(context) isa DynamicPPL.IsParent - @testset "get/set leaf and child contexts" begin # Ensure we're using a different leaf context than the current. leafcontext_new = if DynamicPPL.leafcontext(context) isa DefaultContext diff --git a/test/contexts.jl b/test/contexts.jl index 972d833a5..ae7332a43 100644 --- a/test/contexts.jl +++ b/test/contexts.jl @@ -6,10 +6,9 @@ using DynamicPPL: childcontext, setchildcontext, AbstractContext, - NodeTrait, - IsLeaf, - IsParent, + AbstractParentContext, contextual_isassumption, + PrefixContext, FixedContext, ConditionContext, decondition_context, @@ -25,22 +24,21 @@ using LinearAlgebra: I using Random: Xoshiro # TODO: Should we maybe put this in DPPL itself? +function Base.iterate(context::AbstractParentContext) + return context, childcontext(context) +end function Base.iterate(context::AbstractContext) - if NodeTrait(context) isa IsLeaf - return nothing - end - - return context, context + return context, nothing end -function Base.iterate(_::AbstractContext, context::AbstractContext) - return _iterate(NodeTrait(context), context) +function Base.iterate(::AbstractContext, state::AbstractParentContext) + return state, childcontext(state) end -_iterate(::IsLeaf, context) = nothing -function _iterate(::IsParent, context) - child = childcontext(context) - return child, child +function Base.iterate(::AbstractContext, state::AbstractContext) + return state, nothing +end +function Base.iterate(::AbstractContext, state::Nothing) + return nothing end - Base.IteratorSize(::Type{<:AbstractContext}) = Base.SizeUnknown() Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown() @@ -347,11 +345,11 @@ Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown() @testset "collapse_prefix_stack" begin # Utility function to make sure that there are no PrefixContexts in # the context stack. - function has_no_prefixcontexts(ctx::AbstractContext) - return !(ctx isa PrefixContext) && ( - NodeTrait(ctx) isa IsLeaf || has_no_prefixcontexts(childcontext(ctx)) - ) + has_no_prefixcontexts(::PrefixContext) = false + function has_no_prefixcontexts(ctx::AbstractParentContext) + return has_no_prefixcontexts(childcontext(ctx)) end + has_no_prefixcontexts(::AbstractContext) = true # Prefix -> Condition c1 = PrefixContext(@varname(a), ConditionContext((c=1, d=2))) From 535ce4f68e8f162fb382fb5d55eae0238d332e7a Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Thu, 13 Nov 2025 13:30:43 +0000 Subject: [PATCH 05/45] FastLDF / InitContext unified (#1132) * Fast Log Density Function * Make it work with AD * Optimise performance for identity VarNames * Mark `get_range_and_linked` as having zero derivative * Update comment * make AD testing / benchmarking use FastLDF * Fix tests * Optimise away `make_evaluate_args_and_kwargs` * const func annotation * Disable benchmarks on non-typed-Metadata-VarInfo * Fix `_evaluate!!` correctly to handle submodels * Actually fix submodel evaluate * Document thoroughly and organise code * Support more VarInfos, make it thread-safe (?) * fix bug in parsing ranges from metadata/VNV * Fix get_param_eltype for TSVI * Disable Enzyme benchmark * Don't override _evaluate!!, that breaks ForwardDiff (sometimes) * Move FastLDF to experimental for now * Fix imports, add tests, etc * More test fixes * Fix imports / tests * Remove AbstractFastEvalContext * Changelog and patch bump * Add correctness tests, fix imports * Concretise parameter vector in tests * Add zero-allocation tests * Add Chairmarks as test dep * Disable allocations tests on multi-threaded * Fast InitContext (#1125) * Make InitContext work with OnlyAccsVarInfo * Do not convert NamedTuple to Dict * remove logging * Enable InitFromPrior and InitFromUniform too * Fix `infer_nested_eltype` invocation * Refactor FastLDF to use InitContext * note init breaking change * fix logjac sign * workaround Mooncake segfault * fix changelog too * Fix get_param_eltype for context stacks * Add a test for threaded observe * Export init * Remove dead code * fix transforms for pathological distributions * Tidy up loads of things * fix typed_identity spelling * fix definition order * Improve docstrings * Remove stray comment * export get_param_eltype (unfortunatley) * Add more comment * Update comment * Remove inlines, fix OAVI docstring * Improve docstrings * Simplify InitFromParams constructor * Replace map(identity, x[:]) with [i for i in x[:]] * Simplify implementation for InitContext/OAVI * Add another model to allocation tests Co-authored-by: Markus Hauru * Revert removal of dist argument (oops) * Format * Update some outdated bits of FastLDF docstring * remove underscores --------- Co-authored-by: Markus Hauru --- HISTORY.md | 15 ++ docs/src/api.md | 12 +- ext/DynamicPPLEnzymeCoreExt.jl | 13 +- ext/DynamicPPLMooncakeExt.jl | 3 + src/DynamicPPL.jl | 5 +- src/compiler.jl | 38 ++-- src/contexts/init.jl | 255 ++++++++++++++++++++---- src/experimental.jl | 2 + src/fasteval.jl | 336 ++++++++++++++++++++++++++++++++ src/model.jl | 32 ++- src/onlyaccs.jl | 42 ++++ src/utils.jl | 35 ++++ test/Project.toml | 1 + test/fasteval.jl | 233 ++++++++++++++++++++++ test/integration/enzyme/main.jl | 6 +- test/runtests.jl | 1 + 16 files changed, 955 insertions(+), 74 deletions(-) create mode 100644 src/fasteval.jl create mode 100644 src/onlyaccs.jl create mode 100644 test/fasteval.jl diff --git a/HISTORY.md b/HISTORY.md index f181897f7..0f0102ce4 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -21,6 +21,21 @@ You should not need to use these directly, please use `AbstractPPL.condition` an Removed the method `returned(::Model, values, keys)`; please use `returned(::Model, ::AbstractDict{<:VarName})` instead. +The method `DynamicPPL.init` (for implementing `AbstractInitStrategy`) now has a different signature: it must return a tuple of the generated value, plus a transform function that maps it back to unlinked space. +This is a generalisation of the previous behaviour, where `init` would always return an unlinked value (in effect forcing the transform to be the identity function). + +### Other changes + +#### FastLDF + +Added `DynamicPPL.Experimental.FastLDF`, a version of `LogDensityFunction` that provides performance improvements on the order of 2–10× for both model evaluation as well as automatic differentiation. +Exact speedups depend on the model size: larger models have less significant speedups because the bulk of the work is done in calls to `logpdf`. + +Please note that `FastLDF` is currently considered internal and its API may change without warning. +We intend to replace `LogDensityFunction` with `FastLDF` in a release in the near future, but until then we recommend not using it. + +For more information about `FastLDF`, please see https://github.com/TuringLang/DynamicPPL.jl/pull/1113 as well as the `src/fasteval.jl` file, which contains extensive comments. + ## 0.38.9 Remove warning when using Enzyme as the AD backend. diff --git a/docs/src/api.md b/docs/src/api.md index 63dafdfca..e81f18dc7 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -170,6 +170,12 @@ DynamicPPL.prefix ## Utilities +`typed_identity` is the same as `identity`, but with an overload for `with_logabsdet_jacobian` that ensures that it never errors. + +```@docs +typed_identity +``` + It is possible to manually increase (or decrease) the accumulated log likelihood or prior from within a model function. ```@docs @@ -517,10 +523,12 @@ InitFromParams ``` If you wish to write your own, you have to subtype [`DynamicPPL.AbstractInitStrategy`](@ref) and implement the `init` method. +In very rare situations, you may also need to implement `get_param_eltype`, which defines the element type of the parameters generated by the strategy. ```@docs -DynamicPPL.AbstractInitStrategy -DynamicPPL.init +AbstractInitStrategy +init +get_param_eltype ``` ### Choosing a suitable VarInfo diff --git a/ext/DynamicPPLEnzymeCoreExt.jl b/ext/DynamicPPLEnzymeCoreExt.jl index 35159636f..ef21c255b 100644 --- a/ext/DynamicPPLEnzymeCoreExt.jl +++ b/ext/DynamicPPLEnzymeCoreExt.jl @@ -1,16 +1,15 @@ module DynamicPPLEnzymeCoreExt -if isdefined(Base, :get_extension) - using DynamicPPL: DynamicPPL - using EnzymeCore -else - using ..DynamicPPL: DynamicPPL - using ..EnzymeCore -end +using DynamicPPL: DynamicPPL +using EnzymeCore # Mark is_transformed as having 0 derivative. The `nothing` return value is not significant, Enzyme # only checks whether such a method exists, and never runs it. @inline EnzymeCore.EnzymeRules.inactive(::typeof(DynamicPPL.is_transformed), args...) = nothing +# Likewise for get_range_and_linked. +@inline EnzymeCore.EnzymeRules.inactive( + ::typeof(DynamicPPL._get_range_and_linked), args... +) = nothing end diff --git a/ext/DynamicPPLMooncakeExt.jl b/ext/DynamicPPLMooncakeExt.jl index 23a3430eb..8adf66030 100644 --- a/ext/DynamicPPLMooncakeExt.jl +++ b/ext/DynamicPPLMooncakeExt.jl @@ -5,5 +5,8 @@ using Mooncake: Mooncake # This is purely an optimisation. Mooncake.@zero_derivative Mooncake.DefaultCtx Tuple{typeof(is_transformed),Vararg} +Mooncake.@zero_derivative Mooncake.DefaultCtx Tuple{ + typeof(DynamicPPL._get_range_and_linked),Vararg +} end # module diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index c43bd89d5..e9b902363 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -84,8 +84,8 @@ export AbstractVarInfo, # Compiler @model, # Utilities - init, OrderedDict, + typed_identity, # Model Model, getmissings, @@ -113,6 +113,8 @@ export AbstractVarInfo, InitFromPrior, InitFromUniform, InitFromParams, + init, + get_param_eltype, # Pseudo distributions NamedDist, NoDist, @@ -193,6 +195,7 @@ include("abstract_varinfo.jl") include("threadsafe.jl") include("varinfo.jl") include("simple_varinfo.jl") +include("onlyaccs.jl") include("compiler.jl") include("pointwise_logdensities.jl") include("logdensityfunction.jl") diff --git a/src/compiler.jl b/src/compiler.jl index badba9f9d..3324780ca 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -718,14 +718,15 @@ end # TODO(mhauru) matchingvalue has methods that can accept both types and values. Why? # TODO(mhauru) This function needs a more comprehensive docstring. """ - matchingvalue(vi, value) + matchingvalue(param_eltype, value) -Convert the `value` to the correct type for the `vi` object. +Convert the `value` to the correct type, given the element type of the parameters +being used to evaluate the model. """ -function matchingvalue(vi, value) +function matchingvalue(param_eltype, value) T = typeof(value) if hasmissing(T) - _value = convert(get_matching_type(vi, T), value) + _value = convert(get_matching_type(param_eltype, T), value) # TODO(mhauru) Why do we make a deepcopy, even though in the !hasmissing branch we # are happy to return `value` as-is? if _value === value @@ -738,29 +739,30 @@ function matchingvalue(vi, value) end end -function matchingvalue(vi, value::FloatOrArrayType) - return get_matching_type(vi, value) +function matchingvalue(param_eltype, value::FloatOrArrayType) + return get_matching_type(param_eltype, value) end -function matchingvalue(vi, ::TypeWrap{T}) where {T} - return TypeWrap{get_matching_type(vi, T)}() +function matchingvalue(param_eltype, ::TypeWrap{T}) where {T} + return TypeWrap{get_matching_type(param_eltype, T)}() end # TODO(mhauru) This function needs a more comprehensive docstring. What is it for? """ - get_matching_type(vi, ::TypeWrap{T}) where {T} + get_matching_type(param_eltype, ::TypeWrap{T}) where {T} -Get the specialized version of type `T` for `vi`. +Get the specialized version of type `T`, given an element type of the parameters +being used to evaluate the model. """ get_matching_type(_, ::Type{T}) where {T} = T -function get_matching_type(vi, ::Type{<:Union{Missing,AbstractFloat}}) - return Union{Missing,float_type_with_fallback(eltype(vi))} +function get_matching_type(param_eltype, ::Type{<:Union{Missing,AbstractFloat}}) + return Union{Missing,float_type_with_fallback(param_eltype)} end -function get_matching_type(vi, ::Type{<:AbstractFloat}) - return float_type_with_fallback(eltype(vi)) +function get_matching_type(param_eltype, ::Type{<:AbstractFloat}) + return float_type_with_fallback(param_eltype) end -function get_matching_type(vi, ::Type{<:Array{T,N}}) where {T,N} - return Array{get_matching_type(vi, T),N} +function get_matching_type(param_eltype, ::Type{<:Array{T,N}}) where {T,N} + return Array{get_matching_type(param_eltype, T),N} end -function get_matching_type(vi, ::Type{<:Array{T}}) where {T} - return Array{get_matching_type(vi, T)} +function get_matching_type(param_eltype, ::Type{<:Array{T}}) where {T} + return Array{get_matching_type(param_eltype, T)} end diff --git a/src/contexts/init.jl b/src/contexts/init.jl index 83507353f..a79969a13 100644 --- a/src/contexts/init.jl +++ b/src/contexts/init.jl @@ -1,11 +1,11 @@ """ AbstractInitStrategy -Abstract type representing the possible ways of initialising new values for -the random variables in a model (e.g., when creating a new VarInfo). +Abstract type representing the possible ways of initialising new values for the random +variables in a model (e.g., when creating a new VarInfo). -Any subtype of `AbstractInitStrategy` must implement the -[`DynamicPPL.init`](@ref) method. +Any subtype of `AbstractInitStrategy` must implement the [`DynamicPPL.init`](@ref) method, +and in some cases, [`DynamicPPL.get_param_eltype`](@ref) (see its docstring for details). """ abstract type AbstractInitStrategy end @@ -14,14 +14,60 @@ abstract type AbstractInitStrategy end Generate a new value for a random variable with the given distribution. -!!! warning "Return values must be unlinked" - The values returned by `init` must always be in the untransformed space, i.e., - they must be within the support of the original distribution. That means that, - for example, `init(rng, dist, u::InitFromUniform)` will in general return values that - are outside the range [u.lower, u.upper]. +This function must return a tuple `(x, trf)`, where + +- `x` is the generated value +- `trf` is a function that transforms the generated value back to the unlinked space. If the + value is already in unlinked space, then this should be `DynamicPPL.typed_identity`. You + can also use `Base.identity`, but if you use this, you **must** be confident that + `zero(eltype(x))` will **never** error. See the docstring of `typed_identity` for more + information. """ function init end +""" + DynamicPPL.get_param_eltype(strategy::AbstractInitStrategy) + +Return the element type of the parameters generated from the given initialisation strategy. + +The default implementation returns `Any`. However, for `InitFromParams` which provides known +parameters for evaluating the model, methods are implemented in order to return more specific +types. + +In general, if you are implementing a custom `AbstractInitStrategy`, correct behaviour can +only be guaranteed if you implement this method as well. However, quite often, the default +return value of `Any` will actually suffice. The cases where this does *not* suffice, and +where you _do_ have to manually implement `get_param_eltype`, are explained in the extended +help (see `??DynamicPPL.get_param_eltype` in the REPL). + +# Extended help + +There are a few edge cases in DynamicPPL where the element type is needed. These largely +relate to determining the element type of accumulators ahead of time (_before_ evaluation), +as well as promoting type parameters in model arguments. The classic case is when evaluating +a model with ForwardDiff: the accumulators must be set to `Dual`s, and any `Vector{Float64}` +arguments must be promoted to `Vector{Dual}`. Other tracer types, for example those in +SparseConnectivityTracer.jl, also require similar treatment. + +If the `AbstractInitStrategy` is never used in combination with tracer types, then it is +perfectly safe to return `Any`. This does not lead to type instability downstream because +the actual accumulators will still be created with concrete Float types (the `Any` is just +used to determine whether the float type needs to be modified). + +In case that wasn't enough: in fact, even the above is not always true. Firstly, the +accumulator argument is only true when evaluating with ThreadSafeVarInfo. See the comments +in `DynamicPPL.unflatten` for more details. For non-threadsafe evaluation, Julia is capable +of automatically promoting the types on its own. Secondly, the promotion only matters if you +are trying to directly assign into a `Vector{Float64}` with a `ForwardDiff.Dual` or similar +tracer type, for example using `xs[i] = MyDual`. This doesn't actually apply to +tilde-statements like `xs[i] ~ ...` because those use `Accessors.@set` under the hood, which +also does the promotion for you. For the gory details, see the following issues: + +- https://github.com/TuringLang/DynamicPPL.jl/issues/906 for accumulator types +- https://github.com/TuringLang/DynamicPPL.jl/issues/823 for type argument promotion +""" +get_param_eltype(::AbstractInitStrategy) = Any + """ InitFromPrior() @@ -29,7 +75,7 @@ Obtain new values by sampling from the prior distribution. """ struct InitFromPrior <: AbstractInitStrategy end function init(rng::Random.AbstractRNG, ::VarName, dist::Distribution, ::InitFromPrior) - return rand(rng, dist) + return rand(rng, dist), typed_identity end """ @@ -69,43 +115,61 @@ function init(rng::Random.AbstractRNG, ::VarName, dist::Distribution, u::InitFro if x isa Array{<:Any,0} x = x[] end - return x + return x, typed_identity end """ InitFromParams( - params::Union{AbstractDict{<:VarName},NamedTuple}, + params::Any fallback::Union{AbstractInitStrategy,Nothing}=InitFromPrior() ) -Obtain new values by extracting them from the given dictionary or NamedTuple. +Obtain new values by extracting them from the given set of `params`. + +The most common use case is to provide a `NamedTuple` or `AbstractDict{<:VarName}`, which +provides a mapping from variable names to values. However, we leave the type of `params` +open in order to allow for custom parameter storage types. -The parameter `fallback` specifies how new values are to be obtained if they -cannot be found in `params`, or they are specified as `missing`. `fallback` -can either be an initialisation strategy itself, in which case it will be -used to obtain new values, or it can be `nothing`, in which case an error -will be thrown. The default for `fallback` is `InitFromPrior()`. +## Custom parameter storage types -!!! note - The values in `params` must be provided in the space of the untransformed - distribution. +For `InitFromParams` to work correctly with a custom `params::P`, you need to implement + +```julia +DynamicPPL.init(rng, vn::VarName, dist::Distribution, p::InitFromParams{P}) where {P} +``` + +This tells you how to obtain values for the random variable `vn` from `p.params`. Note that +the last argument is `InitFromParams(params)`, not just `params` itself. Please see the +docstring of [`DynamicPPL.init`](@ref) for more information on the expected behaviour. + +If you only use `InitFromParams` with `DynamicPPL.OnlyAccsVarInfo`, as is usually the case, +then you will not need to implement anything else. So far, this is the same as you would do +for creating any new `AbstractInitStrategy` subtype. + +However, to use `InitFromParams` with a full `DynamicPPL.VarInfo`, you *may* also need to +implement + +```julia +DynamicPPL.get_param_eltype(p::InitFromParams{P}) where {P} +``` + +See the docstring of [`DynamicPPL.get_param_eltype`](@ref) for more information on when this +is needed. + +The argument `fallback` specifies how new values are to be obtained if they cannot be found +in `params`, or they are specified as `missing`. `fallback` can either be an initialisation +strategy itself, in which case it will be used to obtain new values, or it can be `nothing`, +in which case an error will be thrown. The default for `fallback` is `InitFromPrior()`. """ struct InitFromParams{P,S<:Union{AbstractInitStrategy,Nothing}} <: AbstractInitStrategy params::P fallback::S - function InitFromParams( - params::AbstractDict{<:VarName}, - fallback::Union{AbstractInitStrategy,Nothing}=InitFromPrior(), - ) - return new{typeof(params),typeof(fallback)}(params, fallback) - end - function InitFromParams( - params::NamedTuple, fallback::Union{AbstractInitStrategy,Nothing}=InitFromPrior() - ) - return InitFromParams(to_varname_dict(params), fallback) - end end -function init(rng::Random.AbstractRNG, vn::VarName, dist::Distribution, p::InitFromParams) +InitFromParams(params) = InitFromParams(params, InitFromPrior()) + +function init( + rng::Random.AbstractRNG, vn::VarName, dist::Distribution, p::InitFromParams{P} +) where {P<:Union{AbstractDict{<:VarName},NamedTuple}} # TODO(penelopeysm): It would be nice to do a check to make sure that all # of the parameters in `p.params` were actually used, and either warn or # error if they aren't. This is actually quite non-trivial though because @@ -119,13 +183,89 @@ function init(rng::Random.AbstractRNG, vn::VarName, dist::Distribution, p::InitF else # TODO(penelopeysm): Since x is user-supplied, maybe we could also # check here that the type / size of x matches the dist? - x + x, typed_identity end else p.fallback === nothing && error("No value was provided for the variable `$(vn)`.") init(rng, vn, dist, p.fallback) end end +function get_param_eltype( + strategy::InitFromParams{<:Union{AbstractDict{<:VarName},NamedTuple}} +) + return infer_nested_eltype(typeof(strategy.params)) +end + +""" + RangeAndLinked + +Suppose we have vectorised parameters `params::AbstractVector{<:Real}`. Each random variable +in the model will in general correspond to a sub-vector of `params`. This struct stores +information about that range, as well as whether the sub-vector represents a linked value or +an unlinked value. + +$(TYPEDFIELDS) +""" +struct RangeAndLinked + # indices that the variable corresponds to in the vectorised parameter + range::UnitRange{Int} + # whether it's linked + is_linked::Bool +end + +""" + VectorWithRanges( + iden_varname_ranges::NamedTuple, + varname_ranges::Dict{VarName,RangeAndLinked}, + vect::AbstractVector{<:Real}, + ) + +A struct that wraps a vector of parameter values, plus information about how random +variables map to ranges in that vector. + +In the simplest case, this could be accomplished only with a single dictionary mapping +VarNames to ranges and link status. However, for performance reasons, we separate out +VarNames with identity optics into a NamedTuple (`iden_varname_ranges`). All +non-identity-optic VarNames are stored in the `varname_ranges` Dict. + +It would be nice to improve the NamedTuple and Dict approach. See, e.g. +https://github.com/TuringLang/DynamicPPL.jl/issues/1116. +""" +struct VectorWithRanges{N<:NamedTuple,T<:AbstractVector{<:Real}} + # This NamedTuple stores the ranges for identity VarNames + iden_varname_ranges::N + # This Dict stores the ranges for all other VarNames + varname_ranges::Dict{VarName,RangeAndLinked} + # The full parameter vector which we index into to get variable values + vect::T +end + +function _get_range_and_linked( + vr::VectorWithRanges, ::VarName{sym,typeof(identity)} +) where {sym} + return vr.iden_varname_ranges[sym] +end +function _get_range_and_linked(vr::VectorWithRanges, vn::VarName) + return vr.varname_ranges[vn] +end +function init( + ::Random.AbstractRNG, + vn::VarName, + dist::Distribution, + p::InitFromParams{<:VectorWithRanges}, +) + vr = p.params + range_and_linked = _get_range_and_linked(vr, vn) + transform = if range_and_linked.is_linked + from_linked_vec_transform(dist) + else + from_vec_transform(dist) + end + return (@view vr.vect[range_and_linked.range]), transform +end +function get_param_eltype(strategy::InitFromParams{<:VectorWithRanges}) + return eltype(strategy.params.vect) +end """ InitContext( @@ -155,9 +295,8 @@ function tilde_assume!!( ctx::InitContext, dist::Distribution, vn::VarName, vi::AbstractVarInfo ) in_varinfo = haskey(vi, vn) - # `init()` always returns values in original space, i.e. possibly - # constrained - x = init(ctx.rng, vn, dist, ctx.strategy) + val, transform = init(ctx.rng, vn, dist, ctx.strategy) + x, inv_logjac = with_logabsdet_jacobian(transform, val) # Determine whether to insert a transformed value into the VarInfo. # If the VarInfo alrady had a value for this variable, we will # keep the same linked status as in the original VarInfo. If not, we @@ -165,17 +304,49 @@ function tilde_assume!!( # is_transformed(vi) returns true if vi is nonempty and all variables in vi # are linked. insert_transformed_value = in_varinfo ? is_transformed(vi, vn) : is_transformed(vi) - y, logjac = if insert_transformed_value - with_logabsdet_jacobian(link_transform(dist), x) + val_to_insert, logjac = if insert_transformed_value + # Calculate the forward logjac and sum them up. + y, fwd_logjac = with_logabsdet_jacobian(link_transform(dist), x) + # Note that if we use VectorWithRanges with a full VarInfo, this double-Jacobian + # calculation wastes a lot of time going from linked vectorised -> unlinked -> + # linked, and `inv_logjac` will also just be the negative of `fwd_logjac`. + # + # However, `VectorWithRanges` is only really used with `OnlyAccsVarInfo`, in which + # case this branch is never hit (since `in_varinfo` will always be false). It does + # mean that the combination of InitFromParams{<:VectorWithRanges} with a full, + # linked, VarInfo will be very slow. That should never really be used, though. So + # (at least for now) we can leave this branch in for full generality with other + # combinations of init strategies / VarInfo. + # + # TODO(penelopeysm): Figure out one day how to refactor this. The crux of the issue + # is that the transform used by `VectorWithRanges` is `from_linked_VEC_transform`, + # which is NOT the same as `inverse(link_transform)` (because there is an additional + # vectorisation step). We need `init` and `tilde_assume!!` to share this information + # but it's not clear right now how to do this. In my opinion, there are a couple of + # potential ways forward: + # + # 1. Just remove metadata entirely so that there is never any need to construct + # a linked vectorised value again. This would require us to use VAIMAcc as the only + # way of getting values. I consider this the best option, but it might take a long + # time. + # + # 2. Clean up the behaviour of bijectors so that we can have a complete separation + # between the linking and vectorisation parts of it. That way, `x` can either be + # unlinked, unlinked vectorised, linked, or linked vectorised, and regardless of + # which it is, we should only need to apply at most one linking and one + # vectorisation transform. Doing so would allow us to remove the first call to + # `with_logabsdet_jacobian`, and instead compose and/or uncompose the + # transformations before calling `with_logabsdet_jacobian` once. + y, -inv_logjac + fwd_logjac else - x, zero(LogProbType) + x, -inv_logjac end # Add the new value to the VarInfo. `push!!` errors if the value already # exists, hence the need for setindex!!. if in_varinfo - vi = setindex!!(vi, y, vn) + vi = setindex!!(vi, val_to_insert, vn) else - vi = push!!(vi, vn, y, dist) + vi = push!!(vi, vn, val_to_insert, dist) end # Neither of these set the `trans` flag so we have to do it manually if # necessary. diff --git a/src/experimental.jl b/src/experimental.jl index 8c82dca68..c644c09b2 100644 --- a/src/experimental.jl +++ b/src/experimental.jl @@ -2,6 +2,8 @@ module Experimental using DynamicPPL: DynamicPPL +include("fasteval.jl") + # This file only defines the names of the functions, and their docstrings. The actual implementations are in `ext/DynamicPPLJETExt.jl`, since we don't want to depend on JET.jl other than as a weak dependency. """ is_suitable_varinfo(model::Model, varinfo::AbstractVarInfo; kwargs...) diff --git a/src/fasteval.jl b/src/fasteval.jl new file mode 100644 index 000000000..c668b1413 --- /dev/null +++ b/src/fasteval.jl @@ -0,0 +1,336 @@ +using DynamicPPL: + AbstractVarInfo, + AccumulatorTuple, + InitContext, + InitFromParams, + LogJacobianAccumulator, + LogLikelihoodAccumulator, + LogPriorAccumulator, + Model, + ThreadSafeVarInfo, + VarInfo, + OnlyAccsVarInfo, + RangeAndLinked, + VectorWithRanges, + Metadata, + VarNamedVector, + default_accumulators, + float_type_with_fallback, + getlogjoint, + getlogjoint_internal, + getloglikelihood, + getlogprior, + getlogprior_internal +using ADTypes: ADTypes +using BangBang: BangBang +using AbstractPPL: AbstractPPL, VarName +using LogDensityProblems: LogDensityProblems +import DifferentiationInterface as DI +using Random: Random + +""" + FastLDF( + model::Model, + getlogdensity::Function=getlogjoint_internal, + varinfo::AbstractVarInfo=VarInfo(model); + adtype::Union{ADTypes.AbstractADType,Nothing}=nothing, + ) + +A struct which contains a model, along with all the information necessary to: + + - calculate its log density at a given point; + - and if `adtype` is provided, calculate the gradient of the log density at that point. + +This information can be extracted using the LogDensityProblems.jl interface, specifically, +using `LogDensityProblems.logdensity` and `LogDensityProblems.logdensity_and_gradient`. If +`adtype` is nothing, then only `logdensity` is implemented. If `adtype` is a concrete AD +backend type, then `logdensity_and_gradient` is also implemented. + +There are several options for `getlogdensity` that are 'supported' out of the box: + +- [`getlogjoint_internal`](@ref): calculate the log joint, including the log-Jacobian term + for any variables that have been linked in the provided VarInfo. +- [`getlogprior_internal`](@ref): calculate the log prior, including the log-Jacobian term + for any variables that have been linked in the provided VarInfo. +- [`getlogjoint`](@ref): calculate the log joint in the model space, ignoring any effects of + linking +- [`getlogprior`](@ref): calculate the log prior in the model space, ignoring any effects of + linking +- [`getloglikelihood`](@ref): calculate the log likelihood (this is unaffected by linking, + since transforms are only applied to random variables) + +!!! note + By default, `FastLDF` uses `getlogjoint_internal`, i.e., the result of + `LogDensityProblems.logdensity(f, x)` will depend on whether the `FastLDF` was created + with a linked or unlinked VarInfo. This is done primarily to ease interoperability with + MCMC samplers. + +If you provide one of these functions, a `VarInfo` will be automatically created for you. If +you provide a different function, you have to manually create a VarInfo and pass it as the +third argument. + +If the `adtype` keyword argument is provided, then this struct will also store the adtype +along with other information for efficient calculation of the gradient of the log density. +Note that preparing a `FastLDF` with an AD type `AutoBackend()` requires the AD backend +itself to have been loaded (e.g. with `import Backend`). + +## Fields + +Note that it is undefined behaviour to access any of a `FastLDF`'s fields, apart from: + +- `fastldf.model`: The original model from which this `FastLDF` was constructed. +- `fastldf.adtype`: The AD type used for gradient calculations, or `nothing` if no AD + type was provided. + +# Extended help + +Up until DynamicPPL v0.38, there have been two ways of evaluating a DynamicPPL model at a +given set of parameters: + +1. With `unflatten` + `evaluate!!` with `DefaultContext`: this stores a vector of parameters + inside a VarInfo's metadata, then reads parameter values from the VarInfo during evaluation. + +2. With `InitFromParams`: this reads parameter values from a NamedTuple or a Dict, and stores + them inside a VarInfo's metadata. + +In general, both of these approaches work fine, but the fact that they modify the VarInfo's +metadata can often be quite wasteful. In particular, it is very common that the only outputs +we care about from model evaluation are those which are stored in accumulators, such as log +probability densities, or `ValuesAsInModel`. + +To avoid this issue, we use `OnlyAccsVarInfo`, which is a VarInfo that only contains +accumulators. It implements enough of the `AbstractVarInfo` interface to not error during +model evaluation. + +Because `OnlyAccsVarInfo` does not store any parameter values, when evaluating a model with +it, it is mandatory that parameters are provided from outside the VarInfo, namely via +`InitContext`. + +The main problem that we face is that it is not possible to directly implement +`DynamicPPL.init(rng, vn, dist, strategy)` for `strategy::InitFromParams{<:AbstractVector}`. +In particular, it is not clear: + + - which parts of the vector correspond to which random variables, and + - whether the variables are linked or unlinked. + +Traditionally, this problem has been solved by `unflatten`, because that function would +place values into the VarInfo's metadata alongside the information about ranges and linking. +That way, when we evaluate with `DefaultContext`, we can read this information out again. +However, we want to avoid using a metadata. Thus, here, we _extract this information from +the VarInfo_ a single time when constructing a `FastLDF` object. Inside the FastLDF, we +store a mapping from VarNames to ranges in that vector, along with link status. + +For VarNames with identity optics, this is stored in a NamedTuple for efficiency. For all +other VarNames, this is stored in a Dict. The internal data structure used to represent this +could almost certainly be optimised further. See e.g. the discussion in +https://github.com/TuringLang/DynamicPPL.jl/issues/1116. + +When evaluating the model, this allows us to combine the parameter vector together with those +ranges to create an `InitFromParams{VectorWithRanges}`, which lets us very quickly read +parameter values from the vector. + +Note that this assumes that the ranges and link status are static throughout the lifetime of +the `FastLDF` object. Therefore, a `FastLDF` object cannot handle models which have variable +numbers of parameters, or models which may visit random variables in different orders depending +on stochastic control flow. **Indeed, silent errors may occur with such models.** This is a +general limitation of vectorised parameters: the original `unflatten` + `evaluate!!` +approach also fails with such models. +""" +struct FastLDF{ + M<:Model, + AD<:Union{ADTypes.AbstractADType,Nothing}, + F<:Function, + N<:NamedTuple, + ADP<:Union{Nothing,DI.GradientPrep}, +} + model::M + adtype::AD + _getlogdensity::F + _iden_varname_ranges::N + _varname_ranges::Dict{VarName,RangeAndLinked} + _adprep::ADP + + function FastLDF( + model::Model, + getlogdensity::Function=getlogjoint_internal, + varinfo::AbstractVarInfo=VarInfo(model); + adtype::Union{ADTypes.AbstractADType,Nothing}=nothing, + ) + # Figure out which variable corresponds to which index, and + # which variables are linked. + all_iden_ranges, all_ranges = get_ranges_and_linked(varinfo) + # Do AD prep if needed + prep = if adtype === nothing + nothing + else + # Make backend-specific tweaks to the adtype + adtype = DynamicPPL.tweak_adtype(adtype, model, varinfo) + x = [val for val in varinfo[:]] + DI.prepare_gradient( + FastLogDensityAt(model, getlogdensity, all_iden_ranges, all_ranges), + adtype, + x, + ) + end + return new{ + typeof(model), + typeof(adtype), + typeof(getlogdensity), + typeof(all_iden_ranges), + typeof(prep), + }( + model, adtype, getlogdensity, all_iden_ranges, all_ranges, prep + ) + end +end + +################################### +# LogDensityProblems.jl interface # +################################### +""" + fast_ldf_accs(getlogdensity::Function) + +Determine which accumulators are needed for fast evaluation with the given +`getlogdensity` function. +""" +fast_ldf_accs(::Function) = default_accumulators() +fast_ldf_accs(::typeof(getlogjoint_internal)) = default_accumulators() +function fast_ldf_accs(::typeof(getlogjoint)) + return AccumulatorTuple((LogPriorAccumulator(), LogLikelihoodAccumulator())) +end +function fast_ldf_accs(::typeof(getlogprior_internal)) + return AccumulatorTuple((LogPriorAccumulator(), LogJacobianAccumulator())) +end +fast_ldf_accs(::typeof(getlogprior)) = AccumulatorTuple((LogPriorAccumulator(),)) +fast_ldf_accs(::typeof(getloglikelihood)) = AccumulatorTuple((LogLikelihoodAccumulator(),)) + +struct FastLogDensityAt{M<:Model,F<:Function,N<:NamedTuple} + model::M + getlogdensity::F + iden_varname_ranges::N + varname_ranges::Dict{VarName,RangeAndLinked} +end +function (f::FastLogDensityAt)(params::AbstractVector{<:Real}) + ctx = InitContext( + Random.default_rng(), + InitFromParams( + VectorWithRanges(f.iden_varname_ranges, f.varname_ranges, params), nothing + ), + ) + model = DynamicPPL.setleafcontext(f.model, ctx) + accs = fast_ldf_accs(f.getlogdensity) + # Calling `evaluate!!` would be fine, but would lead to an extra call to resetaccs!!, + # which is unnecessary. So we shortcircuit this by simply calling `_evaluate!!` + # directly. To preserve thread-safety we need to reproduce the ThreadSafeVarInfo logic + # here. + # TODO(penelopeysm): This should _not_ check Threads.nthreads(). I still don't know what + # it _should_ do, but this is wrong regardless. + # https://github.com/TuringLang/DynamicPPL.jl/issues/1086 + vi = if Threads.nthreads() > 1 + accs = map( + acc -> DynamicPPL.convert_eltype(float_type_with_fallback(eltype(params)), acc), + accs, + ) + ThreadSafeVarInfo(OnlyAccsVarInfo(accs)) + else + OnlyAccsVarInfo(accs) + end + _, vi = DynamicPPL._evaluate!!(model, vi) + return f.getlogdensity(vi) +end + +function LogDensityProblems.logdensity(fldf::FastLDF, params::AbstractVector{<:Real}) + return FastLogDensityAt( + fldf.model, fldf._getlogdensity, fldf._iden_varname_ranges, fldf._varname_ranges + )( + params + ) +end + +function LogDensityProblems.logdensity_and_gradient( + fldf::FastLDF, params::AbstractVector{<:Real} +) + return DI.value_and_gradient( + FastLogDensityAt( + fldf.model, fldf._getlogdensity, fldf._iden_varname_ranges, fldf._varname_ranges + ), + fldf._adprep, + fldf.adtype, + params, + ) +end + +###################################################### +# Helper functions to extract ranges and link status # +###################################################### + +# This fails for SimpleVarInfo, but honestly there is no reason to support that here. The +# fact is that evaluation doesn't use a VarInfo, it only uses it once to generate the ranges +# and link status. So there is no motivation to use SimpleVarInfo inside a +# LogDensityFunction any more, we can just always use typed VarInfo. In fact one could argue +# that there is no purpose in supporting untyped VarInfo either. +""" + get_ranges_and_linked(varinfo::VarInfo) + +Given a `VarInfo`, extract the ranges of each variable in the vectorised parameter +representation, along with whether each variable is linked or unlinked. + +This function should return a tuple containing: + +- A NamedTuple mapping VarNames with identity optics to their corresponding `RangeAndLinked` +- A Dict mapping all other VarNames to their corresponding `RangeAndLinked`. +""" +function get_ranges_and_linked(varinfo::VarInfo{<:NamedTuple{syms}}) where {syms} + all_iden_ranges = NamedTuple() + all_ranges = Dict{VarName,RangeAndLinked}() + offset = 1 + for sym in syms + md = varinfo.metadata[sym] + this_md_iden, this_md_others, offset = get_ranges_and_linked_metadata(md, offset) + all_iden_ranges = merge(all_iden_ranges, this_md_iden) + all_ranges = merge(all_ranges, this_md_others) + end + return all_iden_ranges, all_ranges +end +function get_ranges_and_linked(varinfo::VarInfo{<:Union{Metadata,VarNamedVector}}) + all_iden, all_others, _ = get_ranges_and_linked_metadata(varinfo.metadata, 1) + return all_iden, all_others +end +function get_ranges_and_linked_metadata(md::Metadata, start_offset::Int) + all_iden_ranges = NamedTuple() + all_ranges = Dict{VarName,RangeAndLinked}() + offset = start_offset + for (vn, idx) in md.idcs + is_linked = md.is_transformed[idx] + range = md.ranges[idx] .+ (start_offset - 1) + if AbstractPPL.getoptic(vn) === identity + all_iden_ranges = merge( + all_iden_ranges, + NamedTuple((AbstractPPL.getsym(vn) => RangeAndLinked(range, is_linked),)), + ) + else + all_ranges[vn] = RangeAndLinked(range, is_linked) + end + offset += length(range) + end + return all_iden_ranges, all_ranges, offset +end +function get_ranges_and_linked_metadata(vnv::VarNamedVector, start_offset::Int) + all_iden_ranges = NamedTuple() + all_ranges = Dict{VarName,RangeAndLinked}() + offset = start_offset + for (vn, idx) in vnv.varname_to_index + is_linked = vnv.is_unconstrained[idx] + range = vnv.ranges[idx] .+ (start_offset - 1) + if AbstractPPL.getoptic(vn) === identity + all_iden_ranges = merge( + all_iden_ranges, + NamedTuple((AbstractPPL.getsym(vn) => RangeAndLinked(range, is_linked),)), + ) + else + all_ranges[vn] = RangeAndLinked(range, is_linked) + end + offset += length(range) + end + return all_iden_ranges, all_ranges, offset +end diff --git a/src/model.jl b/src/model.jl index 94fcd9fd4..2bcfe8f98 100644 --- a/src/model.jl +++ b/src/model.jl @@ -986,9 +986,13 @@ Return the arguments and keyword arguments to be passed to the evaluator of the ) where {_F,argnames} unwrap_args = [ if is_splat_symbol(var) - :($matchingvalue(varinfo, model.args.$var)...) + :( + $matchingvalue( + $get_param_eltype(varinfo, model.context), model.args.$var + )... + ) else - :($matchingvalue(varinfo, model.args.$var)) + :($matchingvalue($get_param_eltype(varinfo, model.context), model.args.$var)) end for var in argnames ] return quote @@ -1006,6 +1010,30 @@ Return the arguments and keyword arguments to be passed to the evaluator of the end end +""" + get_param_eltype(varinfo::AbstractVarInfo, context::AbstractContext) + +Get the element type of the parameters being used to evaluate a model, using a `varinfo` +under the given `context`. For example, when evaluating a model with ForwardDiff AD, this +should return `ForwardDiff.Dual`. + +By default, this uses `eltype(varinfo)` which is slightly cursed. This relies on the fact +that typically, before evaluation, the parameters will have been inserted into the VarInfo's +metadata field. + +For `InitContext`, it's quite different: because `InitContext` is responsible for supplying +the parameters, we can avoid using `eltype(varinfo)` and instead query the parameters inside +it. See the docstring of `get_param_eltype(strategy::AbstractInitStrategy)` for more +explanation. +""" +function get_param_eltype(vi::AbstractVarInfo, ctx::AbstractParentContext) + return get_param_eltype(vi, DynamicPPL.childcontext(ctx)) +end +get_param_eltype(vi::AbstractVarInfo, ::AbstractContext) = eltype(vi) +function get_param_eltype(::AbstractVarInfo, ctx::InitContext) + return get_param_eltype(ctx.strategy) +end + """ getargnames(model::Model) diff --git a/src/onlyaccs.jl b/src/onlyaccs.jl new file mode 100644 index 000000000..940f23124 --- /dev/null +++ b/src/onlyaccs.jl @@ -0,0 +1,42 @@ +""" + OnlyAccsVarInfo + +This is a wrapper around an `AccumulatorTuple` that implements the minimal `AbstractVarInfo` +interface to work with the `tilde_assume!!` and `tilde_observe!!` functions for +`InitContext`. + +Note that this does not implement almost every other AbstractVarInfo interface function, and +so using this with a different leaf context such as `DefaultContext` will result in errors. + +Conceptually, one can also think of this as a VarInfo that doesn't contain a metadata field. +This is also why it only works with `InitContext`: in this case, the parameters used for +evaluation are supplied by the context instead of the metadata. +""" +struct OnlyAccsVarInfo{Accs<:AccumulatorTuple} <: AbstractVarInfo + accs::Accs +end +OnlyAccsVarInfo() = OnlyAccsVarInfo(default_accumulators()) +function OnlyAccsVarInfo(accs::NTuple{N,AbstractAccumulator}) where {N} + return OnlyAccsVarInfo(AccumulatorTuple(accs)) +end + +# Minimal AbstractVarInfo interface +DynamicPPL.maybe_invlink_before_eval!!(vi::OnlyAccsVarInfo, ::Model) = vi +DynamicPPL.getaccs(vi::OnlyAccsVarInfo) = vi.accs +DynamicPPL.setaccs!!(::OnlyAccsVarInfo, accs::AccumulatorTuple) = OnlyAccsVarInfo(accs) + +# Ideally, we'd define this together with InitContext, but alas that file comes way before +# this one, and sorting out the include order is a pain. +function tilde_assume!!( + ctx::InitContext, + dist::Distribution, + vn::VarName, + vi::Union{OnlyAccsVarInfo,ThreadSafeVarInfo{<:OnlyAccsVarInfo}}, +) + # For OnlyAccsVarInfo, since we don't need to write into the VarInfo, we can + # cut out a lot of the code above. + val, transform = init(ctx.rng, vn, dist, ctx.strategy) + x, inv_logjac = with_logabsdet_jacobian(transform, val) + vi = accumulate_assume!!(vi, x, -inv_logjac, vn, dist) + return x, vi +end diff --git a/src/utils.jl b/src/utils.jl index b55a2f715..75fb805dc 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -15,6 +15,41 @@ This is Float64 on 64-bit systems and Float32 on 32-bit systems. """ const LogProbType = float(Real) +""" + typed_identity(x) + +Identity function, but with an overload for `with_logabsdet_jacobian` to ensure +that it returns a sensible zero logjac. + +The problem with plain old `identity` is that the default definition of +`with_logabsdet_jacobian` for `identity` returns `zero(eltype(x))`: +https://github.com/JuliaMath/ChangesOfVariables.jl/blob/d6a8115fc9b9419decbdb48e2c56ec9675b4c6a4/src/with_ladj.jl#L154 + +This is fine for most samples `x`, but if `eltype(x)` doesn't return a sensible type (e.g. +if it's `Any`), then using `identity` will error with `zero(Any)`. This can happen with, +for example, `ProductNamedTupleDistribution`: + +```julia +julia> using Distributions; d = product_distribution((a = Normal(), b = LKJCholesky(3, 0.5))); + +julia> eltype(rand(d)) +Any +``` + +The same problem precludes us from eventually broadening the scope of DynamicPPL.jl to +support distributions with non-numeric samples. + +Furthermore, in principle, the type of the log-probability should be separate from the type +of the sample. Thus, instead of using `zero(LogProbType)`, we should use the eltype of the +LogJacobianAccumulator. There's no easy way to thread that through here, but if a way to do +this is discovered, then `typed_identity` is what will allow us to obtain that custom +behaviour. +""" +function typed_identity end +@inline typed_identity(x) = x +@inline Bijectors.with_logabsdet_jacobian(::typeof(typed_identity), x) = + (x, zero(LogProbType)) + """ @addlogprob!(ex) diff --git a/test/Project.toml b/test/Project.toml index 2dbd5b455..efd916308 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -6,6 +6,7 @@ Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697" Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" BangBang = "198e06fe-97b7-11e9-32a5-e1d131e6ad66" Bijectors = "76274a88-744f-5084-9051-94815aaf08c4" +Chairmarks = "0ca39b1e-fe0b-4e98-acfc-b1656634c4de" Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa" DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63" Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b" diff --git a/test/fasteval.jl b/test/fasteval.jl new file mode 100644 index 000000000..db2333711 --- /dev/null +++ b/test/fasteval.jl @@ -0,0 +1,233 @@ +module DynamicPPLFastLDFTests + +using AbstractPPL: AbstractPPL +using Chairmarks +using DynamicPPL +using Distributions +using DistributionsAD: filldist +using ADTypes +using DynamicPPL.Experimental: FastLDF +using DynamicPPL.TestUtils.AD: run_ad, WithExpectedResult, NoTest +using LinearAlgebra: I +using Test +using LogDensityProblems: LogDensityProblems + +using ForwardDiff: ForwardDiff +using ReverseDiff: ReverseDiff +# Need to include this block here in case we run this test file standalone +@static if VERSION < v"1.12" + using Pkg + Pkg.add("Mooncake") + using Mooncake: Mooncake +end + +@testset "FastLDF: Correctness" begin + @testset "$(m.f)" for m in DynamicPPL.TestUtils.DEMO_MODELS + @testset "$varinfo_func" for varinfo_func in [ + DynamicPPL.untyped_varinfo, + DynamicPPL.typed_varinfo, + DynamicPPL.untyped_vector_varinfo, + DynamicPPL.typed_vector_varinfo, + ] + unlinked_vi = varinfo_func(m) + @testset "$islinked" for islinked in (false, true) + vi = if islinked + DynamicPPL.link!!(unlinked_vi, m) + else + unlinked_vi + end + nt_ranges, dict_ranges = DynamicPPL.Experimental.get_ranges_and_linked(vi) + params = [x for x in vi[:]] + # Iterate over all variables + for vn in keys(vi) + # Check that `getindex_internal` returns the same thing as using the ranges + # directly + range_with_linked = if AbstractPPL.getoptic(vn) === identity + nt_ranges[AbstractPPL.getsym(vn)] + else + dict_ranges[vn] + end + @test params[range_with_linked.range] == + DynamicPPL.getindex_internal(vi, vn) + # Check that the link status is correct + @test range_with_linked.is_linked == islinked + end + + # Compare results of FastLDF vs ordinary LogDensityFunction. These tests + # can eventually go once we replace LogDensityFunction with FastLDF, but + # for now it helps to have this check! (Eventually we should just check + # against manually computed log-densities). + # + # TODO(penelopeysm): I think we need to add tests for some really + # pathological models here. + @testset "$getlogdensity" for getlogdensity in ( + DynamicPPL.getlogjoint_internal, + DynamicPPL.getlogjoint, + DynamicPPL.getloglikelihood, + DynamicPPL.getlogprior_internal, + DynamicPPL.getlogprior, + ) + ldf = DynamicPPL.LogDensityFunction(m, getlogdensity, vi) + fldf = FastLDF(m, getlogdensity, vi) + @test LogDensityProblems.logdensity(ldf, params) ≈ + LogDensityProblems.logdensity(fldf, params) + end + end + end + end + + @testset "Threaded observe" begin + if Threads.nthreads() > 1 + @model function threaded(y) + x ~ Normal() + Threads.@threads for i in eachindex(y) + y[i] ~ Normal(x) + end + end + N = 100 + model = threaded(zeros(N)) + ldf = DynamicPPL.Experimental.FastLDF(model) + + xs = [1.0] + @test LogDensityProblems.logdensity(ldf, xs) ≈ + logpdf(Normal(), xs[1]) + N * logpdf(Normal(xs[1]), 0.0) + end + end +end + +@testset "FastLDF: performance" begin + if Threads.nthreads() == 1 + # Evaluating these three models should not lead to any allocations (but only when + # not using TSVI). + @model function f() + x ~ Normal() + return 1.0 ~ Normal(x) + end + @model function submodel_inner() + m ~ Normal(0, 1) + s ~ Exponential() + return (m=m, s=s) + end + # Note that for the allocation tests to work on this one, `inner` has + # to be passed as an argument to `submodel_outer`, instead of just + # being called inside the model function itself + @model function submodel_outer(inner) + params ~ to_submodel(inner) + y ~ Normal(params.m, params.s) + return 1.0 ~ Normal(y) + end + @testset for model in + (f(), submodel_inner() | (; s=0.0), submodel_outer(submodel_inner())) + vi = VarInfo(model) + fldf = DynamicPPL.Experimental.FastLDF( + model, DynamicPPL.getlogjoint_internal, vi + ) + x = vi[:] + bench = median(@be LogDensityProblems.logdensity(fldf, x)) + @test iszero(bench.allocs) + end + end +end + +@testset "AD with FastLDF" begin + # Used as the ground truth that others are compared against. + ref_adtype = AutoForwardDiff() + + test_adtypes = @static if VERSION < v"1.12" + [ + AutoReverseDiff(; compile=false), + AutoReverseDiff(; compile=true), + AutoMooncake(; config=nothing), + ] + else + [AutoReverseDiff(; compile=false), AutoReverseDiff(; compile=true)] + end + + @testset "Correctness" begin + @testset "$(m.f)" for m in DynamicPPL.TestUtils.DEMO_MODELS + varinfo = VarInfo(m) + linked_varinfo = DynamicPPL.link(varinfo, m) + f = FastLDF(m, getlogjoint_internal, linked_varinfo) + x = [p for p in linked_varinfo[:]] + + # Calculate reference logp + gradient of logp using ForwardDiff + ref_ad_result = run_ad(m, ref_adtype; varinfo=linked_varinfo, test=NoTest()) + ref_logp, ref_grad = ref_ad_result.value_actual, ref_ad_result.grad_actual + + @testset "$adtype" for adtype in test_adtypes + @info "Testing AD on: $(m.f) - $adtype" + + @test run_ad( + m, + adtype; + varinfo=linked_varinfo, + test=WithExpectedResult(ref_logp, ref_grad), + ) isa Any + end + end + end + + # Test that various different ways of specifying array types as arguments work with all + # ADTypes. + @testset "Array argument types" begin + test_m = randn(2, 3) + + function eval_logp_and_grad(model, m, adtype) + ldf = FastLDF(model(); adtype=adtype) + return LogDensityProblems.logdensity_and_gradient(ldf, m[:]) + end + + @model function scalar_matrix_model(::Type{T}=Float64) where {T<:Real} + m = Matrix{T}(undef, 2, 3) + return m ~ filldist(MvNormal(zeros(2), I), 3) + end + + scalar_matrix_model_reference = eval_logp_and_grad( + scalar_matrix_model, test_m, ref_adtype + ) + + @model function matrix_model(::Type{T}=Matrix{Float64}) where {T} + m = T(undef, 2, 3) + return m ~ filldist(MvNormal(zeros(2), I), 3) + end + + matrix_model_reference = eval_logp_and_grad(matrix_model, test_m, ref_adtype) + + @model function scalar_array_model(::Type{T}=Float64) where {T<:Real} + m = Array{T}(undef, 2, 3) + return m ~ filldist(MvNormal(zeros(2), I), 3) + end + + scalar_array_model_reference = eval_logp_and_grad( + scalar_array_model, test_m, ref_adtype + ) + + @model function array_model(::Type{T}=Array{Float64}) where {T} + m = T(undef, 2, 3) + return m ~ filldist(MvNormal(zeros(2), I), 3) + end + + array_model_reference = eval_logp_and_grad(array_model, test_m, ref_adtype) + + @testset "$adtype" for adtype in test_adtypes + scalar_matrix_model_logp_and_grad = eval_logp_and_grad( + scalar_matrix_model, test_m, adtype + ) + @test scalar_matrix_model_logp_and_grad[1] ≈ scalar_matrix_model_reference[1] + @test scalar_matrix_model_logp_and_grad[2] ≈ scalar_matrix_model_reference[2] + matrix_model_logp_and_grad = eval_logp_and_grad(matrix_model, test_m, adtype) + @test matrix_model_logp_and_grad[1] ≈ matrix_model_reference[1] + @test matrix_model_logp_and_grad[2] ≈ matrix_model_reference[2] + scalar_array_model_logp_and_grad = eval_logp_and_grad( + scalar_array_model, test_m, adtype + ) + @test scalar_array_model_logp_and_grad[1] ≈ scalar_array_model_reference[1] + @test scalar_array_model_logp_and_grad[2] ≈ scalar_array_model_reference[2] + array_model_logp_and_grad = eval_logp_and_grad(array_model, test_m, adtype) + @test array_model_logp_and_grad[1] ≈ array_model_reference[1] + @test array_model_logp_and_grad[2] ≈ array_model_reference[2] + end + end +end + +end diff --git a/test/integration/enzyme/main.jl b/test/integration/enzyme/main.jl index b40bbeb8f..ea4ec497d 100644 --- a/test/integration/enzyme/main.jl +++ b/test/integration/enzyme/main.jl @@ -6,8 +6,10 @@ import Enzyme: set_runtime_activity, Forward, Reverse, Const using ForwardDiff: ForwardDiff # run_ad uses FD for correctness test ADTYPES = Dict( - "EnzymeForward" => AutoEnzyme(; mode=set_runtime_activity(Forward)), - "EnzymeReverse" => AutoEnzyme(; mode=set_runtime_activity(Reverse)), + "EnzymeForward" => + AutoEnzyme(; mode=set_runtime_activity(Forward), function_annotation=Const), + "EnzymeReverse" => + AutoEnzyme(; mode=set_runtime_activity(Reverse), function_annotation=Const), ) @testset "$ad_key" for (ad_key, ad_type) in ADTYPES diff --git a/test/runtests.jl b/test/runtests.jl index 861d3bb87..10fac8b0f 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -89,6 +89,7 @@ include("test_util.jl") include("ext/DynamicPPLMooncakeExt.jl") end include("ad.jl") + include("fasteval.jl") end @testset "prob and logprob macro" begin @test_throws ErrorException prob"..." From 9624103885a6f8b6cac70b1d0796da5a6227b65a Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Fri, 14 Nov 2025 00:34:59 +0000 Subject: [PATCH 06/45] implement `LogDensityProblems.dimension` --- src/fasteval.jl | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/src/fasteval.jl b/src/fasteval.jl index c668b1413..aa2fdd933 100644 --- a/src/fasteval.jl +++ b/src/fasteval.jl @@ -149,6 +149,7 @@ struct FastLDF{ _iden_varname_ranges::N _varname_ranges::Dict{VarName,RangeAndLinked} _adprep::ADP + _dim::Int function FastLDF( model::Model, @@ -159,13 +160,14 @@ struct FastLDF{ # Figure out which variable corresponds to which index, and # which variables are linked. all_iden_ranges, all_ranges = get_ranges_and_linked(varinfo) + x = [val for val in varinfo[:]] + dim = length(x) # Do AD prep if needed prep = if adtype === nothing nothing else # Make backend-specific tweaks to the adtype adtype = DynamicPPL.tweak_adtype(adtype, model, varinfo) - x = [val for val in varinfo[:]] DI.prepare_gradient( FastLogDensityAt(model, getlogdensity, all_iden_ranges, all_ranges), adtype, @@ -179,7 +181,7 @@ struct FastLDF{ typeof(all_iden_ranges), typeof(prep), }( - model, adtype, getlogdensity, all_iden_ranges, all_ranges, prep + model, adtype, getlogdensity, all_iden_ranges, all_ranges, prep, dim ) end end @@ -260,6 +262,10 @@ function LogDensityProblems.logdensity_and_gradient( ) end +function LogDensityProblems.dimension(fldf::FastLDF) + return fldf._dim +end + ###################################################### # Helper functions to extract ranges and link status # ###################################################### From ce807139b31919b40afc98bcc522e2f41e14dc30 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Fri, 14 Nov 2025 00:40:37 +0000 Subject: [PATCH 07/45] forgot about capabilities... --- src/fasteval.jl | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/src/fasteval.jl b/src/fasteval.jl index aa2fdd933..4f402f4a8 100644 --- a/src/fasteval.jl +++ b/src/fasteval.jl @@ -262,6 +262,16 @@ function LogDensityProblems.logdensity_and_gradient( ) end +function LogDensityProblems.capabilities( + ::Type{<:DynamicPPL.Experimental.FastLDF{M,Nothing}} +) where {M} + return LogDensityProblems.LogDensityOrder{0}() +end +function LogDensityProblems.capabilities( + ::Type{<:DynamicPPL.Experimental.FastLDF{M,<:ADTypes.AbstractADType}} +) where {M} + return LogDensityProblems.LogDensityOrder{1}() +end function LogDensityProblems.dimension(fldf::FastLDF) return fldf._dim end From 8553e401182b894a653b638e5978f00ae42ae031 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Tue, 18 Nov 2025 12:03:14 +0000 Subject: [PATCH 08/45] use interpolation in run_ad --- src/test_utils/ad.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/test_utils/ad.jl b/src/test_utils/ad.jl index a49ffd18b..d7a34e6e0 100644 --- a/src/test_utils/ad.jl +++ b/src/test_utils/ad.jl @@ -298,8 +298,8 @@ function run_ad( # Benchmark grad_time, primal_time = if benchmark - primal_benchmark = @be (ldf, params) logdensity(_[1], _[2]) - grad_benchmark = @be (ldf, params) logdensity_and_gradient(_[1], _[2]) + primal_benchmark = @be logdensity($ldf, $params) + grad_benchmark = @be logdensity_and_gradient($ldf, $params) median_primal = median(primal_benchmark).time median_grad = median(grad_benchmark).time r(f) = round(f; sigdigits=4) From 3cd8d3431e14ebc581266c1323d1db8a5bd4c0eb Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Tue, 18 Nov 2025 17:48:32 +0000 Subject: [PATCH 09/45] Improvements to benchmark outputs (#1146) * print output * fix * reenable * add more lines to guide the eye * reorder table * print tgrad / trel as well * forgot this type --- benchmarks/benchmarks.jl | 38 +++++++++++++++++++++++++++++++++----- src/test_utils/ad.jl | 10 +++++++++- 2 files changed, 42 insertions(+), 6 deletions(-) diff --git a/benchmarks/benchmarks.jl b/benchmarks/benchmarks.jl index 3af6573cf..e8ffa7e0b 100644 --- a/benchmarks/benchmarks.jl +++ b/benchmarks/benchmarks.jl @@ -98,12 +98,15 @@ function run(; to_json=false) }[] for (model_name, model, varinfo_choice, adbackend, islinked) in chosen_combinations - @info "Running benchmark for $model_name" + @info "Running benchmark for $model_name, $varinfo_choice, $adbackend, $islinked" relative_eval_time, relative_ad_eval_time = try results = benchmark(model, varinfo_choice, adbackend, islinked) + @info " t(eval) = $(results.primal_time)" + @info " t(grad) = $(results.grad_time)" (results.primal_time / reference_time), (results.grad_time / results.primal_time) catch e + @info "benchmark errored: $e" missing, missing end push!( @@ -155,18 +158,33 @@ function combine(head_filename::String, base_filename::String) all_testcases = union(Set(keys(head_testcases)), Set(keys(base_testcases))) @info "$(length(all_testcases)) unique test cases found" sorted_testcases = sort( - collect(all_testcases); by=(c -> (c.model_name, c.ad_backend, c.varinfo, c.linked)) + collect(all_testcases); by=(c -> (c.model_name, c.linked, c.varinfo, c.ad_backend)) ) results_table = Tuple{ - String,Int,String,String,Bool,String,String,String,String,String,String + String, + Int, + String, + String, + Bool, + String, + String, + String, + String, + String, + String, + String, + String, + String, }[] + sublabels = ["base", "this PR", "speedup"] results_colnames = [ [ EmptyCells(5), MultiColumn(3, "t(eval) / t(ref)"), MultiColumn(3, "t(grad) / t(eval)"), + MultiColumn(3, "t(grad) / t(ref)"), ], - [colnames[1:5]..., "base", "this PR", "speedup", "base", "this PR", "speedup"], + [colnames[1:5]..., sublabels..., sublabels..., sublabels...], ] sprint_float(x::Float64) = @sprintf("%.2f", x) sprint_float(m::Missing) = "err" @@ -183,6 +201,10 @@ function combine(head_filename::String, base_filename::String) # Finally that lets us do this division safely speedup_eval = base_eval / head_eval speedup_grad = base_grad / head_grad + # As well as this multiplication, which is t(grad) / t(ref) + head_grad_vs_ref = head_grad * head_eval + base_grad_vs_ref = base_grad * base_eval + speedup_grad_vs_ref = base_grad_vs_ref / head_grad_vs_ref push!( results_table, ( @@ -197,6 +219,9 @@ function combine(head_filename::String, base_filename::String) sprint_float(base_grad), sprint_float(head_grad), sprint_float(speedup_grad), + sprint_float(base_grad_vs_ref), + sprint_float(head_grad_vs_ref), + sprint_float(speedup_grad_vs_ref), ), ) end @@ -212,7 +237,10 @@ function combine(head_filename::String, base_filename::String) backend=:text, fit_table_in_display_horizontally=false, fit_table_in_display_vertically=false, - table_format=TextTableFormat(; horizontal_line_at_merged_column_labels=true), + table_format=TextTableFormat(; + horizontal_line_at_merged_column_labels=true, + horizontal_lines_at_data_rows=collect(3:3:length(results_table)), + ), ) println("```") end diff --git a/src/test_utils/ad.jl b/src/test_utils/ad.jl index d7a34e6e0..8ee850877 100644 --- a/src/test_utils/ad.jl +++ b/src/test_utils/ad.jl @@ -5,7 +5,13 @@ using Chairmarks: @be import DifferentiationInterface as DI using DocStringExtensions using DynamicPPL: - Model, LogDensityFunction, VarInfo, AbstractVarInfo, getlogjoint_internal, link + DynamicPPL, + Model, + LogDensityFunction, + VarInfo, + AbstractVarInfo, + getlogjoint_internal, + link using LogDensityProblems: logdensity, logdensity_and_gradient using Random: AbstractRNG, default_rng using Statistics: median @@ -298,7 +304,9 @@ function run_ad( # Benchmark grad_time, primal_time = if benchmark + logdensity(ldf, params) # Warm-up primal_benchmark = @be logdensity($ldf, $params) + logdensity_and_gradient(ldf, params) # Warm-up grad_benchmark = @be logdensity_and_gradient($ldf, $params) median_primal = median(primal_benchmark).time median_grad = median(grad_benchmark).time From eab71317d406a56fe06df7c8f944a4063e564112 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Wed, 19 Nov 2025 17:08:31 +0000 Subject: [PATCH 10/45] Add VarNamedTuple, tests, and WIP docs --- docs/src/internals/varnamedtuple.md | 112 ++++++++++ src/varnamedtuple.jl | 310 ++++++++++++++++++++++++++++ test/varnamedtuple.jl | 89 ++++++++ 3 files changed, 511 insertions(+) create mode 100644 docs/src/internals/varnamedtuple.md create mode 100644 src/varnamedtuple.jl create mode 100644 test/varnamedtuple.jl diff --git a/docs/src/internals/varnamedtuple.md b/docs/src/internals/varnamedtuple.md new file mode 100644 index 000000000..9f7a84cdb --- /dev/null +++ b/docs/src/internals/varnamedtuple.md @@ -0,0 +1,112 @@ +# VarNamedTuple as the basis of VarInfo + +This document collects thoughts and ideas for how to unify our multitude of AbstractVarInfo types using a VarNamedTuple type. It may eventually turn into a draft design document, but for now it is more raw than that. + +## The current situation + +We currently have the following AbstractVarInfo types: + + - A: VarInfo with Metadata + - B: VarInfo with VarNamedVector + - C: VarInfo with NamedTuple, with values being Metadata + - D: VarInfo with NamedTuple, with values being VarNamedVector + - E: SimpleVarInfo with NamedTuples + - F: SimpleVarInfo with OrderedDict + +A and C are the classic ones, and the defaults. C wraps groups the Metadata objects by the lead Symbol of the VarName of a variable, e.g. `x` in `@varname(x.y[1].z)`, which allows different lead Symbols to have different element types and for the VarInfo to still be type stable. B and D were created to simplify A and C, give them a nicer interface, and make them deal better with changing variable sizes, but according to recent (Oct 2025) benchmarks are quite a lot slower, which needs work. + +E and F are entirely distinct in implementation from the others. E is simply a mapping from Symbols to values, with each VarName being converted to a single symbol, e.g. `Symbol("a[1]")`. F is a mapping from VarNames to values as an OrderedDict, with VarName as the key type. + +A-D carry within them values for variables, but also their bijectors/distributions, and store all values vectorised, using the bijectors to map to the original values. They also store for each variable a flag for whether the variable has been linked. E-F store only the raw values, and a global flag for the whole SimpleVarInfo for whether it's linked. The link transform itself is implicit. + +TODO: Write a better summary of pros and cons of each approach. + +## VarNamedTuple + +VarNamedTuple has been discussed as a possible data structure to generalise the structure used in VarInfo to achieve type stability, i.e. grouping VarNames by their lead Symbol. The same NamedTuple structure has been used elsewhere, too, e.g. in Turing.GibbsContext. The idea was to encapsulate this structure into its own type, reducing code duplication and making the design more robust and powerful. See https://github.com/TuringLang/DynamicPPL.jl/issues/900 for the discussion. + +An AbstractVarInfo type could be only one application of VarNamedTuple, but here I'll focus on it exclusively. If we can make VarNamedTuple work for an AbstractVarInfo, I bet we can make it work for other purposes (condition, fix, Gibbs) as well. + +Without going into full detail, here's @mhauru's current proposal for what it would look like. This proposal remains in constant flux as I develop the code. + +A VarNamedTuple is a mapping of VarNames to values. Values can be anything. In the case of using VarNamedTuple to implement an AbstractVarInfo, the values would be random samples for random variables. However, they could hold with them extra information. For instance, we might use a value that is a tuple of a vectorised value, a bijector, and a flag for whether the variable is linked. + +I sometimes shorten VarNamedTuple to VNT. + +Internally, a VarNamedTuple consists of nested NamedTuples. For instance, the mapping `@varname(x) => 1, @varname(y.z) => 2` would be stored as + +``` +(; x=1, y=(; z=2)) +``` + +(This is a slight simplification, really it would be nested VarNamedTuples rather than NamedTuples, but I omit this detail.) +This forms a tree, with each node being a NamedTuple, like so: + +``` + NT +x / \ y + 1 NT + \ z + 2 +``` + +Each `NT` marks a NamedTuple, and the labels on the edges its keys. Here the root node has the keys `x` and `y`. This is like with the type stable VarInfo in our current design, except with possibly more levels (our current one only has the root node). Each nested `PropertyLens`, i.e. each `.` in a VarName like `@varname(a.b.c.e)`, creates a new layer of the tree. + +For simplicity, at least for now, we ban any VarNames where an `IndexLens` precedes a `PropertyLens`. That is, we ban any VarNames like `@varname(a.b[1].c)`. Recall that VarNames allow three different kinds of lenses: `PropertyLens`es, `IndexLens`es, and `identity` (the trivial lens). Thus the only allowed VarName types are `@varname(a.b.c.d)` and `@varname(a.b.c.d[i,j,k])`. + +This means that we can add levels to the NamedTuple tree until all `PropertyLenses` have been covered. The leaves of the tree are then of two kinds: They are either the raw value itself if the last lens of the VarName is an `identity`, or otherwise they are something that can be indexed with an `IndexLens`, such as an `Array`. + +To get a value from a VarNamedTuple is very simple: For `getindex(vnt::VNT, vn::VarName{S})` (`S` being the lead Symbol) you recurse into `getindex(vnt[S], unprefix(vn, S))`. If the last lens of `vn` is an `IndexLens`, we assume that the leaf of the NamedTuple tree we've reached contains something that can be indexed with it. + +Setting values in a VNT is equally simple if there are no `IndexLenses`: For `setindex!!(vnt::VNT, value::Any, vn::VarName)` one simply finds the leaf of the `vnt` tree corresponding to `vn` and sets its value to `value`. + +The tricky part is what to do when setting values with `IndexLenses`. There are three possible situations. Say one calls `setindex!!(vnt, 3.0, @varname(a.b[3]))`. + + 1. If `getindex(vnt, @varname(a.b))` is already a vector of length at least 3, this is easy: Just set the third element. + 2. If `getindex(vnt, @varname(a.b))` is a vector of length less than 3, what should we do? Do we error? Do we extend that vector? + 3. If `getindex(vnt, @varname(a.b))` isn't even set, what do we do? Say for instance that `vnt` is currently empty. We should set `vnt` to be something like `(; a=(; b=x))`, where `x` is such that `x[3] = 3.0`, but what exactly should `x` be? Is it a dictionary? A vector of length 3? If the latter, what are `x[2]` and `x[1]`? Or should this `setindex!!` call simply error? + +A note at this point: VarNamedTuples must always use `setindex!!`, the `!!` version that may or may not operate in place. The NamedTuples can't be modified in place, but the values at the leaves may be. Always using a `!!` function makes type stability easier, and makes structures like the type unstable old VarInfo with Metadata unnecessary: Any value can be set into any VarNamedTuple. The type parameters of the VNT will simply expand as necessary. + +To solve the problem of points 2. and 3. above I propose expanding the definition of VNT a bit. This will also help make VNT more flexible, which may help performance or allow more use cases. The modification is this: + +Unlike I said above, let's say that VNT isn't just nested NamedTuples with some values at the leaves. Let's say it also has a field called `make_leaf`. `make_leaf(value, lens)` is a function that takes any value, and a lens that is either `identity` or an `IndexLens`, and returns the value wrapped in some suitable struct that can be stored in the leaf of the NamedTuple tree. The values should always be such that `make_leaf(value, lens)[lens] == value`. + +Our earlier example of `VarNamedTuple(@varname(x) => 1, @varname(y.z) => 2; make_leaf=f)` would be stored as a tree like + +``` + --NT-- + x / \ y +f(1, identity) NT + \ z + f(2, identity) +``` + +The above, first draft of VNT which did not include `make_leaf` is equivalent to the trivial choice `make_leaf(value, lens) = lens === identity ? value : error("Don't know how to deal IndexLenses")`. The problems 2. and 3. above are "solved" by making it `make_leaf`'s problem to figure out what to do. For instance, `make_leaf` can always return a `Dict` that maps lenses to values. This is probably slow, but works for any lens. Or it can initialise a vector type, that can grow as needed when indexed into. + +The idea would be to use `make_leaf` to try out different ways of implementing a VarInfo, find a good default, and ,if necessary, leave the option for power users to customise behaviour. The first ones to implement would be + + - `make_leaf` that returns a Metadata object. This would be a direct replacement for type stable VarInfo that uses Metadata, except now with more nested levels of NamedTuple. + - `make_leaf` that returns an `OrderedDict`. This would be a direct replacement for SimpleVarInfo with OrderedDict. + +You may ask, have we simple gone from too many VarInfo types to too many `make_leaf` functions. Yes we have. But hopefully we have gained something in the process: + + - The leaf types can be simpler. They do not need to deal with VarNames any more, they only need to deal with `identity` lenses and `IndexLenses`. + - All AbstactVarInfos are as type stable as their leaf types allow. There is no more notion of an untyped VarInfo being converted to a typed one. + - Type stability is maintained even with nested `PropertyLenses` like `@varname(a.b)`, which happens a lot with submodels. + - Many functions that are currently implemented individually for each AbstactVarInfo type would now have a single implementation for the VarNamedTuple-based AbstactVarInfo type, reducing code duplication. I would also hope to get ride of most of the generated functions for in `varinfo.jl`. + +My guess is that the eventual One AbstractVarInfo To Rule Them All would have a `make_leaf` function that stores the raw values when the lens is an `identity`, and uses a flexible Vector, a lot like VarNamedVector, when the lens is an IndexLens. However, I could be wrong on that being the best option. Implementing and benchmarking is the only way to know. + +I think the two big questions are: + + - Will we run into some big, unanticipated blockers when we start to implement this. + - Will the nesting of NamedTuples cause performance regressions, if the compiler either chokes or gives up. + +I'll try to derisk these early on in this PR. + +## Questions / issues + + - People might really need IndexLenses in the middle of VarNames. The one place this comes up is submodels within a loop. I'm still inclined to keep designing without allowing for that, for now, but should keep in mind that that needs to be relaxed eventually. If it makes it easier, we can require that users explicitly tell us the size of any arrays for which this is done. + - When storing values for nested NamedTuples, the actual variable may be a struct. Do we need to be able to reconstruct the struct from the NamedTuple? If so, how do we do that? + - Do `Colon` indices cause any extra trouble for the leafnodes? diff --git a/src/varnamedtuple.jl b/src/varnamedtuple.jl new file mode 100644 index 000000000..448ae4636 --- /dev/null +++ b/src/varnamedtuple.jl @@ -0,0 +1,310 @@ +# TODO(mhauru) This module should probably be moved to AbstractPPL. +module VarNamedTuples + +using AbstractPPL +using BangBang +using Accessors +using DynamicPPL: _compose_no_identity + +export VarNamedTuple + +"""The factor by which we increase the dimensions of PartialArrays when resizing them.""" +const PARTIAL_ARRAY_DIM_GROWTH_FACTOR = 4 + +_has_colon(::IndexLens{T}) where {T} = any(x <: Colon for x in T.parameters) + +function _is_multiindex(::IndexLens{T}) where {T} + return any(x <: UnitRange || x <: Colon for x in T.parameters) +end + +struct VarNamedTuple{T<:Function,Names,Values} + data::NamedTuple{Names,Values} + make_leaf::T +end + +struct IndexDict{T<:Function,Keys,Values} + data::Dict{Keys,Values} + make_leaf::T +end + +struct PartialArray{T<:Function,ElType,numdims} + data::Array{ElType,numdims} + mask::Array{Bool,numdims} + make_leaf::T +end + +function PartialArray(eltype, num_dims, make_leaf) + dims = ntuple(_ -> PARTIAL_ARRAY_DIM_GROWTH_FACTOR, num_dims) + data = Array{eltype,num_dims}(undef, dims) + mask = fill(false, dims) + return PartialArray(data, mask, make_leaf) +end + +_length_needed(i::Integer) = i +_length_needed(r::UnitRange) = last(r) +_length_needed(::Colon) = 0 + +"""Take the minimum size that a dimension of a PartialArray needs to be, and return the size +we choose it to be. This size will be the smallest possible power of +PARTIAL_ARRAY_DIM_GROWTH_FACTOR. Growing PartialArrays in big jumps like this helps reduce +data copying, as resizes aren't needed as often. +""" +function _partial_array_dim_size(min_dim) + factor = PARTIAL_ARRAY_DIM_GROWTH_FACTOR + return factor^(Int(ceil(log(factor, min_dim)))) +end + +function _resize_partialarray(iarr::PartialArray, inds) + min_sizes = ntuple(i -> max(size(iarr.data, i), _length_needed(inds[i])), length(inds)) + new_sizes = map(_partial_array_dim_size, min_sizes) + # Generic multidimensional Arrays can not be resized, so we need to make a new one. + # See https://github.com/JuliaLang/julia/issues/37900 + new_data = Array{eltype(iarr.data),ndims(iarr.data)}(undef, new_sizes) + new_mask = fill(false, new_sizes) + # Note that we have to use CartesianIndices instead of eachindex, because the latter + # may use a linear index that does not match between the old and the new arrays. + for i in CartesianIndices(iarr.data) + mask_val = iarr.mask[i] + @inbounds new_mask[i] = mask_val + if mask_val + @inbounds new_data[i] = iarr.data[i] + end + end + return PartialArray(new_data, new_mask, iarr.make_leaf) +end + +# The below implements the same functionality as above, but more performantly for 1D arrays. +function _resize_partialarray(iarr::PartialArray{T,Eltype,1}, (ind,)) where {T,Eltype} + # Resize arrays to accommodate new indices. + old_size = size(iarr.data, 1) + min_size = max(old_size, _length_needed(ind)) + new_size = _partial_array_dim_size(min_size) + resize!(iarr.data, new_size) + resize!(iarr.mask, new_size) + @inbounds iarr.mask[(old_size + 1):new_size] .= false + return iarr +end + +function BangBang.setindex!!(iarr::PartialArray, value, optic::IndexLens) + if _has_colon(optic) + # TODO(mhauru) This could be implemented by getting size information from `value`. + # However, the corresponding getindex is more fundamentally ill-defined. + throw(ArgumentError("Indexing with colons is not supported")) + end + inds = optic.indices + if length(inds) != ndims(iarr.data) + throw(ArgumentError("Invalid index $(inds)")) + end + iarr = if checkbounds(Bool, iarr.mask, inds...) + iarr + else + _resize_partialarray(iarr, inds) + end + new_data = setindex!!(iarr.data, value, inds...) + if _is_multiindex(optic) + iarr.mask[inds...] .= true + else + iarr.mask[inds...] = true + end + return PartialArray(new_data, iarr.mask, iarr.make_leaf) +end + +function Base.getindex(iarr::PartialArray, optic::IndexLens) + if _has_colon(optic) + throw(ArgumentError("Indexing with colons is not supported")) + end + inds = optic.indices + if length(inds) != ndims(iarr.data) + throw(ArgumentError("Invalid index $(inds)")) + end + if !haskey(iarr, optic) + throw(BoundsError(iarr, inds)) + end + return getindex(iarr.data, inds...) +end + +function Base.haskey(iarr::PartialArray, optic::IndexLens) + if _has_colon(optic) + throw(ArgumentError("Indexing with colons is not supported")) + end + inds = optic.indices + return checkbounds(Bool, iarr.mask, inds...) && + all(@inbounds(getindex(iarr.mask, inds...))) +end + +function make_leaf_array(value, ::PropertyLens{S}) where {S} + return VarNamedTuple(NamedTuple{(S,)}((value,)), make_leaf_array) +end +make_leaf_array(value, ::typeof(identity)) = value +function make_leaf_array(value, optic::ComposedFunction) + sub = make_leaf_array(value, optic.outer) + return make_leaf_array(sub, optic.inner) +end + +function make_leaf_array(value, optic::IndexLens{T}) where {T} + inds = optic.indices + num_inds = length(inds) + # Check if any of the indices are ranges or colons. If yes, value needs to be an + # AbstractArray. Otherwise it needs to be an individual value. + et = _is_multiindex(optic) ? eltype(value) : typeof(value) + iarr = PartialArray(et, num_inds, make_leaf_array) + return setindex!!(iarr, value, optic) +end + +function make_leaf_dict(value, ::PropertyLens{S}) where {S} + return VarNamedTuple(NamedTuple{(S,)}((value,)), make_leaf_dict) +end +make_leaf_dict(value, ::typeof(identity)) = value +function make_leaf_dict(value, optic::ComposedFunction) + sub = make_leaf_dict(value, optic.outer) + return make_leaf_dict(sub, optic.inner) +end +function make_leaf_dict(value, optic::IndexLens) + return IndexDict(Dict(optic.indices => value), make_leaf_dict) +end + +VarNamedTuple() = VarNamedTuple((;), make_leaf_array) + +function Base.show(io::IO, vnt::VarNamedTuple) + print(io, "(") + for (i, (name, value)) in enumerate(pairs(vnt.data)) + if i > 1 + print(io, ", ") + end + print(io, name, " -> ") + print(io, value) + end + return print(io, ")") +end + +function Base.show(io::IO, id::IndexDict) + return print(io, id.data) +end + +Base.getindex(vnt::VarNamedTuple, name::Symbol) = vnt.data[name] + +function varname_to_lens(name::VarName{S}) where {S} + return _compose_no_identity(getoptic(name), PropertyLens{S}()) +end + +function Base.getindex(vnt::VarNamedTuple, name::VarName) + return getindex(vnt, varname_to_lens(name)) +end +function Base.getindex( + x::Union{VarNamedTuple,IndexDict,PartialArray}, optic::ComposedFunction +) + subdata = getindex(x, optic.inner) + return getindex(subdata, optic.outer) +end +function Base.getindex(vnt::VarNamedTuple, ::PropertyLens{S}) where {S} + return getindex(vnt.data, S) +end +function Base.getindex(id::IndexDict, optic::IndexLens) + return getindex(id.data, optic.indices) +end + +function Base.haskey(vnt::VarNamedTuple, name::VarName) + return haskey(vnt, varname_to_lens(name)) +end + +Base.haskey(vnt::VarNamedTuple, ::typeof(identity)) = true + +function Base.haskey(vnt::VarNamedTuple, optic::ComposedFunction) + return haskey(vnt, optic.inner) && haskey(getindex(vnt, optic.inner), optic.outer) +end + +Base.haskey(vnt::VarNamedTuple, ::PropertyLens{S}) where {S} = haskey(vnt.data, S) +Base.haskey(id::IndexDict, optic::IndexLens) = haskey(id.data, optic.indices) +Base.haskey(::VarNamedTuple, ::IndexLens) = false +Base.haskey(::IndexDict, ::PropertyLens) = false + +# TODO(mhauru) This is type piracy. +Base.getindex(arr::AbstractArray, optic::IndexLens) = getindex(arr, optic.indices...) + +# TODO(mhauru) This is type piracy. +function BangBang.setindex!!(arr::AbstractArray, value, optic::IndexLens) + return BangBang.setindex!!(arr, value, optic.indices...) +end + +function BangBang.setindex!!(vnt::VarNamedTuple, value, name::VarName) + return BangBang.setindex!!(vnt, value, varname_to_lens(name)) +end + +function BangBang.setindex!!( + vnt::Union{VarNamedTuple,IndexDict,PartialArray}, value, optic::ComposedFunction +) + sub = if haskey(vnt, optic.inner) + BangBang.setindex!!(getindex(vnt, optic.inner), value, optic.outer) + else + vnt.make_leaf(value, optic.outer) + end + return BangBang.setindex!!(vnt, sub, optic.inner) +end + +function BangBang.setindex!!(vnt::VarNamedTuple, value, ::PropertyLens{S}) where {S} + # I would like this to just read + # return VarNamedTuple(BangBang.setindex!!(vnt.data, value, S), vnt.make_leaf) + # but that seems to be type unstable. Why? Shouldn't it obviously be the same as the + # below? + return VarNamedTuple(merge(vnt.data, NamedTuple{(S,)}((value,))), vnt.make_leaf) +end + +function BangBang.setindex!!(id::IndexDict, value, optic::IndexLens) + return IndexDict(setindex!!(id.data, value, optic.indices), id.make_leaf) +end + +function apply(func, vnt::VarNamedTuple, name::VarName) + if !haskey(vnt.data, name.name) + throw(KeyError(repr(name))) + end + subdata = getindex(vnt, name) + new_subdata = func(subdata) + return BangBang.setindex!!(vnt, new_subdata, name) +end + +function Base.map(func, vnt::VarNamedTuple) + new_data = NamedTuple{keys(vnt.data)}(map(func, values(vnt.data))) + return VarNamedTuple(new_data, vnt.make_leaf) +end + +function Base.keys(vnt::VarNamedTuple) + result = () + for sym in keys(vnt.data) + subdata = vnt.data[sym] + if subdata isa VarNamedTuple + subkeys = keys(subdata) + result = ( + (AbstractPPL.prefix(sk, VarName{sym}()) for sk in subkeys)..., result... + ) + else + result = (VarName{sym}(), result...) + end + subkeys = keys(vnt.data[sym]) + end + return result +end + +function Base.haskey(vnt::VarNamedTuple, name::VarName{S,Optic}) where {S,Optic} + if !haskey(vnt.data, S) + return false + end + subdata = vnt.data[S] + return if Optic === typeof(identity) + true + elseif Optic <: IndexLens + try + AbstractPPL.getoptic(name)(subdata) + true + catch e + if e isa BoundsError || e isa KeyError + false + else + rethrow(e) + end + end + else + haskey(subdata, AbstractPPL.unprefix(name, VarName{S}())) + end +end + +end diff --git a/test/varnamedtuple.jl b/test/varnamedtuple.jl new file mode 100644 index 000000000..85b824ffc --- /dev/null +++ b/test/varnamedtuple.jl @@ -0,0 +1,89 @@ +module VarNamedTupleTests + +using Test: @inferred, @test, @test_throws, @testset +using DynamicPPL: @varname, VarNamedTuple +using BangBang: setindex!! + +@testset "Basic sets and gets" begin + vnt = VarNamedTuple() + vnt = @inferred(setindex!!(vnt, 32.0, @varname(a))) + @test @inferred(getindex(vnt, @varname(a))) == 32.0 + + vnt = @inferred(setindex!!(vnt, [1, 2, 3], @varname(b))) + @test @inferred(getindex(vnt, @varname(b))) == [1, 2, 3] + @test @inferred(getindex(vnt, @varname(b[2]))) == 2 + + vnt = @inferred(setindex!!(vnt, 64.0, @varname(a))) + @test @inferred(getindex(vnt, @varname(a))) == 64.0 + @test @inferred(getindex(vnt, @varname(b))) == [1, 2, 3] + + vnt = @inferred(setindex!!(vnt, 15, @varname(b[2]))) + @test @inferred(getindex(vnt, @varname(b))) == [1, 15, 3] + @test @inferred(getindex(vnt, @varname(b[2]))) == 15 + + vnt = @inferred(setindex!!(vnt, [10], @varname(c.x.y))) + @test @inferred(getindex(vnt, @varname(c.x.y))) == [10] + + vnt = @inferred(setindex!!(vnt, 11, @varname(c.x.y[1]))) + @test @inferred(getindex(vnt, @varname(c.x.y))) == [11] + @test @inferred(getindex(vnt, @varname(c.x.y[1]))) == 11 + + vnt = @inferred(setindex!!(vnt, -1.0, @varname(d[4]))) + @test @inferred(getindex(vnt, @varname(d[4]))) == -1.0 + + vnt = @inferred(setindex!!(vnt, -2.0, @varname(d[4]))) + @test @inferred(getindex(vnt, @varname(d[4]))) == -2.0 + + # These can't be @inferred because `d` now has an abstract element type. Note that this + # does not ruin type stability for other varnames that don't involve `d`. + vnt = setindex!!(vnt, "a", @varname(d[5])) + @test getindex(vnt, @varname(d[5])) == "a" + + vnt = @inferred(setindex!!(vnt, 1.0, @varname(e.f[3].g.h[2].i))) + @test @inferred(getindex(vnt, @varname(e.f[3].g.h[2].i))) == 1.0 + + vnt = @inferred(setindex!!(vnt, 2.0, @varname(e.f[3].g.h[2].i))) + @test @inferred(getindex(vnt, @varname(e.f[3].g.h[2].i))) == 2.0 + + vec = fill(1.0, 4) + vnt = @inferred(setindex!!(vnt, vec, @varname(j[1:4]))) + @test @inferred(getindex(vnt, @varname(j[1:4]))) == vec + @test @inferred(getindex(vnt, @varname(j[2]))) == vec[2] + @test haskey(vnt, @varname(j[4])) + @test !haskey(vnt, @varname(j[5])) + @test_throws BoundsError getindex(vnt, @varname(j[5])) + + vec = fill(2.0, 4) + vnt = @inferred(setindex!!(vnt, vec, @varname(j[2:5]))) + @test @inferred(getindex(vnt, @varname(j[1]))) == 1.0 + @test @inferred(getindex(vnt, @varname(j[2:5]))) == vec + @test haskey(vnt, @varname(j[5])) + + arr = fill(2.0, (4, 2)) + vn = @varname(k.l[2:5, 3, 1:2, 2]) + vnt = @inferred(setindex!!(vnt, arr, vn)) + @test @inferred(getindex(vnt, vn)) == arr + # A subset of the elements set just now. + @test @inferred(getindex(vnt, @varname(k.l[2, 3, 1:2, 2]))) == fill(2.0, 2) + + # Not enough, or too many, indices. + @test_throws "Invalid index" setindex!!(vnt, 0.0, @varname(k.l[1, 2, 3])) + @test_throws "Invalid index" setindex!!(vnt, 0.0, @varname(k.l[1, 2, 3, 4, 5])) + + arr = fill(3.0, (3, 3)) + vn = @varname(k.l[1, 1:3, 1:3, 1]) + vnt = @inferred(setindex!!(vnt, arr, vn)) + @test @inferred(getindex(vnt, vn)) == arr + # A subset of the elements set just now. + @test @inferred(getindex(vnt, @varname(k.l[1, 1:2, 1:2, 1]))) == fill(3.0, 2, 2) + # A subset of the elements set previously. + @test @inferred(getindex(vnt, @varname(k.l[2, 3, 1:2, 2]))) == fill(2.0, 2) + @test !haskey(vnt, @varname(k.l[2, 3, 3, 2])) + + vnt = @inferred(setindex!!(vnt, 1.0, @varname(m[2]))) + vnt = @inferred(setindex!!(vnt, 1.0, @varname(m[3]))) + @test @inferred(getindex(vnt, @varname(m[2:3]))) == [1.0, 1.0] + @test !haskey(vnt, @varname(m[1])) +end + +end From 0c7825bd9b80494459393ea6b7349885d8c2e29c Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Thu, 20 Nov 2025 11:58:22 +0000 Subject: [PATCH 11/45] Add comparisons and merge --- src/varnamedtuple.jl | 179 +++++++++++++++++++++++++++++++++++++----- test/varnamedtuple.jl | 120 ++++++++++++++++++++++++++++ 2 files changed, 278 insertions(+), 21 deletions(-) diff --git a/src/varnamedtuple.jl b/src/varnamedtuple.jl index 448ae4636..006e8f0d5 100644 --- a/src/varnamedtuple.jl +++ b/src/varnamedtuple.jl @@ -4,16 +4,18 @@ module VarNamedTuples using AbstractPPL using BangBang using Accessors -using DynamicPPL: _compose_no_identity +using ..DynamicPPL: _compose_no_identity export VarNamedTuple """The factor by which we increase the dimensions of PartialArrays when resizing them.""" const PARTIAL_ARRAY_DIM_GROWTH_FACTOR = 4 -_has_colon(::IndexLens{T}) where {T} = any(x <: Colon for x in T.parameters) +const INDEX_TYPES = Union{Integer,UnitRange,Colon} -function _is_multiindex(::IndexLens{T}) where {T} +_has_colon(::T) where {T<:Tuple} = any(x <: Colon for x in T.parameters) + +function _is_multiindex(::T) where {T<:Tuple} return any(x <: UnitRange || x <: Colon for x in T.parameters) end @@ -22,6 +24,12 @@ struct VarNamedTuple{T<:Function,Names,Values} make_leaf::T end +# TODO(mhauru) Since I define this, should I also define `isequal` and `hash`? Same for +# PartialArrays. +function Base.:(==)(vnt1::VarNamedTuple, vnt2::VarNamedTuple) + return vnt1.make_leaf === vnt2.make_leaf && vnt1.data == vnt2.data +end + struct IndexDict{T<:Function,Keys,Values} data::Dict{Keys,Values} make_leaf::T @@ -33,13 +41,44 @@ struct PartialArray{T<:Function,ElType,numdims} make_leaf::T end -function PartialArray(eltype, num_dims, make_leaf) +function PartialArray(eltype, num_dims, make_leaf=make_leaf_array) dims = ntuple(_ -> PARTIAL_ARRAY_DIM_GROWTH_FACTOR, num_dims) data = Array{eltype,num_dims}(undef, dims) mask = fill(false, dims) return PartialArray(data, mask, make_leaf) end +Base.ndims(iarr::PartialArray) = ndims(iarr.data) + +# We deliberately don't define Base.size for PartialArray, because it is ill-defined. +# The size of the .data field is an implementation detail. +_internal_size(iarr::PartialArray, args...) = size(iarr.data, args...) + +function Base.copy(pa::PartialArray) + return PartialArray(copy(pa.data), copy(pa.mask), pa.make_leaf) +end + +function Base.:(==)(pa1::PartialArray, pa2::PartialArray) + if (pa1.make_leaf !== pa2.make_leaf) || (ndims(pa1) != ndims(pa2)) + return false + end + size1 = _internal_size(pa1) + size2 = _internal_size(pa2) + # TODO(mhauru) This could be optimised, but not sure it's worth it. + merge_size = ntuple(i -> max(size1[i], size2[i]), ndims(pa1)) + for i in CartesianIndices(merge_size) + m1 = checkbounds(Bool, pa1.mask, Tuple(i)...) ? pa1.mask[i] : false + m2 = checkbounds(Bool, pa2.mask, Tuple(i)...) ? pa2.mask[i] : false + if m1 != m2 + return false + end + if m1 && (pa1.data[i] != pa2.data[i]) + return false + end + end + return true +end + _length_needed(i::Integer) = i _length_needed(r::UnitRange) = last(r) _length_needed(::Colon) = 0 @@ -55,11 +94,13 @@ function _partial_array_dim_size(min_dim) end function _resize_partialarray(iarr::PartialArray, inds) - min_sizes = ntuple(i -> max(size(iarr.data, i), _length_needed(inds[i])), length(inds)) + min_sizes = ntuple( + i -> max(_internal_size(iarr, i), _length_needed(inds[i])), length(inds) + ) new_sizes = map(_partial_array_dim_size, min_sizes) # Generic multidimensional Arrays can not be resized, so we need to make a new one. # See https://github.com/JuliaLang/julia/issues/37900 - new_data = Array{eltype(iarr.data),ndims(iarr.data)}(undef, new_sizes) + new_data = Array{eltype(iarr.data),ndims(iarr)}(undef, new_sizes) new_mask = fill(false, new_sizes) # Note that we have to use CartesianIndices instead of eachindex, because the latter # may use a linear index that does not match between the old and the new arrays. @@ -76,7 +117,7 @@ end # The below implements the same functionality as above, but more performantly for 1D arrays. function _resize_partialarray(iarr::PartialArray{T,Eltype,1}, (ind,)) where {T,Eltype} # Resize arrays to accommodate new indices. - old_size = size(iarr.data, 1) + old_size = _internal_size(iarr, 1) min_size = max(old_size, _length_needed(ind)) new_size = _partial_array_dim_size(min_size) resize!(iarr.data, new_size) @@ -85,14 +126,19 @@ function _resize_partialarray(iarr::PartialArray{T,Eltype,1}, (ind,)) where {T,E return iarr end -function BangBang.setindex!!(iarr::PartialArray, value, optic::IndexLens) - if _has_colon(optic) +function BangBang.setindex!!(pa::PartialArray, value, optic::IndexLens) + return BangBang.setindex!!(pa, value, optic.indices...) +end +Base.getindex(pa::PartialArray, optic::IndexLens) = Base.getindex(pa, optic.indices...) +Base.haskey(pa::PartialArray, optic::IndexLens) = Base.haskey(pa, optic.indices) + +function BangBang.setindex!!(iarr::PartialArray, value, inds::Vararg{INDEX_TYPES}) + if _has_colon(inds) # TODO(mhauru) This could be implemented by getting size information from `value`. # However, the corresponding getindex is more fundamentally ill-defined. throw(ArgumentError("Indexing with colons is not supported")) end - inds = optic.indices - if length(inds) != ndims(iarr.data) + if length(inds) != ndims(iarr) throw(ArgumentError("Invalid index $(inds)")) end iarr = if checkbounds(Bool, iarr.mask, inds...) @@ -101,7 +147,7 @@ function BangBang.setindex!!(iarr::PartialArray, value, optic::IndexLens) _resize_partialarray(iarr, inds) end new_data = setindex!!(iarr.data, value, inds...) - if _is_multiindex(optic) + if _is_multiindex(inds) iarr.mask[inds...] .= true else iarr.mask[inds...] = true @@ -109,29 +155,105 @@ function BangBang.setindex!!(iarr::PartialArray, value, optic::IndexLens) return PartialArray(new_data, iarr.mask, iarr.make_leaf) end -function Base.getindex(iarr::PartialArray, optic::IndexLens) - if _has_colon(optic) +function Base.getindex(iarr::PartialArray, inds::Vararg{INDEX_TYPES}) + if _has_colon(inds) throw(ArgumentError("Indexing with colons is not supported")) end - inds = optic.indices - if length(inds) != ndims(iarr.data) + if length(inds) != ndims(iarr) throw(ArgumentError("Invalid index $(inds)")) end - if !haskey(iarr, optic) + if !haskey(iarr, inds) throw(BoundsError(iarr, inds)) end return getindex(iarr.data, inds...) end -function Base.haskey(iarr::PartialArray, optic::IndexLens) - if _has_colon(optic) +function Base.haskey(iarr::PartialArray, inds) + if _has_colon(inds) throw(ArgumentError("Indexing with colons is not supported")) end - inds = optic.indices return checkbounds(Bool, iarr.mask, inds...) && all(@inbounds(getindex(iarr.mask, inds...))) end +Base.merge(x1::PartialArray, x2::PartialArray) = _merge_recursive(x1, x2) +Base.merge(x1::VarNamedTuple, x2::VarNamedTuple) = _merge_recursive(x1, x2) +_merge_recursive(_, x2) = x2 + +function _merge_element_recursive(x1::PartialArray, x2::PartialArray, ind::CartesianIndex) + m1 = x1.mask[ind] + m2 = x2.mask[ind] + return if m1 && m2 + _merge_recursive(x1.data[ind], x2.data[ind]) + elseif m2 + x2.data[ind] + else + x1.data[ind] + end +end + +# TODO(mhauru) Would this benefit from a specialised method for 1D PartialArrays? +function _merge_recursive(pa1::PartialArray, pa2::PartialArray) + if ndims(pa1) != ndims(pa2) + throw( + ArgumentError("Cannot merge PartialArrays with different number of dimensions") + ) + end + if pa1.make_leaf !== pa2.make_leaf + throw( + ArgumentError("Cannot merge PartialArrays with different make_leaf functions") + ) + end + num_dims = ndims(pa1) + merge_size = ntuple(i -> max(_internal_size(pa1, i), _internal_size(pa2, i)), num_dims) + result = if merge_size == _internal_size(pa2) + # Either pa2 is strictly bigger than pa1, or they are equal in size. + result = copy(pa2) + for i in CartesianIndices(pa1.data) + @inbounds if pa1.mask[i] + result = setindex!!( + result, _merge_element_recursive(pa1, result, i), Tuple(i)... + ) + end + end + result + else + if merge_size == _internal_size(pa1) + # pa1 is bigger than pa2 + result = copy(pa1) + for i in CartesianIndices(pa2.data) + @inbounds if pa2.mask[i] + result = setindex!!( + result, _merge_element_recursive(result, pa2, i), Tuple(i)... + ) + end + end + result + else + # Neither is strictly bigger than the other. + et = promote_type(eltype(pa1), eltype(pa2)) + new_data = Array{et,num_dims}(undef, merge_size) + new_mask = fill(false, merge_size) + result = PartialArray(new_data, new_mask, pa2.make_leaf) + for i in CartesianIndices(pa2.data) + @inbounds if pa2.mask[i] + result.mask[i] = true + result.data[i] = pa2.data[i] + end + end + for i in CartesianIndices(pa1.data) + @inbounds if pa1.mask[i] + result = setindex!!( + result, _merge_element_recursive(pa1, result, i), Tuple(i)... + ) + end + end + result + end + end + return result +end + function make_leaf_array(value, ::PropertyLens{S}) where {S} return VarNamedTuple(NamedTuple{(S,)}((value,)), make_leaf_array) end @@ -146,7 +268,7 @@ function make_leaf_array(value, optic::IndexLens{T}) where {T} num_inds = length(inds) # Check if any of the indices are ranges or colons. If yes, value needs to be an # AbstractArray. Otherwise it needs to be an individual value. - et = _is_multiindex(optic) ? eltype(value) : typeof(value) + et = _is_multiindex(optic.indices) ? eltype(value) : typeof(value) iarr = PartialArray(et, num_inds, make_leaf_array) return setindex!!(iarr, value, optic) end @@ -307,4 +429,19 @@ function Base.haskey(vnt::VarNamedTuple, name::VarName{S,Optic}) where {S,Optic} end end +# TODO(mhauru) Check the performance of this, and make it into a generated function if +# necessary. +function _merge_recursive(vnt1::VarNamedTuple, vnt2::VarNamedTuple) + result_data = vnt1.data + for k in keys(vnt2.data) + val = if haskey(result_data, k) + _merge_recursive(result_data[k], vnt2.data[k]) + else + vnt2.data[k] + end + Accessors.@reset result_data[k] = val + end + return VarNamedTuple(result_data, vnt2.make_leaf) +end + end diff --git a/test/varnamedtuple.jl b/test/varnamedtuple.jl index 85b824ffc..f9864e7be 100644 --- a/test/varnamedtuple.jl +++ b/test/varnamedtuple.jl @@ -86,4 +86,124 @@ using BangBang: setindex!! @test !haskey(vnt, @varname(m[1])) end +@testset "equality" begin + vnt1 = VarNamedTuple() + vnt2 = VarNamedTuple() + @test vnt1 == vnt2 + + vnt1 = setindex!!(vnt1, 1.0, @varname(a)) + @test vnt1 != vnt2 + + vnt2 = setindex!!(vnt2, 1.0, @varname(a)) + @test vnt1 == vnt2 + + vnt1 = setindex!!(vnt1, [1, 2], @varname(b)) + vnt2 = setindex!!(vnt2, [1, 2], @varname(b)) + @test vnt1 == vnt2 + + vnt2 = setindex!!(vnt2, [1, 3], @varname(b)) + @test vnt1 != vnt2 + vnt2 = setindex!!(vnt2, [1, 2], @varname(b)) + + # Try with index lenses too + vnt1 = setindex!!(vnt1, 2, @varname(c[2])) + vnt2 = setindex!!(vnt2, 2, @varname(c[2])) + @test vnt1 == vnt2 + + vnt2 = setindex!!(vnt2, 3, @varname(c[2])) + @test vnt1 != vnt2 + vnt2 = setindex!!(vnt2, 2, @varname(c[2])) + + vnt1 = setindex!!(vnt1, ["a", "b"], @varname(d.e[1:2])) + vnt2 = setindex!!(vnt2, ["a", "b"], @varname(d.e[1:2])) + @test vnt1 == vnt2 + + vnt2 = setindex!!(vnt2, :b, @varname(d.e[2])) + @test vnt1 != vnt2 +end + +@testset "merge" begin + vnt1 = VarNamedTuple() + vnt2 = VarNamedTuple() + expected_merge = VarNamedTuple() + # TODO(mhauru) Wrap this merge in @inferred, likewise other merges where it makes sense. + @test merge(vnt1, vnt2) == expected_merge + + vnt1 = setindex!!(vnt1, 1.0, @varname(a)) + vnt2 = setindex!!(vnt2, 2.0, @varname(b)) + vnt1 = setindex!!(vnt1, 1, @varname(c)) + vnt2 = setindex!!(vnt2, 2, @varname(c)) + expected_merge = setindex!!(expected_merge, 1.0, @varname(a)) + expected_merge = setindex!!(expected_merge, 2, @varname(c)) + expected_merge = setindex!!(expected_merge, 2.0, @varname(b)) + @test merge(vnt1, vnt2) == expected_merge + + vnt1 = VarNamedTuple() + vnt2 = VarNamedTuple() + expected_merge = VarNamedTuple() + vnt1 = setindex!!(vnt1, [1], @varname(d.a)) + vnt2 = setindex!!(vnt2, [2, 2], @varname(d.b)) + vnt1 = setindex!!(vnt1, [1], @varname(d.c)) + vnt2 = setindex!!(vnt2, [2, 2], @varname(d.c)) + expected_merge = setindex!!(expected_merge, [1], @varname(d.a)) + expected_merge = setindex!!(expected_merge, [2, 2], @varname(d.c)) + expected_merge = setindex!!(expected_merge, [2, 2], @varname(d.b)) + @test merge(vnt1, vnt2) == expected_merge + + vnt1 = setindex!!(vnt1, 1, @varname(e.a[1])) + vnt2 = setindex!!(vnt2, 2, @varname(e.a[2])) + expected_merge = setindex!!(expected_merge, 1, @varname(e.a[1])) + expected_merge = setindex!!(expected_merge, 2, @varname(e.a[2])) + vnt1 = setindex!!(vnt1, 1, @varname(e.a[3])) + vnt2 = setindex!!(vnt2, 2, @varname(e.a[3])) + expected_merge = setindex!!(expected_merge, 2, @varname(e.a[3])) + @test merge(vnt1, vnt2) == expected_merge + + vnt1 = setindex!!(vnt1, fill(1, 4), @varname(e.a[7:10])) + vnt2 = setindex!!(vnt2, fill(2, 4), @varname(e.a[8:11])) + expected_merge = setindex!!(expected_merge, 1, @varname(e.a[7])) + expected_merge = setindex!!(expected_merge, fill(2, 4), @varname(e.a[8:11])) + @test merge(vnt1, vnt2) == expected_merge + + vnt1 = setindex!!(vnt1, ["1", "1"], @varname(f.a[1].b.c[2, 2].d[1, 3:4])) + vnt2 = setindex!!(vnt2, ["2", "2"], @varname(f.a[1].b.c[2, 2].d[1, 3:4])) + expected_merge = setindex!!( + expected_merge, ["2", "2"], @varname(f.a[1].b.c[2, 2].d[1, 3:4]) + ) + vnt1 = setindex!!(vnt1, :1, @varname(f.a[1].b.c[3, 2].d[1, 1])) + vnt2 = setindex!!(vnt2, :2, @varname(f.a[1].b.c[4, 2].d[1, 1])) + expected_merge = setindex!!(expected_merge, :1, @varname(f.a[1].b.c[3, 2].d[1, 1])) + expected_merge = setindex!!(expected_merge, :2, @varname(f.a[1].b.c[4, 2].d[1, 1])) + @test merge(vnt1, vnt2) == expected_merge + + # PartialArrays with different sizes. + vnt1 = VarNamedTuple() + vnt2 = VarNamedTuple() + vnt1 = setindex!!(vnt1, 1, @varname(a[1])) + vnt1 = setindex!!(vnt1, 1, @varname(a[1025])) + vnt2 = setindex!!(vnt2, 2, @varname(a[1])) + vnt2 = setindex!!(vnt2, 2, @varname(a[2])) + expected_merge_12 = VarNamedTuple() + expected_merge_12 = setindex!!(expected_merge_12, 1, @varname(a[1025])) + expected_merge_12 = setindex!!(expected_merge_12, 2, @varname(a[1])) + expected_merge_12 = setindex!!(expected_merge_12, 2, @varname(a[2])) + @test merge(vnt1, vnt2) == expected_merge_12 + expected_merge_21 = setindex!!(expected_merge_12, 1, @varname(a[1])) + @test merge(vnt2, vnt1) == expected_merge_21 + + vnt1 = VarNamedTuple() + vnt2 = VarNamedTuple() + vnt1 = setindex!!(vnt1, 1, @varname(a[1, 1])) + vnt1 = setindex!!(vnt1, 1, @varname(a[1025, 1])) + vnt2 = setindex!!(vnt2, :2, @varname(a[1, 1])) + vnt2 = setindex!!(vnt2, :2, @varname(a[1, 1025])) + expected_merge_12 = VarNamedTuple() + expected_merge_12 = setindex!!(expected_merge_12, :2, @varname(a[1, 1])) + expected_merge_12 = setindex!!(expected_merge_12, 1, @varname(a[1025, 1])) + expected_merge_12 = setindex!!(expected_merge_12, :2, @varname(a[1, 1025])) + @test merge(vnt1, vnt2) == expected_merge_12 + expected_merge_21 = setindex!!(expected_merge_12, 1, @varname(a[1, 1])) + @test merge(vnt2, vnt1) == expected_merge_21 +end + end From 15d5a8a97795de35390706e858cf60a48cb17b76 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Thu, 20 Nov 2025 12:09:39 +0000 Subject: [PATCH 12/45] Start using VNT in FastLDF --- src/DynamicPPL.jl | 2 ++ src/contexts/init.jl | 16 +++------ src/fasteval.jl | 81 +++++++++++++------------------------------ test/fasteval.jl | 8 ++--- test/varnamedtuple.jl | 12 +++---- 5 files changed, 39 insertions(+), 80 deletions(-) diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index e9b902363..5f32a8b66 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -178,6 +178,8 @@ abstract type AbstractVarInfo <: AbstractModelTrace end # Necessary forward declarations include("utils.jl") +include("varnamedtuple.jl") +using .VarNamedTuples: VarNamedTuple include("contexts.jl") include("contexts/default.jl") include("contexts/init.jl") diff --git a/src/contexts/init.jl b/src/contexts/init.jl index a79969a13..a0ad92fe3 100644 --- a/src/contexts/init.jl +++ b/src/contexts/init.jl @@ -215,8 +215,7 @@ end """ VectorWithRanges( - iden_varname_ranges::NamedTuple, - varname_ranges::Dict{VarName,RangeAndLinked}, + varname_ranges::VarNamedTuple, vect::AbstractVector{<:Real}, ) @@ -231,20 +230,13 @@ non-identity-optic VarNames are stored in the `varname_ranges` Dict. It would be nice to improve the NamedTuple and Dict approach. See, e.g. https://github.com/TuringLang/DynamicPPL.jl/issues/1116. """ -struct VectorWithRanges{N<:NamedTuple,T<:AbstractVector{<:Real}} - # This NamedTuple stores the ranges for identity VarNames - iden_varname_ranges::N - # This Dict stores the ranges for all other VarNames - varname_ranges::Dict{VarName,RangeAndLinked} +struct VectorWithRanges{VNT<:VarNamedTuple,T<:AbstractVector{<:Real}} + # Ranges for all VarNames + varname_ranges::VNT # The full parameter vector which we index into to get variable values vect::T end -function _get_range_and_linked( - vr::VectorWithRanges, ::VarName{sym,typeof(identity)} -) where {sym} - return vr.iden_varname_ranges[sym] -end function _get_range_and_linked(vr::VectorWithRanges, vn::VarName) return vr.varname_ranges[vn] end diff --git a/src/fasteval.jl b/src/fasteval.jl index 4f402f4a8..b82180dca 100644 --- a/src/fasteval.jl +++ b/src/fasteval.jl @@ -13,6 +13,7 @@ using DynamicPPL: RangeAndLinked, VectorWithRanges, Metadata, + VarNamedTuple, VarNamedVector, default_accumulators, float_type_with_fallback, @@ -140,14 +141,13 @@ struct FastLDF{ M<:Model, AD<:Union{ADTypes.AbstractADType,Nothing}, F<:Function, - N<:NamedTuple, + VNT<:VarNamedTuple, ADP<:Union{Nothing,DI.GradientPrep}, } model::M adtype::AD _getlogdensity::F - _iden_varname_ranges::N - _varname_ranges::Dict{VarName,RangeAndLinked} + _varname_ranges::VNT _adprep::ADP _dim::Int @@ -159,7 +159,7 @@ struct FastLDF{ ) # Figure out which variable corresponds to which index, and # which variables are linked. - all_iden_ranges, all_ranges = get_ranges_and_linked(varinfo) + all_ranges = get_ranges_and_linked(varinfo) x = [val for val in varinfo[:]] dim = length(x) # Do AD prep if needed @@ -169,19 +169,17 @@ struct FastLDF{ # Make backend-specific tweaks to the adtype adtype = DynamicPPL.tweak_adtype(adtype, model, varinfo) DI.prepare_gradient( - FastLogDensityAt(model, getlogdensity, all_iden_ranges, all_ranges), - adtype, - x, + FastLogDensityAt(model, getlogdensity, all_ranges), adtype, x ) end return new{ typeof(model), typeof(adtype), typeof(getlogdensity), - typeof(all_iden_ranges), + typeof(all_ranges), typeof(prep), }( - model, adtype, getlogdensity, all_iden_ranges, all_ranges, prep, dim + model, adtype, getlogdensity, all_ranges, prep, dim ) end end @@ -206,18 +204,15 @@ end fast_ldf_accs(::typeof(getlogprior)) = AccumulatorTuple((LogPriorAccumulator(),)) fast_ldf_accs(::typeof(getloglikelihood)) = AccumulatorTuple((LogLikelihoodAccumulator(),)) -struct FastLogDensityAt{M<:Model,F<:Function,N<:NamedTuple} +struct FastLogDensityAt{M<:Model,F<:Function,VNT<:VarNamedTuple} model::M getlogdensity::F - iden_varname_ranges::N - varname_ranges::Dict{VarName,RangeAndLinked} + varname_ranges::VNT end function (f::FastLogDensityAt)(params::AbstractVector{<:Real}) ctx = InitContext( Random.default_rng(), - InitFromParams( - VectorWithRanges(f.iden_varname_ranges, f.varname_ranges, params), nothing - ), + InitFromParams(VectorWithRanges(f.varname_ranges, params), nothing), ) model = DynamicPPL.setleafcontext(f.model, ctx) accs = fast_ldf_accs(f.getlogdensity) @@ -242,20 +237,14 @@ function (f::FastLogDensityAt)(params::AbstractVector{<:Real}) end function LogDensityProblems.logdensity(fldf::FastLDF, params::AbstractVector{<:Real}) - return FastLogDensityAt( - fldf.model, fldf._getlogdensity, fldf._iden_varname_ranges, fldf._varname_ranges - )( - params - ) + return FastLogDensityAt(fldf.model, fldf._getlogdensity, fldf._varname_ranges)(params) end function LogDensityProblems.logdensity_and_gradient( fldf::FastLDF, params::AbstractVector{<:Real} ) return DI.value_and_gradient( - FastLogDensityAt( - fldf.model, fldf._getlogdensity, fldf._iden_varname_ranges, fldf._varname_ranges - ), + FastLogDensityAt(fldf.model, fldf._getlogdensity, fldf._varname_ranges), fldf._adprep, fldf.adtype, params, @@ -291,62 +280,42 @@ end Given a `VarInfo`, extract the ranges of each variable in the vectorised parameter representation, along with whether each variable is linked or unlinked. -This function should return a tuple containing: - -- A NamedTuple mapping VarNames with identity optics to their corresponding `RangeAndLinked` -- A Dict mapping all other VarNames to their corresponding `RangeAndLinked`. +This function returns a VarNamedTuple mapping all VarNames to their corresponding +`RangeAndLinked`. """ function get_ranges_and_linked(varinfo::VarInfo{<:NamedTuple{syms}}) where {syms} - all_iden_ranges = NamedTuple() - all_ranges = Dict{VarName,RangeAndLinked}() + all_ranges = VarNamedTuple() offset = 1 for sym in syms md = varinfo.metadata[sym] - this_md_iden, this_md_others, offset = get_ranges_and_linked_metadata(md, offset) - all_iden_ranges = merge(all_iden_ranges, this_md_iden) + this_md_others, offset = get_ranges_and_linked_metadata(md, offset) all_ranges = merge(all_ranges, this_md_others) end - return all_iden_ranges, all_ranges + return all_ranges end function get_ranges_and_linked(varinfo::VarInfo{<:Union{Metadata,VarNamedVector}}) - all_iden, all_others, _ = get_ranges_and_linked_metadata(varinfo.metadata, 1) - return all_iden, all_others + all_ranges, _ = get_ranges_and_linked_metadata(varinfo.metadata, 1) + return all_ranges end function get_ranges_and_linked_metadata(md::Metadata, start_offset::Int) - all_iden_ranges = NamedTuple() - all_ranges = Dict{VarName,RangeAndLinked}() + all_ranges = VarNamedTuple() offset = start_offset for (vn, idx) in md.idcs is_linked = md.is_transformed[idx] range = md.ranges[idx] .+ (start_offset - 1) - if AbstractPPL.getoptic(vn) === identity - all_iden_ranges = merge( - all_iden_ranges, - NamedTuple((AbstractPPL.getsym(vn) => RangeAndLinked(range, is_linked),)), - ) - else - all_ranges[vn] = RangeAndLinked(range, is_linked) - end + all_ranges = BangBang.setindex!!(all_ranges, RangeAndLinked(range, is_linked), vn) offset += length(range) end - return all_iden_ranges, all_ranges, offset + return all_ranges, offset end function get_ranges_and_linked_metadata(vnv::VarNamedVector, start_offset::Int) - all_iden_ranges = NamedTuple() - all_ranges = Dict{VarName,RangeAndLinked}() + all_ranges = VarNamedTuple() offset = start_offset for (vn, idx) in vnv.varname_to_index is_linked = vnv.is_unconstrained[idx] range = vnv.ranges[idx] .+ (start_offset - 1) - if AbstractPPL.getoptic(vn) === identity - all_iden_ranges = merge( - all_iden_ranges, - NamedTuple((AbstractPPL.getsym(vn) => RangeAndLinked(range, is_linked),)), - ) - else - all_ranges[vn] = RangeAndLinked(range, is_linked) - end + all_ranges = BangBang.setindex!!(all_ranges, RangeAndLinked(range, is_linked), vn) offset += length(range) end - return all_iden_ranges, all_ranges, offset + return all_ranges, offset end diff --git a/test/fasteval.jl b/test/fasteval.jl index db2333711..2ad50ed26 100644 --- a/test/fasteval.jl +++ b/test/fasteval.jl @@ -36,17 +36,13 @@ end else unlinked_vi end - nt_ranges, dict_ranges = DynamicPPL.Experimental.get_ranges_and_linked(vi) + ranges = DynamicPPL.Experimental.get_ranges_and_linked(vi) params = [x for x in vi[:]] # Iterate over all variables for vn in keys(vi) # Check that `getindex_internal` returns the same thing as using the ranges # directly - range_with_linked = if AbstractPPL.getoptic(vn) === identity - nt_ranges[AbstractPPL.getsym(vn)] - else - dict_ranges[vn] - end + range_with_linked = ranges[vn] @test params[range_with_linked.range] == DynamicPPL.getindex_internal(vi, vn) # Check that the link status is correct diff --git a/test/varnamedtuple.jl b/test/varnamedtuple.jl index f9864e7be..99f528175 100644 --- a/test/varnamedtuple.jl +++ b/test/varnamedtuple.jl @@ -180,11 +180,11 @@ end vnt1 = VarNamedTuple() vnt2 = VarNamedTuple() vnt1 = setindex!!(vnt1, 1, @varname(a[1])) - vnt1 = setindex!!(vnt1, 1, @varname(a[1025])) + vnt1 = setindex!!(vnt1, 1, @varname(a[257])) vnt2 = setindex!!(vnt2, 2, @varname(a[1])) vnt2 = setindex!!(vnt2, 2, @varname(a[2])) expected_merge_12 = VarNamedTuple() - expected_merge_12 = setindex!!(expected_merge_12, 1, @varname(a[1025])) + expected_merge_12 = setindex!!(expected_merge_12, 1, @varname(a[257])) expected_merge_12 = setindex!!(expected_merge_12, 2, @varname(a[1])) expected_merge_12 = setindex!!(expected_merge_12, 2, @varname(a[2])) @test merge(vnt1, vnt2) == expected_merge_12 @@ -194,13 +194,13 @@ end vnt1 = VarNamedTuple() vnt2 = VarNamedTuple() vnt1 = setindex!!(vnt1, 1, @varname(a[1, 1])) - vnt1 = setindex!!(vnt1, 1, @varname(a[1025, 1])) + vnt1 = setindex!!(vnt1, 1, @varname(a[257, 1])) vnt2 = setindex!!(vnt2, :2, @varname(a[1, 1])) - vnt2 = setindex!!(vnt2, :2, @varname(a[1, 1025])) + vnt2 = setindex!!(vnt2, :2, @varname(a[1, 257])) expected_merge_12 = VarNamedTuple() expected_merge_12 = setindex!!(expected_merge_12, :2, @varname(a[1, 1])) - expected_merge_12 = setindex!!(expected_merge_12, 1, @varname(a[1025, 1])) - expected_merge_12 = setindex!!(expected_merge_12, :2, @varname(a[1, 1025])) + expected_merge_12 = setindex!!(expected_merge_12, 1, @varname(a[257, 1])) + expected_merge_12 = setindex!!(expected_merge_12, :2, @varname(a[1, 257])) @test merge(vnt1, vnt2) == expected_merge_12 expected_merge_21 = setindex!!(expected_merge_12, 1, @varname(a[1, 1])) @test merge(vnt2, vnt1) == expected_merge_21 From 871eb9fd1216f392460462d4c84d8a38ca89da05 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Thu, 20 Nov 2025 12:41:55 +0000 Subject: [PATCH 13/45] Move _compose_no_identity to utils.jl --- src/utils.jl | 16 ++++++++++++++++ src/varnamedvector.jl | 16 ---------------- 2 files changed, 16 insertions(+), 16 deletions(-) diff --git a/src/utils.jl b/src/utils.jl index 75fb805dc..fe2879182 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -949,3 +949,19 @@ end Return `typeof(x)` stripped of its type parameters. """ basetypeof(x::T) where {T} = Base.typename(T).wrapper + +# TODO(mhauru) Might add another specialisation to _compose_no_identity, where if +# ReshapeTransforms are composed with each other or with a an UnwrapSingeltonTransform, only +# the latter one would be kept. +""" + _compose_no_identity(f, g) + +Like `f ∘ g`, but if `f` or `g` is `identity` it is omitted. + +This helps avoid trivial cases of `ComposedFunction` that would cause unnecessary type +conflicts. +""" +_compose_no_identity(f, g) = f ∘ g +_compose_no_identity(::typeof(identity), g) = g +_compose_no_identity(f, ::typeof(identity)) = f +_compose_no_identity(::typeof(identity), ::typeof(identity)) = identity diff --git a/src/varnamedvector.jl b/src/varnamedvector.jl index 17b851d1d..e5d2f2c2e 100644 --- a/src/varnamedvector.jl +++ b/src/varnamedvector.jl @@ -1355,22 +1355,6 @@ function nextrange(vnv::VarNamedVector, x) return (offset + 1):(offset + length(x)) end -# TODO(mhauru) Might add another specialisation to _compose_no_identity, where if -# ReshapeTransforms are composed with each other or with a an UnwrapSingeltonTransform, only -# the latter one would be kept. -""" - _compose_no_identity(f, g) - -Like `f ∘ g`, but if `f` or `g` is `identity` it is omitted. - -This helps avoid trivial cases of `ComposedFunction` that would cause unnecessary type -conflicts. -""" -_compose_no_identity(f, g) = f ∘ g -_compose_no_identity(::typeof(identity), g) = g -_compose_no_identity(f, ::typeof(identity)) = f -_compose_no_identity(::typeof(identity), ::typeof(identity)) = identity - """ shift_right!(x::AbstractVector{<:Real}, start::Int, n::Int) From 4a1156038eb673dc9567d8a3a4d008455ec83908 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Sat, 22 Nov 2025 00:26:16 +0000 Subject: [PATCH 14/45] Allow generation of `ParamsWithStats` from `FastLDF` plus parameters, and also `bundle_samples` (#1129) * Implement `ParamsWithStats` for `FastLDF` * Add comments * Implement `bundle_samples` for ParamsWithStats -> MCMCChains * Remove redundant comment * don't need Statistics? --- ext/DynamicPPLMCMCChainsExt.jl | 37 ++++++++++++++++ src/DynamicPPL.jl | 2 +- src/chains.jl | 57 ++++++++++++++++++++++++ src/fasteval.jl | 81 ++++++++++++++++++++++++---------- test/chains.jl | 28 +++++++++++- 5 files changed, 180 insertions(+), 25 deletions(-) diff --git a/ext/DynamicPPLMCMCChainsExt.jl b/ext/DynamicPPLMCMCChainsExt.jl index d8c343917..e74f0b8a9 100644 --- a/ext/DynamicPPLMCMCChainsExt.jl +++ b/ext/DynamicPPLMCMCChainsExt.jl @@ -140,6 +140,43 @@ function AbstractMCMC.to_samples( end end +function AbstractMCMC.bundle_samples( + ts::Vector{<:DynamicPPL.ParamsWithStats}, + model::DynamicPPL.Model, + spl::AbstractMCMC.AbstractSampler, + state, + chain_type::Type{MCMCChains.Chains}; + save_state=false, + stats=missing, + sort_chain=false, + discard_initial=0, + thinning=1, + kwargs..., +) + bare_chain = AbstractMCMC.from_samples(MCMCChains.Chains, reshape(ts, :, 1)) + + # Add additional MCMC-specific info + info = bare_chain.info + if save_state + info = merge(info, (model=model, sampler=spl, samplerstate=state)) + end + if !ismissing(stats) + info = merge(info, (start_time=stats.start, stop_time=stats.stop)) + end + + # Reconstruct the chain with the extra information + # Yeah, this is quite ugly. Blame MCMCChains. + chain = MCMCChains.Chains( + bare_chain.value.data, + names(bare_chain), + bare_chain.name_map; + info=info, + start=discard_initial + 1, + thin=thinning, + ) + return sort_chain ? sort(chain) : chain +end + """ predict([rng::AbstractRNG,] model::Model, chain::MCMCChains.Chains; include_all=false) diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index e9b902363..6d3900e91 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -202,6 +202,7 @@ include("logdensityfunction.jl") include("model_utils.jl") include("extract_priors.jl") include("values_as_in_model.jl") +include("experimental.jl") include("chains.jl") include("bijector.jl") @@ -209,7 +210,6 @@ include("debug_utils.jl") using .DebugUtils include("test_utils.jl") -include("experimental.jl") include("deprecated.jl") if isdefined(Base.Experimental, :register_error_hint) diff --git a/src/chains.jl b/src/chains.jl index 2b5976b9b..892423822 100644 --- a/src/chains.jl +++ b/src/chains.jl @@ -133,3 +133,60 @@ function ParamsWithStats( end return ParamsWithStats(params, stats) end + +""" + ParamsWithStats( + param_vector::AbstractVector, + ldf::DynamicPPL.Experimental.FastLDF, + stats::NamedTuple=NamedTuple(); + include_colon_eq::Bool=true, + include_log_probs::Bool=true, + ) + +Generate a `ParamsWithStats` by re-evaluating the given `ldf` with the provided +`param_vector`. + +This method is intended to replace the old method of obtaining parameters and statistics +via `unflatten` plus re-evaluation. It is faster for two reasons: + +1. It does not rely on `deepcopy`-ing the VarInfo object (this used to be mandatory as + otherwise re-evaluation would mutate the VarInfo, rendering it unusable for subsequent + MCMC iterations). +2. The re-evaluation is faster as it uses `OnlyAccsVarInfo`. +""" +function ParamsWithStats( + param_vector::AbstractVector, + ldf::DynamicPPL.Experimental.FastLDF, + stats::NamedTuple=NamedTuple(); + include_colon_eq::Bool=true, + include_log_probs::Bool=true, +) + strategy = InitFromParams( + VectorWithRanges(ldf._iden_varname_ranges, ldf._varname_ranges, param_vector), + nothing, + ) + accs = if include_log_probs + ( + DynamicPPL.LogPriorAccumulator(), + DynamicPPL.LogLikelihoodAccumulator(), + DynamicPPL.ValuesAsInModelAccumulator(include_colon_eq), + ) + else + (DynamicPPL.ValuesAsInModelAccumulator(include_colon_eq),) + end + _, vi = DynamicPPL.Experimental.fast_evaluate!!( + ldf.model, strategy, AccumulatorTuple(accs) + ) + params = DynamicPPL.getacc(vi, Val(:ValuesAsInModel)).values + if include_log_probs + stats = merge( + stats, + ( + logprior=DynamicPPL.getlogprior(vi), + loglikelihood=DynamicPPL.getloglikelihood(vi), + lp=DynamicPPL.getlogjoint(vi), + ), + ) + end + return ParamsWithStats(params, stats) +end diff --git a/src/fasteval.jl b/src/fasteval.jl index 4f402f4a8..722760fa1 100644 --- a/src/fasteval.jl +++ b/src/fasteval.jl @@ -3,6 +3,7 @@ using DynamicPPL: AccumulatorTuple, InitContext, InitFromParams, + AbstractInitStrategy, LogJacobianAccumulator, LogLikelihoodAccumulator, LogPriorAccumulator, @@ -28,6 +29,60 @@ using LogDensityProblems: LogDensityProblems import DifferentiationInterface as DI using Random: Random +""" + DynamicPPL.Experimental.fast_evaluate!!( + [rng::Random.AbstractRNG,] + model::Model, + strategy::AbstractInitStrategy, + accs::AccumulatorTuple, params::AbstractVector{<:Real} + ) + +Evaluate a model using parameters obtained via `strategy`, and only computing the results in +the provided accumulators. + +It is assumed that the accumulators passed in have been initialised to appropriate values, +as this function will not reset them. The default constructors for each accumulator will do +this for you correctly. + +Returns a tuple of the model's return value, plus an `OnlyAccsVarInfo`. Note that the `accs` +argument may be mutated (depending on how the accumulators are implemented); hence the `!!` +in the function name. +""" +@inline function fast_evaluate!!( + # Note that this `@inline` is mandatory for performance. If it's not inlined, it leads + # to extra allocations (even for trivial models) and much slower runtime. + rng::Random.AbstractRNG, + model::Model, + strategy::AbstractInitStrategy, + accs::AccumulatorTuple, +) + ctx = InitContext(rng, strategy) + model = DynamicPPL.setleafcontext(model, ctx) + # Calling `evaluate!!` would be fine, but would lead to an extra call to resetaccs!!, + # which is unnecessary. So we shortcircuit this by simply calling `_evaluate!!` + # directly. To preserve thread-safety we need to reproduce the ThreadSafeVarInfo logic + # here. + # TODO(penelopeysm): This should _not_ check Threads.nthreads(). I still don't know what + # it _should_ do, but this is wrong regardless. + # https://github.com/TuringLang/DynamicPPL.jl/issues/1086 + vi = if Threads.nthreads() > 1 + param_eltype = DynamicPPL.get_param_eltype(strategy) + accs = map(accs) do acc + DynamicPPL.convert_eltype(float_type_with_fallback(param_eltype), acc) + end + ThreadSafeVarInfo(OnlyAccsVarInfo(accs)) + else + OnlyAccsVarInfo(accs) + end + return DynamicPPL._evaluate!!(model, vi) +end +@inline function fast_evaluate!!( + model::Model, strategy::AbstractInitStrategy, accs::AccumulatorTuple +) + # This `@inline` is also mandatory for performance + return fast_evaluate!!(Random.default_rng(), model, strategy, accs) +end + """ FastLDF( model::Model, @@ -213,31 +268,11 @@ struct FastLogDensityAt{M<:Model,F<:Function,N<:NamedTuple} varname_ranges::Dict{VarName,RangeAndLinked} end function (f::FastLogDensityAt)(params::AbstractVector{<:Real}) - ctx = InitContext( - Random.default_rng(), - InitFromParams( - VectorWithRanges(f.iden_varname_ranges, f.varname_ranges, params), nothing - ), + strategy = InitFromParams( + VectorWithRanges(f.iden_varname_ranges, f.varname_ranges, params), nothing ) - model = DynamicPPL.setleafcontext(f.model, ctx) accs = fast_ldf_accs(f.getlogdensity) - # Calling `evaluate!!` would be fine, but would lead to an extra call to resetaccs!!, - # which is unnecessary. So we shortcircuit this by simply calling `_evaluate!!` - # directly. To preserve thread-safety we need to reproduce the ThreadSafeVarInfo logic - # here. - # TODO(penelopeysm): This should _not_ check Threads.nthreads(). I still don't know what - # it _should_ do, but this is wrong regardless. - # https://github.com/TuringLang/DynamicPPL.jl/issues/1086 - vi = if Threads.nthreads() > 1 - accs = map( - acc -> DynamicPPL.convert_eltype(float_type_with_fallback(eltype(params)), acc), - accs, - ) - ThreadSafeVarInfo(OnlyAccsVarInfo(accs)) - else - OnlyAccsVarInfo(accs) - end - _, vi = DynamicPPL._evaluate!!(model, vi) + _, vi = fast_evaluate!!(f.model, strategy, accs) return f.getlogdensity(vi) end diff --git a/test/chains.jl b/test/chains.jl index ab0ff4475..43b877d62 100644 --- a/test/chains.jl +++ b/test/chains.jl @@ -4,7 +4,7 @@ using DynamicPPL using Distributions using Test -@testset "ParamsWithStats" begin +@testset "ParamsWithStats from VarInfo" begin @model function f(z) x ~ Normal() y := x + 1 @@ -66,4 +66,30 @@ using Test end end +@testset "ParamsWithStats from FastLDF" begin + @testset "$(m.f)" for m in DynamicPPL.TestUtils.DEMO_MODELS + unlinked_vi = VarInfo(m) + @testset "$islinked" for islinked in (false, true) + vi = if islinked + DynamicPPL.link!!(unlinked_vi, m) + else + unlinked_vi + end + params = [x for x in vi[:]] + + # Get the ParamsWithStats using FastLDF + fldf = DynamicPPL.Experimental.FastLDF(m, getlogjoint, vi) + ps = ParamsWithStats(params, fldf) + + # Check that length of parameters is as expected + @test length(ps.params) == length(keys(vi)) + + # Iterate over all variables to check that their values match + for vn in keys(vi) + @test ps.params[vn] == vi[vn] + end + end + end +end + end # module From 766f6635903c401a79d3c2427dc60225f0053dad Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Tue, 25 Nov 2025 11:41:51 +0000 Subject: [PATCH 15/45] Make FastLDF the default (#1139) * Make FastLDF the default * Add miscellaneous LogDensityProblems tests * Use `init!!` instead of `fast_evaluate!!` * Rename files, rebalance tests --- HISTORY.md | 23 +- docs/src/api.md | 8 +- ext/DynamicPPLMarginalLogDensitiesExt.jl | 11 +- src/DynamicPPL.jl | 4 + src/chains.jl | 8 +- src/experimental.jl | 2 - src/fasteval.jl | 387 --------------- src/logdensityfunction.jl | 579 +++++++++++------------ src/model.jl | 54 ++- test/ad.jl | 137 ------ test/chains.jl | 8 +- test/fasteval.jl | 233 --------- test/logdensityfunction.jl | 263 ++++++++-- test/runtests.jl | 7 +- 14 files changed, 584 insertions(+), 1140 deletions(-) delete mode 100644 src/fasteval.jl delete mode 100644 test/ad.jl delete mode 100644 test/fasteval.jl diff --git a/HISTORY.md b/HISTORY.md index 0f0102ce4..91306c219 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -4,6 +4,17 @@ ### Breaking changes +#### Fast Log Density Functions + +This version provides a reimplementation of `LogDensityFunction` that provides performance improvements on the order of 2–10× for both model evaluation as well as automatic differentiation. +Exact speedups depend on the model size: larger models have less significant speedups because the bulk of the work is done in calls to `logpdf`. + +For more information about how this is accomplished, please see https://github.com/TuringLang/DynamicPPL.jl/pull/1113 as well as the `src/fasteval.jl` file, which contains extensive comments. + +As a result of this change, `LogDensityFunction` no longer stores a VarInfo inside it. +In general, if `ldf` is a `LogDensityFunction`, it is now only valid to access `ldf.model` and `ldf.adtype`. +If you were previously relying on this behaviour, you will need to store a VarInfo separately. + #### Parent and leaf contexts The `DynamicPPL.NodeTrait` function has been removed. @@ -24,18 +35,6 @@ Removed the method `returned(::Model, values, keys)`; please use `returned(::Mod The method `DynamicPPL.init` (for implementing `AbstractInitStrategy`) now has a different signature: it must return a tuple of the generated value, plus a transform function that maps it back to unlinked space. This is a generalisation of the previous behaviour, where `init` would always return an unlinked value (in effect forcing the transform to be the identity function). -### Other changes - -#### FastLDF - -Added `DynamicPPL.Experimental.FastLDF`, a version of `LogDensityFunction` that provides performance improvements on the order of 2–10× for both model evaluation as well as automatic differentiation. -Exact speedups depend on the model size: larger models have less significant speedups because the bulk of the work is done in calls to `logpdf`. - -Please note that `FastLDF` is currently considered internal and its API may change without warning. -We intend to replace `LogDensityFunction` with `FastLDF` in a release in the near future, but until then we recommend not using it. - -For more information about `FastLDF`, please see https://github.com/TuringLang/DynamicPPL.jl/pull/1113 as well as the `src/fasteval.jl` file, which contains extensive comments. - ## 0.38.9 Remove warning when using Enzyme as the AD backend. diff --git a/docs/src/api.md b/docs/src/api.md index e81f18dc7..adb476db5 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -66,6 +66,12 @@ The [LogDensityProblems.jl](https://github.com/tpapp/LogDensityProblems.jl) inte LogDensityFunction ``` +Internally, this is accomplished using [`init!!`](@ref) on: + +```@docs +OnlyAccsVarInfo +``` + ## Condition and decondition A [`Model`](@ref) can be conditioned on a set of observations with [`AbstractPPL.condition`](@ref) or its alias [`|`](@ref). @@ -510,7 +516,7 @@ The function `init!!` is used to initialise, or overwrite, values in a VarInfo. It is really a thin wrapper around using `evaluate!!` with an `InitContext`. ```@docs -DynamicPPL.init!! +init!! ``` To accomplish this, an initialisation _strategy_ is required, which defines how new values are to be obtained. diff --git a/ext/DynamicPPLMarginalLogDensitiesExt.jl b/ext/DynamicPPLMarginalLogDensitiesExt.jl index 2155fa161..8b3040757 100644 --- a/ext/DynamicPPLMarginalLogDensitiesExt.jl +++ b/ext/DynamicPPLMarginalLogDensitiesExt.jl @@ -6,8 +6,13 @@ using MarginalLogDensities: MarginalLogDensities # A thin wrapper to adapt a DynamicPPL.LogDensityFunction to the interface expected by # MarginalLogDensities. It's helpful to have a struct so that we can dispatch on its type # below. -struct LogDensityFunctionWrapper{L<:DynamicPPL.LogDensityFunction} +struct LogDensityFunctionWrapper{ + L<:DynamicPPL.LogDensityFunction,V<:DynamicPPL.AbstractVarInfo +} logdensity::L + # This field is used only to reconstruct the VarInfo later on; it's not needed for the + # actual log-density evaluation. + varinfo::V end function (lw::LogDensityFunctionWrapper)(x, _) return LogDensityProblems.logdensity(lw.logdensity, x) @@ -101,7 +106,7 @@ function DynamicPPL.marginalize( # Construct the marginal log-density model. f = DynamicPPL.LogDensityFunction(model, getlogprob, varinfo) mld = MarginalLogDensities.MarginalLogDensity( - LogDensityFunctionWrapper(f), varinfo[:], varindices, (), method; kwargs... + LogDensityFunctionWrapper(f, varinfo), varinfo[:], varindices, (), method; kwargs... ) return mld end @@ -190,7 +195,7 @@ function DynamicPPL.VarInfo( unmarginalized_params::Union{AbstractVector,Nothing}=nothing, ) # Extract the original VarInfo. Its contents will in general be junk. - original_vi = mld.logdensity.logdensity.varinfo + original_vi = mld.logdensity.varinfo # Extract the stored parameters, which includes the modes for any marginalized # parameters full_params = MarginalLogDensities.cached_params(mld) diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index 6d3900e91..a885f6a96 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -92,8 +92,12 @@ export AbstractVarInfo, getargnames, extract_priors, values_as_in_model, + # evaluation + evaluate!!, + init!!, # LogDensityFunction LogDensityFunction, + OnlyAccsVarInfo, # Leaf contexts AbstractContext, contextualize, diff --git a/src/chains.jl b/src/chains.jl index 892423822..f176b8e68 100644 --- a/src/chains.jl +++ b/src/chains.jl @@ -137,7 +137,7 @@ end """ ParamsWithStats( param_vector::AbstractVector, - ldf::DynamicPPL.Experimental.FastLDF, + ldf::DynamicPPL.LogDensityFunction, stats::NamedTuple=NamedTuple(); include_colon_eq::Bool=true, include_log_probs::Bool=true, @@ -156,7 +156,7 @@ via `unflatten` plus re-evaluation. It is faster for two reasons: """ function ParamsWithStats( param_vector::AbstractVector, - ldf::DynamicPPL.Experimental.FastLDF, + ldf::DynamicPPL.LogDensityFunction, stats::NamedTuple=NamedTuple(); include_colon_eq::Bool=true, include_log_probs::Bool=true, @@ -174,9 +174,7 @@ function ParamsWithStats( else (DynamicPPL.ValuesAsInModelAccumulator(include_colon_eq),) end - _, vi = DynamicPPL.Experimental.fast_evaluate!!( - ldf.model, strategy, AccumulatorTuple(accs) - ) + _, vi = DynamicPPL.init!!(ldf.model, OnlyAccsVarInfo(AccumulatorTuple(accs)), strategy) params = DynamicPPL.getacc(vi, Val(:ValuesAsInModel)).values if include_log_probs stats = merge( diff --git a/src/experimental.jl b/src/experimental.jl index c644c09b2..8c82dca68 100644 --- a/src/experimental.jl +++ b/src/experimental.jl @@ -2,8 +2,6 @@ module Experimental using DynamicPPL: DynamicPPL -include("fasteval.jl") - # This file only defines the names of the functions, and their docstrings. The actual implementations are in `ext/DynamicPPLJETExt.jl`, since we don't want to depend on JET.jl other than as a weak dependency. """ is_suitable_varinfo(model::Model, varinfo::AbstractVarInfo; kwargs...) diff --git a/src/fasteval.jl b/src/fasteval.jl deleted file mode 100644 index 722760fa1..000000000 --- a/src/fasteval.jl +++ /dev/null @@ -1,387 +0,0 @@ -using DynamicPPL: - AbstractVarInfo, - AccumulatorTuple, - InitContext, - InitFromParams, - AbstractInitStrategy, - LogJacobianAccumulator, - LogLikelihoodAccumulator, - LogPriorAccumulator, - Model, - ThreadSafeVarInfo, - VarInfo, - OnlyAccsVarInfo, - RangeAndLinked, - VectorWithRanges, - Metadata, - VarNamedVector, - default_accumulators, - float_type_with_fallback, - getlogjoint, - getlogjoint_internal, - getloglikelihood, - getlogprior, - getlogprior_internal -using ADTypes: ADTypes -using BangBang: BangBang -using AbstractPPL: AbstractPPL, VarName -using LogDensityProblems: LogDensityProblems -import DifferentiationInterface as DI -using Random: Random - -""" - DynamicPPL.Experimental.fast_evaluate!!( - [rng::Random.AbstractRNG,] - model::Model, - strategy::AbstractInitStrategy, - accs::AccumulatorTuple, params::AbstractVector{<:Real} - ) - -Evaluate a model using parameters obtained via `strategy`, and only computing the results in -the provided accumulators. - -It is assumed that the accumulators passed in have been initialised to appropriate values, -as this function will not reset them. The default constructors for each accumulator will do -this for you correctly. - -Returns a tuple of the model's return value, plus an `OnlyAccsVarInfo`. Note that the `accs` -argument may be mutated (depending on how the accumulators are implemented); hence the `!!` -in the function name. -""" -@inline function fast_evaluate!!( - # Note that this `@inline` is mandatory for performance. If it's not inlined, it leads - # to extra allocations (even for trivial models) and much slower runtime. - rng::Random.AbstractRNG, - model::Model, - strategy::AbstractInitStrategy, - accs::AccumulatorTuple, -) - ctx = InitContext(rng, strategy) - model = DynamicPPL.setleafcontext(model, ctx) - # Calling `evaluate!!` would be fine, but would lead to an extra call to resetaccs!!, - # which is unnecessary. So we shortcircuit this by simply calling `_evaluate!!` - # directly. To preserve thread-safety we need to reproduce the ThreadSafeVarInfo logic - # here. - # TODO(penelopeysm): This should _not_ check Threads.nthreads(). I still don't know what - # it _should_ do, but this is wrong regardless. - # https://github.com/TuringLang/DynamicPPL.jl/issues/1086 - vi = if Threads.nthreads() > 1 - param_eltype = DynamicPPL.get_param_eltype(strategy) - accs = map(accs) do acc - DynamicPPL.convert_eltype(float_type_with_fallback(param_eltype), acc) - end - ThreadSafeVarInfo(OnlyAccsVarInfo(accs)) - else - OnlyAccsVarInfo(accs) - end - return DynamicPPL._evaluate!!(model, vi) -end -@inline function fast_evaluate!!( - model::Model, strategy::AbstractInitStrategy, accs::AccumulatorTuple -) - # This `@inline` is also mandatory for performance - return fast_evaluate!!(Random.default_rng(), model, strategy, accs) -end - -""" - FastLDF( - model::Model, - getlogdensity::Function=getlogjoint_internal, - varinfo::AbstractVarInfo=VarInfo(model); - adtype::Union{ADTypes.AbstractADType,Nothing}=nothing, - ) - -A struct which contains a model, along with all the information necessary to: - - - calculate its log density at a given point; - - and if `adtype` is provided, calculate the gradient of the log density at that point. - -This information can be extracted using the LogDensityProblems.jl interface, specifically, -using `LogDensityProblems.logdensity` and `LogDensityProblems.logdensity_and_gradient`. If -`adtype` is nothing, then only `logdensity` is implemented. If `adtype` is a concrete AD -backend type, then `logdensity_and_gradient` is also implemented. - -There are several options for `getlogdensity` that are 'supported' out of the box: - -- [`getlogjoint_internal`](@ref): calculate the log joint, including the log-Jacobian term - for any variables that have been linked in the provided VarInfo. -- [`getlogprior_internal`](@ref): calculate the log prior, including the log-Jacobian term - for any variables that have been linked in the provided VarInfo. -- [`getlogjoint`](@ref): calculate the log joint in the model space, ignoring any effects of - linking -- [`getlogprior`](@ref): calculate the log prior in the model space, ignoring any effects of - linking -- [`getloglikelihood`](@ref): calculate the log likelihood (this is unaffected by linking, - since transforms are only applied to random variables) - -!!! note - By default, `FastLDF` uses `getlogjoint_internal`, i.e., the result of - `LogDensityProblems.logdensity(f, x)` will depend on whether the `FastLDF` was created - with a linked or unlinked VarInfo. This is done primarily to ease interoperability with - MCMC samplers. - -If you provide one of these functions, a `VarInfo` will be automatically created for you. If -you provide a different function, you have to manually create a VarInfo and pass it as the -third argument. - -If the `adtype` keyword argument is provided, then this struct will also store the adtype -along with other information for efficient calculation of the gradient of the log density. -Note that preparing a `FastLDF` with an AD type `AutoBackend()` requires the AD backend -itself to have been loaded (e.g. with `import Backend`). - -## Fields - -Note that it is undefined behaviour to access any of a `FastLDF`'s fields, apart from: - -- `fastldf.model`: The original model from which this `FastLDF` was constructed. -- `fastldf.adtype`: The AD type used for gradient calculations, or `nothing` if no AD - type was provided. - -# Extended help - -Up until DynamicPPL v0.38, there have been two ways of evaluating a DynamicPPL model at a -given set of parameters: - -1. With `unflatten` + `evaluate!!` with `DefaultContext`: this stores a vector of parameters - inside a VarInfo's metadata, then reads parameter values from the VarInfo during evaluation. - -2. With `InitFromParams`: this reads parameter values from a NamedTuple or a Dict, and stores - them inside a VarInfo's metadata. - -In general, both of these approaches work fine, but the fact that they modify the VarInfo's -metadata can often be quite wasteful. In particular, it is very common that the only outputs -we care about from model evaluation are those which are stored in accumulators, such as log -probability densities, or `ValuesAsInModel`. - -To avoid this issue, we use `OnlyAccsVarInfo`, which is a VarInfo that only contains -accumulators. It implements enough of the `AbstractVarInfo` interface to not error during -model evaluation. - -Because `OnlyAccsVarInfo` does not store any parameter values, when evaluating a model with -it, it is mandatory that parameters are provided from outside the VarInfo, namely via -`InitContext`. - -The main problem that we face is that it is not possible to directly implement -`DynamicPPL.init(rng, vn, dist, strategy)` for `strategy::InitFromParams{<:AbstractVector}`. -In particular, it is not clear: - - - which parts of the vector correspond to which random variables, and - - whether the variables are linked or unlinked. - -Traditionally, this problem has been solved by `unflatten`, because that function would -place values into the VarInfo's metadata alongside the information about ranges and linking. -That way, when we evaluate with `DefaultContext`, we can read this information out again. -However, we want to avoid using a metadata. Thus, here, we _extract this information from -the VarInfo_ a single time when constructing a `FastLDF` object. Inside the FastLDF, we -store a mapping from VarNames to ranges in that vector, along with link status. - -For VarNames with identity optics, this is stored in a NamedTuple for efficiency. For all -other VarNames, this is stored in a Dict. The internal data structure used to represent this -could almost certainly be optimised further. See e.g. the discussion in -https://github.com/TuringLang/DynamicPPL.jl/issues/1116. - -When evaluating the model, this allows us to combine the parameter vector together with those -ranges to create an `InitFromParams{VectorWithRanges}`, which lets us very quickly read -parameter values from the vector. - -Note that this assumes that the ranges and link status are static throughout the lifetime of -the `FastLDF` object. Therefore, a `FastLDF` object cannot handle models which have variable -numbers of parameters, or models which may visit random variables in different orders depending -on stochastic control flow. **Indeed, silent errors may occur with such models.** This is a -general limitation of vectorised parameters: the original `unflatten` + `evaluate!!` -approach also fails with such models. -""" -struct FastLDF{ - M<:Model, - AD<:Union{ADTypes.AbstractADType,Nothing}, - F<:Function, - N<:NamedTuple, - ADP<:Union{Nothing,DI.GradientPrep}, -} - model::M - adtype::AD - _getlogdensity::F - _iden_varname_ranges::N - _varname_ranges::Dict{VarName,RangeAndLinked} - _adprep::ADP - _dim::Int - - function FastLDF( - model::Model, - getlogdensity::Function=getlogjoint_internal, - varinfo::AbstractVarInfo=VarInfo(model); - adtype::Union{ADTypes.AbstractADType,Nothing}=nothing, - ) - # Figure out which variable corresponds to which index, and - # which variables are linked. - all_iden_ranges, all_ranges = get_ranges_and_linked(varinfo) - x = [val for val in varinfo[:]] - dim = length(x) - # Do AD prep if needed - prep = if adtype === nothing - nothing - else - # Make backend-specific tweaks to the adtype - adtype = DynamicPPL.tweak_adtype(adtype, model, varinfo) - DI.prepare_gradient( - FastLogDensityAt(model, getlogdensity, all_iden_ranges, all_ranges), - adtype, - x, - ) - end - return new{ - typeof(model), - typeof(adtype), - typeof(getlogdensity), - typeof(all_iden_ranges), - typeof(prep), - }( - model, adtype, getlogdensity, all_iden_ranges, all_ranges, prep, dim - ) - end -end - -################################### -# LogDensityProblems.jl interface # -################################### -""" - fast_ldf_accs(getlogdensity::Function) - -Determine which accumulators are needed for fast evaluation with the given -`getlogdensity` function. -""" -fast_ldf_accs(::Function) = default_accumulators() -fast_ldf_accs(::typeof(getlogjoint_internal)) = default_accumulators() -function fast_ldf_accs(::typeof(getlogjoint)) - return AccumulatorTuple((LogPriorAccumulator(), LogLikelihoodAccumulator())) -end -function fast_ldf_accs(::typeof(getlogprior_internal)) - return AccumulatorTuple((LogPriorAccumulator(), LogJacobianAccumulator())) -end -fast_ldf_accs(::typeof(getlogprior)) = AccumulatorTuple((LogPriorAccumulator(),)) -fast_ldf_accs(::typeof(getloglikelihood)) = AccumulatorTuple((LogLikelihoodAccumulator(),)) - -struct FastLogDensityAt{M<:Model,F<:Function,N<:NamedTuple} - model::M - getlogdensity::F - iden_varname_ranges::N - varname_ranges::Dict{VarName,RangeAndLinked} -end -function (f::FastLogDensityAt)(params::AbstractVector{<:Real}) - strategy = InitFromParams( - VectorWithRanges(f.iden_varname_ranges, f.varname_ranges, params), nothing - ) - accs = fast_ldf_accs(f.getlogdensity) - _, vi = fast_evaluate!!(f.model, strategy, accs) - return f.getlogdensity(vi) -end - -function LogDensityProblems.logdensity(fldf::FastLDF, params::AbstractVector{<:Real}) - return FastLogDensityAt( - fldf.model, fldf._getlogdensity, fldf._iden_varname_ranges, fldf._varname_ranges - )( - params - ) -end - -function LogDensityProblems.logdensity_and_gradient( - fldf::FastLDF, params::AbstractVector{<:Real} -) - return DI.value_and_gradient( - FastLogDensityAt( - fldf.model, fldf._getlogdensity, fldf._iden_varname_ranges, fldf._varname_ranges - ), - fldf._adprep, - fldf.adtype, - params, - ) -end - -function LogDensityProblems.capabilities( - ::Type{<:DynamicPPL.Experimental.FastLDF{M,Nothing}} -) where {M} - return LogDensityProblems.LogDensityOrder{0}() -end -function LogDensityProblems.capabilities( - ::Type{<:DynamicPPL.Experimental.FastLDF{M,<:ADTypes.AbstractADType}} -) where {M} - return LogDensityProblems.LogDensityOrder{1}() -end -function LogDensityProblems.dimension(fldf::FastLDF) - return fldf._dim -end - -###################################################### -# Helper functions to extract ranges and link status # -###################################################### - -# This fails for SimpleVarInfo, but honestly there is no reason to support that here. The -# fact is that evaluation doesn't use a VarInfo, it only uses it once to generate the ranges -# and link status. So there is no motivation to use SimpleVarInfo inside a -# LogDensityFunction any more, we can just always use typed VarInfo. In fact one could argue -# that there is no purpose in supporting untyped VarInfo either. -""" - get_ranges_and_linked(varinfo::VarInfo) - -Given a `VarInfo`, extract the ranges of each variable in the vectorised parameter -representation, along with whether each variable is linked or unlinked. - -This function should return a tuple containing: - -- A NamedTuple mapping VarNames with identity optics to their corresponding `RangeAndLinked` -- A Dict mapping all other VarNames to their corresponding `RangeAndLinked`. -""" -function get_ranges_and_linked(varinfo::VarInfo{<:NamedTuple{syms}}) where {syms} - all_iden_ranges = NamedTuple() - all_ranges = Dict{VarName,RangeAndLinked}() - offset = 1 - for sym in syms - md = varinfo.metadata[sym] - this_md_iden, this_md_others, offset = get_ranges_and_linked_metadata(md, offset) - all_iden_ranges = merge(all_iden_ranges, this_md_iden) - all_ranges = merge(all_ranges, this_md_others) - end - return all_iden_ranges, all_ranges -end -function get_ranges_and_linked(varinfo::VarInfo{<:Union{Metadata,VarNamedVector}}) - all_iden, all_others, _ = get_ranges_and_linked_metadata(varinfo.metadata, 1) - return all_iden, all_others -end -function get_ranges_and_linked_metadata(md::Metadata, start_offset::Int) - all_iden_ranges = NamedTuple() - all_ranges = Dict{VarName,RangeAndLinked}() - offset = start_offset - for (vn, idx) in md.idcs - is_linked = md.is_transformed[idx] - range = md.ranges[idx] .+ (start_offset - 1) - if AbstractPPL.getoptic(vn) === identity - all_iden_ranges = merge( - all_iden_ranges, - NamedTuple((AbstractPPL.getsym(vn) => RangeAndLinked(range, is_linked),)), - ) - else - all_ranges[vn] = RangeAndLinked(range, is_linked) - end - offset += length(range) - end - return all_iden_ranges, all_ranges, offset -end -function get_ranges_and_linked_metadata(vnv::VarNamedVector, start_offset::Int) - all_iden_ranges = NamedTuple() - all_ranges = Dict{VarName,RangeAndLinked}() - offset = start_offset - for (vn, idx) in vnv.varname_to_index - is_linked = vnv.is_unconstrained[idx] - range = vnv.ranges[idx] .+ (start_offset - 1) - if AbstractPPL.getoptic(vn) === identity - all_iden_ranges = merge( - all_iden_ranges, - NamedTuple((AbstractPPL.getsym(vn) => RangeAndLinked(range, is_linked),)), - ) - else - all_ranges[vn] = RangeAndLinked(range, is_linked) - end - offset += length(range) - end - return all_iden_ranges, all_ranges, offset -end diff --git a/src/logdensityfunction.jl b/src/logdensityfunction.jl index 7c7438c9f..65eab448e 100644 --- a/src/logdensityfunction.jl +++ b/src/logdensityfunction.jl @@ -1,312 +1,263 @@ -using AbstractMCMC: AbstractModel +using DynamicPPL: + AbstractVarInfo, + AccumulatorTuple, + InitContext, + InitFromParams, + AbstractInitStrategy, + LogJacobianAccumulator, + LogLikelihoodAccumulator, + LogPriorAccumulator, + Model, + ThreadSafeVarInfo, + VarInfo, + OnlyAccsVarInfo, + RangeAndLinked, + VectorWithRanges, + Metadata, + VarNamedVector, + default_accumulators, + float_type_with_fallback, + getlogjoint, + getlogjoint_internal, + getloglikelihood, + getlogprior, + getlogprior_internal +using ADTypes: ADTypes +using BangBang: BangBang +using AbstractPPL: AbstractPPL, VarName +using LogDensityProblems: LogDensityProblems import DifferentiationInterface as DI +using Random: Random """ - is_supported(adtype::AbstractADType) - -Check if the given AD type is formally supported by DynamicPPL. - -AD backends that are not formally supported can still be used for gradient -calculation; it is just that the DynamicPPL developers do not commit to -maintaining compatibility with them. -""" -is_supported(::ADTypes.AbstractADType) = false -is_supported(::ADTypes.AutoEnzyme) = true -is_supported(::ADTypes.AutoForwardDiff) = true -is_supported(::ADTypes.AutoMooncake) = true -is_supported(::ADTypes.AutoReverseDiff) = true - -""" - LogDensityFunction( + DynamicPPL.LogDensityFunction( model::Model, getlogdensity::Function=getlogjoint_internal, - varinfo::AbstractVarInfo=ldf_default_varinfo(model, getlogdensity); - adtype::Union{ADTypes.AbstractADType,Nothing}=nothing + varinfo::AbstractVarInfo=VarInfo(model); + adtype::Union{ADTypes.AbstractADType,Nothing}=nothing, ) A struct which contains a model, along with all the information necessary to: - calculate its log density at a given point; - - and if `adtype` is provided, calculate the gradient of the log density at - that point. - -This information can be extracted using the LogDensityProblems.jl interface, -specifically, using `LogDensityProblems.logdensity` and -`LogDensityProblems.logdensity_and_gradient`. If `adtype` is nothing, then only -`logdensity` is implemented. If `adtype` is a concrete AD backend type, then -`logdensity_and_gradient` is also implemented. - -There are several options for `getlogdensity` that are 'supported' out of the -box: - -- [`getlogjoint_internal`](@ref): calculate the log joint, including the - log-Jacobian term for any variables that have been linked in the provided - VarInfo. -- [`getlogprior_internal`](@ref): calculate the log prior, including the - log-Jacobian term for any variables that have been linked in the provided - VarInfo. -- [`getlogjoint`](@ref): calculate the log joint in the model space, ignoring - any effects of linking -- [`getlogprior`](@ref): calculate the log prior in the model space, ignoring - any effects of linking -- [`getloglikelihood`](@ref): calculate the log likelihood (this is unaffected - by linking, since transforms are only applied to random variables) + - and if `adtype` is provided, calculate the gradient of the log density at that point. + +This information can be extracted using the LogDensityProblems.jl interface, specifically, +using `LogDensityProblems.logdensity` and `LogDensityProblems.logdensity_and_gradient`. If +`adtype` is nothing, then only `logdensity` is implemented. If `adtype` is a concrete AD +backend type, then `logdensity_and_gradient` is also implemented. + +There are several options for `getlogdensity` that are 'supported' out of the box: + +- [`getlogjoint_internal`](@ref): calculate the log joint, including the log-Jacobian term + for any variables that have been linked in the provided VarInfo. +- [`getlogprior_internal`](@ref): calculate the log prior, including the log-Jacobian term + for any variables that have been linked in the provided VarInfo. +- [`getlogjoint`](@ref): calculate the log joint in the model space, ignoring any effects of + linking +- [`getlogprior`](@ref): calculate the log prior in the model space, ignoring any effects of + linking +- [`getloglikelihood`](@ref): calculate the log likelihood (this is unaffected by linking, + since transforms are only applied to random variables) !!! note - By default, `LogDensityFunction` uses `getlogjoint_internal`, i.e., the - result of `LogDensityProblems.logdensity(f, x)` will depend on whether the - `LogDensityFunction` was created with a linked or unlinked VarInfo. This - is done primarily to ease interoperability with MCMC samplers. - -If you provide one of these functions, a `VarInfo` will be automatically created -for you. If you provide a different function, you have to manually create a -VarInfo and pass it as the third argument. - -If the `adtype` keyword argument is provided, then this struct will also store -the adtype along with other information for efficient calculation of the -gradient of the log density. Note that preparing a `LogDensityFunction` with an -AD type `AutoBackend()` requires the AD backend itself to have been loaded -(e.g. with `import Backend`). - -# Fields -$(FIELDS) - -# Examples - -```jldoctest -julia> using Distributions - -julia> using DynamicPPL: LogDensityFunction, setaccs!! - -julia> @model function demo(x) - m ~ Normal() - x ~ Normal(m, 1) - end -demo (generic function with 2 methods) - -julia> model = demo(1.0); - -julia> f = LogDensityFunction(model); - -julia> # It implements the interface of LogDensityProblems.jl. - using LogDensityProblems - -julia> LogDensityProblems.logdensity(f, [0.0]) --2.3378770664093453 - -julia> LogDensityProblems.dimension(f) -1 - -julia> # By default it uses `VarInfo` under the hood, but this is not necessary. - f = LogDensityFunction(model, getlogjoint_internal, SimpleVarInfo(model)); - -julia> LogDensityProblems.logdensity(f, [0.0]) --2.3378770664093453 - -julia> # One can also specify evaluating e.g. the log prior only: - f_prior = LogDensityFunction(model, getlogprior); - -julia> LogDensityProblems.logdensity(f_prior, [0.0]) == logpdf(Normal(), 0.0) -true - -julia> # If we also need to calculate the gradient, we can specify an AD backend. - import ForwardDiff, ADTypes - -julia> f = LogDensityFunction(model, adtype=ADTypes.AutoForwardDiff()); - -julia> LogDensityProblems.logdensity_and_gradient(f, [0.0]) -(-2.3378770664093453, [1.0]) -``` + By default, `LogDensityFunction` uses `getlogjoint_internal`, i.e., the result of + `LogDensityProblems.logdensity(f, x)` will depend on whether the `LogDensityFunction` + was created with a linked or unlinked VarInfo. This is done primarily to ease + interoperability with MCMC samplers. + +If you provide one of these functions, a `VarInfo` will be automatically created for you. If +you provide a different function, you have to manually create a VarInfo and pass it as the +third argument. + +If the `adtype` keyword argument is provided, then this struct will also store the adtype +along with other information for efficient calculation of the gradient of the log density. +Note that preparing a `LogDensityFunction` with an AD type `AutoBackend()` requires the AD +backend itself to have been loaded (e.g. with `import Backend`). + +## Fields + +Note that it is undefined behaviour to access any of a `LogDensityFunction`'s fields, apart +from: + +- `ldf.model`: The original model from which this `LogDensityFunction` was constructed. +- `ldf.adtype`: The AD type used for gradient calculations, or `nothing` if no AD + type was provided. + +# Extended help + +Up until DynamicPPL v0.38, there have been two ways of evaluating a DynamicPPL model at a +given set of parameters: + +1. With `unflatten` + `evaluate!!` with `DefaultContext`: this stores a vector of parameters + inside a VarInfo's metadata, then reads parameter values from the VarInfo during evaluation. + +2. With `InitFromParams`: this reads parameter values from a NamedTuple or a Dict, and stores + them inside a VarInfo's metadata. + +In general, both of these approaches work fine, but the fact that they modify the VarInfo's +metadata can often be quite wasteful. In particular, it is very common that the only outputs +we care about from model evaluation are those which are stored in accumulators, such as log +probability densities, or `ValuesAsInModel`. + +To avoid this issue, we use `OnlyAccsVarInfo`, which is a VarInfo that only contains +accumulators. It implements enough of the `AbstractVarInfo` interface to not error during +model evaluation. + +Because `OnlyAccsVarInfo` does not store any parameter values, when evaluating a model with +it, it is mandatory that parameters are provided from outside the VarInfo, namely via +`InitContext`. + +The main problem that we face is that it is not possible to directly implement +`DynamicPPL.init(rng, vn, dist, strategy)` for `strategy::InitFromParams{<:AbstractVector}`. +In particular, it is not clear: + + - which parts of the vector correspond to which random variables, and + - whether the variables are linked or unlinked. + +Traditionally, this problem has been solved by `unflatten`, because that function would +place values into the VarInfo's metadata alongside the information about ranges and linking. +That way, when we evaluate with `DefaultContext`, we can read this information out again. +However, we want to avoid using a metadata. Thus, here, we _extract this information from +the VarInfo_ a single time when constructing a `LogDensityFunction` object. Inside the +LogDensityFunction, we store a mapping from VarNames to ranges in that vector, along with +link status. + +For VarNames with identity optics, this is stored in a NamedTuple for efficiency. For all +other VarNames, this is stored in a Dict. The internal data structure used to represent this +could almost certainly be optimised further. See e.g. the discussion in +https://github.com/TuringLang/DynamicPPL.jl/issues/1116. + +When evaluating the model, this allows us to combine the parameter vector together with those +ranges to create an `InitFromParams{VectorWithRanges}`, which lets us very quickly read +parameter values from the vector. + +Note that this assumes that the ranges and link status are static throughout the lifetime of +the `LogDensityFunction` object. Therefore, a `LogDensityFunction` object cannot handle +models which have variable numbers of parameters, or models which may visit random variables +in different orders depending on stochastic control flow. **Indeed, silent errors may occur +with such models.** This is a general limitation of vectorised parameters: the original +`unflatten` + `evaluate!!` approach also fails with such models. """ struct LogDensityFunction{ - M<:Model,F<:Function,V<:AbstractVarInfo,AD<:Union{Nothing,ADTypes.AbstractADType} -} <: AbstractModel - "model used for evaluation" + M<:Model, + AD<:Union{ADTypes.AbstractADType,Nothing}, + F<:Function, + N<:NamedTuple, + ADP<:Union{Nothing,DI.GradientPrep}, +} model::M - "function to be called on `varinfo` to extract the log density. By default `getlogjoint_internal`." - getlogdensity::F - "varinfo used for evaluation. If not specified, generated with `ldf_default_varinfo`." - varinfo::V - "AD type used for evaluation of log density gradient. If `nothing`, no gradient can be calculated" adtype::AD - "(internal use only) gradient preparation object for the model" - prep::Union{Nothing,DI.GradientPrep} + _getlogdensity::F + _iden_varname_ranges::N + _varname_ranges::Dict{VarName,RangeAndLinked} + _adprep::ADP + _dim::Int function LogDensityFunction( model::Model, getlogdensity::Function=getlogjoint_internal, - varinfo::AbstractVarInfo=ldf_default_varinfo(model, getlogdensity); + varinfo::AbstractVarInfo=VarInfo(model); adtype::Union{ADTypes.AbstractADType,Nothing}=nothing, ) - if adtype === nothing - prep = nothing + # Figure out which variable corresponds to which index, and + # which variables are linked. + all_iden_ranges, all_ranges = get_ranges_and_linked(varinfo) + x = [val for val in varinfo[:]] + dim = length(x) + # Do AD prep if needed + prep = if adtype === nothing + nothing else # Make backend-specific tweaks to the adtype - adtype = tweak_adtype(adtype, model, varinfo) - # Check whether it is supported - is_supported(adtype) || - @warn "The AD backend $adtype is not officially supported by DynamicPPL. Gradient calculations may still work, but compatibility is not guaranteed." - # Get a set of dummy params to use for prep - x = [val for val in varinfo[:]] - if use_closure(adtype) - prep = DI.prepare_gradient( - LogDensityAt(model, getlogdensity, varinfo), adtype, x - ) - else - prep = DI.prepare_gradient( - logdensity_at, - adtype, - x, - DI.Constant(model), - DI.Constant(getlogdensity), - DI.Constant(varinfo), - ) - end + adtype = DynamicPPL.tweak_adtype(adtype, model, varinfo) + DI.prepare_gradient( + LogDensityAt(model, getlogdensity, all_iden_ranges, all_ranges), + adtype, + x, + ) end - return new{typeof(model),typeof(getlogdensity),typeof(varinfo),typeof(adtype)}( - model, getlogdensity, varinfo, adtype, prep + return new{ + typeof(model), + typeof(adtype), + typeof(getlogdensity), + typeof(all_iden_ranges), + typeof(prep), + }( + model, adtype, getlogdensity, all_iden_ranges, all_ranges, prep, dim ) end end +################################### +# LogDensityProblems.jl interface # +################################### """ - LogDensityFunction( - ldf::LogDensityFunction, - adtype::Union{Nothing,ADTypes.AbstractADType} - ) + fast_ldf_accs(getlogdensity::Function) -Create a new LogDensityFunction using the model and varinfo from the given -`ldf` argument, but with the AD type set to `adtype`. To remove the AD type, -pass `nothing` as the second argument. +Determine which accumulators are needed for fast evaluation with the given +`getlogdensity` function. """ -function LogDensityFunction( - f::LogDensityFunction, adtype::Union{Nothing,ADTypes.AbstractADType} -) - return if adtype === f.adtype - f # Avoid recomputing prep if not needed - else - LogDensityFunction(f.model, f.getlogdensity, f.varinfo; adtype=adtype) - end +fast_ldf_accs(::Function) = default_accumulators() +fast_ldf_accs(::typeof(getlogjoint_internal)) = default_accumulators() +function fast_ldf_accs(::typeof(getlogjoint)) + return AccumulatorTuple((LogPriorAccumulator(), LogLikelihoodAccumulator())) end - -""" - ldf_default_varinfo(model::Model, getlogdensity::Function) - -Create the default AbstractVarInfo that should be used for evaluating the log density. - -Only the accumulators necesessary for `getlogdensity` will be used. -""" -function ldf_default_varinfo(::Model, getlogdensity::Function) - msg = """ - LogDensityFunction does not know what sort of VarInfo should be used when \ - `getlogdensity` is $getlogdensity. Please specify a VarInfo explicitly. - """ - return error(msg) +function fast_ldf_accs(::typeof(getlogprior_internal)) + return AccumulatorTuple((LogPriorAccumulator(), LogJacobianAccumulator())) end +fast_ldf_accs(::typeof(getlogprior)) = AccumulatorTuple((LogPriorAccumulator(),)) +fast_ldf_accs(::typeof(getloglikelihood)) = AccumulatorTuple((LogLikelihoodAccumulator(),)) -ldf_default_varinfo(model::Model, ::typeof(getlogjoint_internal)) = VarInfo(model) - -function ldf_default_varinfo(model::Model, ::typeof(getlogjoint)) - return setaccs!!(VarInfo(model), (LogPriorAccumulator(), LogLikelihoodAccumulator())) -end - -function ldf_default_varinfo(model::Model, ::typeof(getlogprior_internal)) - return setaccs!!(VarInfo(model), (LogPriorAccumulator(), LogJacobianAccumulator())) -end - -function ldf_default_varinfo(model::Model, ::typeof(getlogprior)) - return setaccs!!(VarInfo(model), (LogPriorAccumulator(),)) -end - -function ldf_default_varinfo(model::Model, ::typeof(getloglikelihood)) - return setaccs!!(VarInfo(model), (LogLikelihoodAccumulator(),)) +struct LogDensityAt{M<:Model,F<:Function,N<:NamedTuple} + model::M + getlogdensity::F + iden_varname_ranges::N + varname_ranges::Dict{VarName,RangeAndLinked} end - -""" - logdensity_at( - x::AbstractVector, - model::Model, - getlogdensity::Function, - varinfo::AbstractVarInfo, +function (f::LogDensityAt)(params::AbstractVector{<:Real}) + strategy = InitFromParams( + VectorWithRanges(f.iden_varname_ranges, f.varname_ranges, params), nothing ) + accs = fast_ldf_accs(f.getlogdensity) + _, vi = DynamicPPL.init!!(f.model, OnlyAccsVarInfo(accs), strategy) + return f.getlogdensity(vi) +end -Evaluate the log density of the given `model` at the given parameter values -`x`, using the given `varinfo`. Note that the `varinfo` argument is provided -only for its structure, in the sense that the parameters from the vector `x` -are inserted into it, and its own parameters are discarded. `getlogdensity` is -the function that extracts the log density from the evaluated varinfo. -""" -function logdensity_at( - x::AbstractVector, model::Model, getlogdensity::Function, varinfo::AbstractVarInfo +function LogDensityProblems.logdensity( + ldf::LogDensityFunction, params::AbstractVector{<:Real} ) - varinfo_new = unflatten(varinfo, x) - varinfo_eval = last(evaluate!!(model, varinfo_new)) - return getlogdensity(varinfo_eval) + return LogDensityAt( + ldf.model, ldf._getlogdensity, ldf._iden_varname_ranges, ldf._varname_ranges + )( + params + ) end -""" - LogDensityAt{M<:Model,F<:Function,V<:AbstractVarInfo}( - model::M - getlogdensity::F, - varinfo::V +function LogDensityProblems.logdensity_and_gradient( + ldf::LogDensityFunction, params::AbstractVector{<:Real} +) + return DI.value_and_gradient( + LogDensityAt( + ldf.model, ldf._getlogdensity, ldf._iden_varname_ranges, ldf._varname_ranges + ), + ldf._adprep, + ldf.adtype, + params, ) - -A callable struct that serves the same purpose as `x -> logdensity_at(x, model, -getlogdensity, varinfo)`. -""" -struct LogDensityAt{M<:Model,F<:Function,V<:AbstractVarInfo} - model::M - getlogdensity::F - varinfo::V -end -function (ld::LogDensityAt)(x::AbstractVector) - return logdensity_at(x, ld.model, ld.getlogdensity, ld.varinfo) end -### LogDensityProblems interface - -function LogDensityProblems.capabilities( - ::Type{<:LogDensityFunction{M,F,V,Nothing}} -) where {M,F,V} +function LogDensityProblems.capabilities(::Type{<:LogDensityFunction{M,Nothing}}) where {M} return LogDensityProblems.LogDensityOrder{0}() end function LogDensityProblems.capabilities( - ::Type{<:LogDensityFunction{M,F,V,AD}} -) where {M,F,V,AD<:ADTypes.AbstractADType} + ::Type{<:LogDensityFunction{M,<:ADTypes.AbstractADType}} +) where {M} return LogDensityProblems.LogDensityOrder{1}() end -function LogDensityProblems.logdensity(f::LogDensityFunction, x::AbstractVector) - return logdensity_at(x, f.model, f.getlogdensity, f.varinfo) +function LogDensityProblems.dimension(ldf::LogDensityFunction) + return ldf._dim end -function LogDensityProblems.logdensity_and_gradient( - f::LogDensityFunction{M,F,V,AD}, x::AbstractVector -) where {M,F,V,AD<:ADTypes.AbstractADType} - f.prep === nothing && - error("Gradient preparation not available; this should not happen") - x = [val for val in x] # Concretise type - # Make branching statically inferrable, i.e. type-stable (even if the two - # branches happen to return different types) - return if use_closure(f.adtype) - DI.value_and_gradient( - LogDensityAt(f.model, f.getlogdensity, f.varinfo), f.prep, f.adtype, x - ) - else - DI.value_and_gradient( - logdensity_at, - f.prep, - f.adtype, - x, - DI.Constant(f.model), - DI.Constant(f.getlogdensity), - DI.Constant(f.varinfo), - ) - end -end - -# TODO: should we instead implement and call on `length(f.varinfo)` (at least in the cases where no sampler is involved)? -LogDensityProblems.dimension(f::LogDensityFunction) = length(getparams(f)) - -### Utils """ tweak_adtype( @@ -325,53 +276,77 @@ By default, this just returns the input unchanged. """ tweak_adtype(adtype::ADTypes.AbstractADType, ::Model, ::AbstractVarInfo) = adtype -""" - use_closure(adtype::ADTypes.AbstractADType) - -In LogDensityProblems, we want to calculate the derivative of logdensity(f, x) -with respect to x, where f is the model (in our case LogDensityFunction) and is -a constant. However, DifferentiationInterface generally expects a -single-argument function g(x) to differentiate. +###################################################### +# Helper functions to extract ranges and link status # +###################################################### -There are two ways of dealing with this: - -1. Construct a closure over the model, i.e. let g = Base.Fix1(logdensity, f) - -2. Use a constant DI.Context. This lets us pass a two-argument function to DI, - as long as we also give it the 'inactive argument' (i.e. the model) wrapped - in `DI.Constant`. - -The relative performance of the two approaches, however, depends on the AD -backend used. Some benchmarks are provided here: -https://github.com/TuringLang/DynamicPPL.jl/issues/946#issuecomment-2931604829 - -This function is used to determine whether a given AD backend should use a -closure or a constant. If `use_closure(adtype)` returns `true`, then the -closure approach will be used. By default, this function returns `false`, i.e. -the constant approach will be used. +# This fails for SimpleVarInfo, but honestly there is no reason to support that here. The +# fact is that evaluation doesn't use a VarInfo, it only uses it once to generate the ranges +# and link status. So there is no motivation to use SimpleVarInfo inside a +# LogDensityFunction any more, we can just always use typed VarInfo. In fact one could argue +# that there is no purpose in supporting untyped VarInfo either. """ -use_closure(::ADTypes.AbstractADType) = true -use_closure(::ADTypes.AutoEnzyme) = false + get_ranges_and_linked(varinfo::VarInfo) -""" - getmodel(f) +Given a `VarInfo`, extract the ranges of each variable in the vectorised parameter +representation, along with whether each variable is linked or unlinked. -Return the `DynamicPPL.Model` wrapped in the given log-density function `f`. -""" -getmodel(f::DynamicPPL.LogDensityFunction) = f.model +This function should return a tuple containing: +- A NamedTuple mapping VarNames with identity optics to their corresponding `RangeAndLinked` +- A Dict mapping all other VarNames to their corresponding `RangeAndLinked`. """ - setmodel(f, model[, adtype]) - -Set the `DynamicPPL.Model` in the given log-density function `f` to `model`. -""" -function setmodel(f::DynamicPPL.LogDensityFunction, model::DynamicPPL.Model) - return LogDensityFunction(model, f.getlogdensity, f.varinfo; adtype=f.adtype) +function get_ranges_and_linked(varinfo::VarInfo{<:NamedTuple{syms}}) where {syms} + all_iden_ranges = NamedTuple() + all_ranges = Dict{VarName,RangeAndLinked}() + offset = 1 + for sym in syms + md = varinfo.metadata[sym] + this_md_iden, this_md_others, offset = get_ranges_and_linked_metadata(md, offset) + all_iden_ranges = merge(all_iden_ranges, this_md_iden) + all_ranges = merge(all_ranges, this_md_others) + end + return all_iden_ranges, all_ranges +end +function get_ranges_and_linked(varinfo::VarInfo{<:Union{Metadata,VarNamedVector}}) + all_iden, all_others, _ = get_ranges_and_linked_metadata(varinfo.metadata, 1) + return all_iden, all_others +end +function get_ranges_and_linked_metadata(md::Metadata, start_offset::Int) + all_iden_ranges = NamedTuple() + all_ranges = Dict{VarName,RangeAndLinked}() + offset = start_offset + for (vn, idx) in md.idcs + is_linked = md.is_transformed[idx] + range = md.ranges[idx] .+ (start_offset - 1) + if AbstractPPL.getoptic(vn) === identity + all_iden_ranges = merge( + all_iden_ranges, + NamedTuple((AbstractPPL.getsym(vn) => RangeAndLinked(range, is_linked),)), + ) + else + all_ranges[vn] = RangeAndLinked(range, is_linked) + end + offset += length(range) + end + return all_iden_ranges, all_ranges, offset +end +function get_ranges_and_linked_metadata(vnv::VarNamedVector, start_offset::Int) + all_iden_ranges = NamedTuple() + all_ranges = Dict{VarName,RangeAndLinked}() + offset = start_offset + for (vn, idx) in vnv.varname_to_index + is_linked = vnv.is_unconstrained[idx] + range = vnv.ranges[idx] .+ (start_offset - 1) + if AbstractPPL.getoptic(vn) === identity + all_iden_ranges = merge( + all_iden_ranges, + NamedTuple((AbstractPPL.getsym(vn) => RangeAndLinked(range, is_linked),)), + ) + else + all_ranges[vn] = RangeAndLinked(range, is_linked) + end + offset += length(range) + end + return all_iden_ranges, all_ranges, offset end - -""" - getparams(f::LogDensityFunction) - -Return the parameters of the wrapped varinfo as a vector. -""" -getparams(f::LogDensityFunction) = f.varinfo[:] diff --git a/src/model.jl b/src/model.jl index 2bcfe8f98..9029318b1 100644 --- a/src/model.jl +++ b/src/model.jl @@ -881,30 +881,56 @@ end [init_strategy::AbstractInitStrategy=InitFromPrior()] ) -Evaluate the `model` and replace the values of the model's random variables -in the given `varinfo` with new values, using a specified initialisation strategy. -If the values in `varinfo` are not set, they will be added -using a specified initialisation strategy. +Evaluate the `model` and replace the values of the model's random variables in the given +`varinfo` with new values, using a specified initialisation strategy. If the values in +`varinfo` are not set, they will be added using a specified initialisation strategy. If `init_strategy` is not provided, defaults to `InitFromPrior()`. Returns a tuple of the model's return value, plus the updated `varinfo` object. """ -function init!!( +@inline function init!!( + # Note that this `@inline` is mandatory for performance, especially for + # LogDensityFunction. If it's not inlined, it leads to extra allocations (even for + # trivial models) and much slower runtime. rng::Random.AbstractRNG, model::Model, - varinfo::AbstractVarInfo, - init_strategy::AbstractInitStrategy=InitFromPrior(), + vi::AbstractVarInfo, + strategy::AbstractInitStrategy=InitFromPrior(), ) - new_model = setleafcontext(model, InitContext(rng, init_strategy)) - return evaluate!!(new_model, varinfo) + ctx = InitContext(rng, strategy) + model = DynamicPPL.setleafcontext(model, ctx) + # TODO(penelopeysm): This should _not_ check Threads.nthreads(). I still don't know what + # it _should_ do, but this is wrong regardless. + # https://github.com/TuringLang/DynamicPPL.jl/issues/1086 + return if Threads.nthreads() > 1 + # TODO(penelopeysm): The logic for setting eltype of accs is very similar to that + # used in `unflatten`. The reason why we need it here is because the VarInfo `vi` + # won't have been filled with parameters prior to `init!!` being called. + # + # Note that this eltype promotion is only needed for threadsafe evaluation. In an + # ideal world, this code should be handled inside `evaluate_threadsafe!!` or a + # similar method. In other words, it should not be here, and it should not be inside + # `unflatten` either. The problem is performance. Shifting this code around can have + # massive, inexplicable, impacts on performance. This should be investigated + # properly. + param_eltype = DynamicPPL.get_param_eltype(strategy) + accs = map(vi.accs) do acc + DynamicPPL.convert_eltype(float_type_with_fallback(param_eltype), acc) + end + vi = DynamicPPL.setaccs!!(vi, accs) + tsvi = ThreadSafeVarInfo(resetaccs!!(vi)) + retval, tsvi_new = DynamicPPL._evaluate!!(model, tsvi) + return retval, setaccs!!(tsvi_new.varinfo, DynamicPPL.getaccs(tsvi_new)) + else + return DynamicPPL._evaluate!!(model, resetaccs!!(vi)) + end end -function init!!( - model::Model, - varinfo::AbstractVarInfo, - init_strategy::AbstractInitStrategy=InitFromPrior(), +@inline function init!!( + model::Model, vi::AbstractVarInfo, strategy::AbstractInitStrategy=InitFromPrior() ) - return init!!(Random.default_rng(), model, varinfo, init_strategy) + # This `@inline` is also mandatory for performance + return init!!(Random.default_rng(), model, vi, strategy) end """ diff --git a/test/ad.jl b/test/ad.jl deleted file mode 100644 index 0236c232f..000000000 --- a/test/ad.jl +++ /dev/null @@ -1,137 +0,0 @@ -using DynamicPPL: LogDensityFunction -using DynamicPPL.TestUtils.AD: run_ad, WithExpectedResult, NoTest - -@testset "Automatic differentiation" begin - # Used as the ground truth that others are compared against. - ref_adtype = AutoForwardDiff() - - test_adtypes = [ - AutoReverseDiff(; compile=false), - AutoReverseDiff(; compile=true), - AutoMooncake(; config=nothing), - ] - - @testset "Unsupported backends" begin - @model demo() = x ~ Normal() - @test_logs (:warn, r"not officially supported") LogDensityFunction( - demo(); adtype=AutoZygote() - ) - end - - @testset "Correctness" begin - @testset "$(m.f)" for m in DynamicPPL.TestUtils.DEMO_MODELS - rand_param_values = DynamicPPL.TestUtils.rand_prior_true(m) - vns = DynamicPPL.TestUtils.varnames(m) - varinfos = DynamicPPL.TestUtils.setup_varinfos(m, rand_param_values, vns) - - @testset "$(short_varinfo_name(varinfo))" for varinfo in varinfos - linked_varinfo = DynamicPPL.link(varinfo, m) - f = LogDensityFunction(m, getlogjoint_internal, linked_varinfo) - x = DynamicPPL.getparams(f) - - # Calculate reference logp + gradient of logp using ForwardDiff - ref_ad_result = run_ad(m, ref_adtype; varinfo=linked_varinfo, test=NoTest()) - ref_logp, ref_grad = ref_ad_result.value_actual, ref_ad_result.grad_actual - - @testset "$adtype" for adtype in test_adtypes - @info "Testing AD on: $(m.f) - $(short_varinfo_name(linked_varinfo)) - $adtype" - - # Put predicates here to avoid long lines - is_mooncake = adtype isa AutoMooncake - is_1_10 = v"1.10" <= VERSION < v"1.11" - is_1_11_or_1_12 = v"1.11" <= VERSION < v"1.13" - is_svi_vnv = - linked_varinfo isa SimpleVarInfo{<:DynamicPPL.VarNamedVector} - is_svi_od = linked_varinfo isa SimpleVarInfo{<:OrderedDict} - - # Mooncake doesn't work with several combinations of SimpleVarInfo. - if is_mooncake && is_1_11_or_1_12 && is_svi_vnv - # https://github.com/compintell/Mooncake.jl/issues/470 - @test_throws ArgumentError DynamicPPL.LogDensityFunction( - m, getlogjoint_internal, linked_varinfo; adtype=adtype - ) - elseif is_mooncake && is_1_10 && is_svi_vnv - # TODO: report upstream - @test_throws UndefRefError DynamicPPL.LogDensityFunction( - m, getlogjoint_internal, linked_varinfo; adtype=adtype - ) - elseif is_mooncake && is_1_10 && is_svi_od - # TODO: report upstream - @test_throws Mooncake.MooncakeRuleCompilationError DynamicPPL.LogDensityFunction( - m, getlogjoint_internal, linked_varinfo; adtype=adtype - ) - else - @test run_ad( - m, - adtype; - varinfo=linked_varinfo, - test=WithExpectedResult(ref_logp, ref_grad), - ) isa Any - end - end - end - end - end - - # Test that various different ways of specifying array types as arguments work with all - # ADTypes. - @testset "Array argument types" begin - test_m = randn(2, 3) - - function eval_logp_and_grad(model, m, adtype) - ldf = LogDensityFunction(model(); adtype=adtype) - return LogDensityProblems.logdensity_and_gradient(ldf, m[:]) - end - - @model function scalar_matrix_model(::Type{T}=Float64) where {T<:Real} - m = Matrix{T}(undef, 2, 3) - return m ~ filldist(MvNormal(zeros(2), I), 3) - end - - scalar_matrix_model_reference = eval_logp_and_grad( - scalar_matrix_model, test_m, ref_adtype - ) - - @model function matrix_model(::Type{T}=Matrix{Float64}) where {T} - m = T(undef, 2, 3) - return m ~ filldist(MvNormal(zeros(2), I), 3) - end - - matrix_model_reference = eval_logp_and_grad(matrix_model, test_m, ref_adtype) - - @model function scalar_array_model(::Type{T}=Float64) where {T<:Real} - m = Array{T}(undef, 2, 3) - return m ~ filldist(MvNormal(zeros(2), I), 3) - end - - scalar_array_model_reference = eval_logp_and_grad( - scalar_array_model, test_m, ref_adtype - ) - - @model function array_model(::Type{T}=Array{Float64}) where {T} - m = T(undef, 2, 3) - return m ~ filldist(MvNormal(zeros(2), I), 3) - end - - array_model_reference = eval_logp_and_grad(array_model, test_m, ref_adtype) - - @testset "$adtype" for adtype in test_adtypes - scalar_matrix_model_logp_and_grad = eval_logp_and_grad( - scalar_matrix_model, test_m, adtype - ) - @test scalar_matrix_model_logp_and_grad[1] ≈ scalar_matrix_model_reference[1] - @test scalar_matrix_model_logp_and_grad[2] ≈ scalar_matrix_model_reference[2] - matrix_model_logp_and_grad = eval_logp_and_grad(matrix_model, test_m, adtype) - @test matrix_model_logp_and_grad[1] ≈ matrix_model_reference[1] - @test matrix_model_logp_and_grad[2] ≈ matrix_model_reference[2] - scalar_array_model_logp_and_grad = eval_logp_and_grad( - scalar_array_model, test_m, adtype - ) - @test scalar_array_model_logp_and_grad[1] ≈ scalar_array_model_reference[1] - @test scalar_array_model_logp_and_grad[2] ≈ scalar_array_model_reference[2] - array_model_logp_and_grad = eval_logp_and_grad(array_model, test_m, adtype) - @test array_model_logp_and_grad[1] ≈ array_model_reference[1] - @test array_model_logp_and_grad[2] ≈ array_model_reference[2] - end - end -end diff --git a/test/chains.jl b/test/chains.jl index 43b877d62..12a9ece71 100644 --- a/test/chains.jl +++ b/test/chains.jl @@ -66,7 +66,7 @@ using Test end end -@testset "ParamsWithStats from FastLDF" begin +@testset "ParamsWithStats from LogDensityFunction" begin @testset "$(m.f)" for m in DynamicPPL.TestUtils.DEMO_MODELS unlinked_vi = VarInfo(m) @testset "$islinked" for islinked in (false, true) @@ -77,9 +77,9 @@ end end params = [x for x in vi[:]] - # Get the ParamsWithStats using FastLDF - fldf = DynamicPPL.Experimental.FastLDF(m, getlogjoint, vi) - ps = ParamsWithStats(params, fldf) + # Get the ParamsWithStats using LogDensityFunction + ldf = DynamicPPL.LogDensityFunction(m, getlogjoint, vi) + ps = ParamsWithStats(params, ldf) # Check that length of parameters is as expected @test length(ps.params) == length(keys(vi)) diff --git a/test/fasteval.jl b/test/fasteval.jl deleted file mode 100644 index db2333711..000000000 --- a/test/fasteval.jl +++ /dev/null @@ -1,233 +0,0 @@ -module DynamicPPLFastLDFTests - -using AbstractPPL: AbstractPPL -using Chairmarks -using DynamicPPL -using Distributions -using DistributionsAD: filldist -using ADTypes -using DynamicPPL.Experimental: FastLDF -using DynamicPPL.TestUtils.AD: run_ad, WithExpectedResult, NoTest -using LinearAlgebra: I -using Test -using LogDensityProblems: LogDensityProblems - -using ForwardDiff: ForwardDiff -using ReverseDiff: ReverseDiff -# Need to include this block here in case we run this test file standalone -@static if VERSION < v"1.12" - using Pkg - Pkg.add("Mooncake") - using Mooncake: Mooncake -end - -@testset "FastLDF: Correctness" begin - @testset "$(m.f)" for m in DynamicPPL.TestUtils.DEMO_MODELS - @testset "$varinfo_func" for varinfo_func in [ - DynamicPPL.untyped_varinfo, - DynamicPPL.typed_varinfo, - DynamicPPL.untyped_vector_varinfo, - DynamicPPL.typed_vector_varinfo, - ] - unlinked_vi = varinfo_func(m) - @testset "$islinked" for islinked in (false, true) - vi = if islinked - DynamicPPL.link!!(unlinked_vi, m) - else - unlinked_vi - end - nt_ranges, dict_ranges = DynamicPPL.Experimental.get_ranges_and_linked(vi) - params = [x for x in vi[:]] - # Iterate over all variables - for vn in keys(vi) - # Check that `getindex_internal` returns the same thing as using the ranges - # directly - range_with_linked = if AbstractPPL.getoptic(vn) === identity - nt_ranges[AbstractPPL.getsym(vn)] - else - dict_ranges[vn] - end - @test params[range_with_linked.range] == - DynamicPPL.getindex_internal(vi, vn) - # Check that the link status is correct - @test range_with_linked.is_linked == islinked - end - - # Compare results of FastLDF vs ordinary LogDensityFunction. These tests - # can eventually go once we replace LogDensityFunction with FastLDF, but - # for now it helps to have this check! (Eventually we should just check - # against manually computed log-densities). - # - # TODO(penelopeysm): I think we need to add tests for some really - # pathological models here. - @testset "$getlogdensity" for getlogdensity in ( - DynamicPPL.getlogjoint_internal, - DynamicPPL.getlogjoint, - DynamicPPL.getloglikelihood, - DynamicPPL.getlogprior_internal, - DynamicPPL.getlogprior, - ) - ldf = DynamicPPL.LogDensityFunction(m, getlogdensity, vi) - fldf = FastLDF(m, getlogdensity, vi) - @test LogDensityProblems.logdensity(ldf, params) ≈ - LogDensityProblems.logdensity(fldf, params) - end - end - end - end - - @testset "Threaded observe" begin - if Threads.nthreads() > 1 - @model function threaded(y) - x ~ Normal() - Threads.@threads for i in eachindex(y) - y[i] ~ Normal(x) - end - end - N = 100 - model = threaded(zeros(N)) - ldf = DynamicPPL.Experimental.FastLDF(model) - - xs = [1.0] - @test LogDensityProblems.logdensity(ldf, xs) ≈ - logpdf(Normal(), xs[1]) + N * logpdf(Normal(xs[1]), 0.0) - end - end -end - -@testset "FastLDF: performance" begin - if Threads.nthreads() == 1 - # Evaluating these three models should not lead to any allocations (but only when - # not using TSVI). - @model function f() - x ~ Normal() - return 1.0 ~ Normal(x) - end - @model function submodel_inner() - m ~ Normal(0, 1) - s ~ Exponential() - return (m=m, s=s) - end - # Note that for the allocation tests to work on this one, `inner` has - # to be passed as an argument to `submodel_outer`, instead of just - # being called inside the model function itself - @model function submodel_outer(inner) - params ~ to_submodel(inner) - y ~ Normal(params.m, params.s) - return 1.0 ~ Normal(y) - end - @testset for model in - (f(), submodel_inner() | (; s=0.0), submodel_outer(submodel_inner())) - vi = VarInfo(model) - fldf = DynamicPPL.Experimental.FastLDF( - model, DynamicPPL.getlogjoint_internal, vi - ) - x = vi[:] - bench = median(@be LogDensityProblems.logdensity(fldf, x)) - @test iszero(bench.allocs) - end - end -end - -@testset "AD with FastLDF" begin - # Used as the ground truth that others are compared against. - ref_adtype = AutoForwardDiff() - - test_adtypes = @static if VERSION < v"1.12" - [ - AutoReverseDiff(; compile=false), - AutoReverseDiff(; compile=true), - AutoMooncake(; config=nothing), - ] - else - [AutoReverseDiff(; compile=false), AutoReverseDiff(; compile=true)] - end - - @testset "Correctness" begin - @testset "$(m.f)" for m in DynamicPPL.TestUtils.DEMO_MODELS - varinfo = VarInfo(m) - linked_varinfo = DynamicPPL.link(varinfo, m) - f = FastLDF(m, getlogjoint_internal, linked_varinfo) - x = [p for p in linked_varinfo[:]] - - # Calculate reference logp + gradient of logp using ForwardDiff - ref_ad_result = run_ad(m, ref_adtype; varinfo=linked_varinfo, test=NoTest()) - ref_logp, ref_grad = ref_ad_result.value_actual, ref_ad_result.grad_actual - - @testset "$adtype" for adtype in test_adtypes - @info "Testing AD on: $(m.f) - $adtype" - - @test run_ad( - m, - adtype; - varinfo=linked_varinfo, - test=WithExpectedResult(ref_logp, ref_grad), - ) isa Any - end - end - end - - # Test that various different ways of specifying array types as arguments work with all - # ADTypes. - @testset "Array argument types" begin - test_m = randn(2, 3) - - function eval_logp_and_grad(model, m, adtype) - ldf = FastLDF(model(); adtype=adtype) - return LogDensityProblems.logdensity_and_gradient(ldf, m[:]) - end - - @model function scalar_matrix_model(::Type{T}=Float64) where {T<:Real} - m = Matrix{T}(undef, 2, 3) - return m ~ filldist(MvNormal(zeros(2), I), 3) - end - - scalar_matrix_model_reference = eval_logp_and_grad( - scalar_matrix_model, test_m, ref_adtype - ) - - @model function matrix_model(::Type{T}=Matrix{Float64}) where {T} - m = T(undef, 2, 3) - return m ~ filldist(MvNormal(zeros(2), I), 3) - end - - matrix_model_reference = eval_logp_and_grad(matrix_model, test_m, ref_adtype) - - @model function scalar_array_model(::Type{T}=Float64) where {T<:Real} - m = Array{T}(undef, 2, 3) - return m ~ filldist(MvNormal(zeros(2), I), 3) - end - - scalar_array_model_reference = eval_logp_and_grad( - scalar_array_model, test_m, ref_adtype - ) - - @model function array_model(::Type{T}=Array{Float64}) where {T} - m = T(undef, 2, 3) - return m ~ filldist(MvNormal(zeros(2), I), 3) - end - - array_model_reference = eval_logp_and_grad(array_model, test_m, ref_adtype) - - @testset "$adtype" for adtype in test_adtypes - scalar_matrix_model_logp_and_grad = eval_logp_and_grad( - scalar_matrix_model, test_m, adtype - ) - @test scalar_matrix_model_logp_and_grad[1] ≈ scalar_matrix_model_reference[1] - @test scalar_matrix_model_logp_and_grad[2] ≈ scalar_matrix_model_reference[2] - matrix_model_logp_and_grad = eval_logp_and_grad(matrix_model, test_m, adtype) - @test matrix_model_logp_and_grad[1] ≈ matrix_model_reference[1] - @test matrix_model_logp_and_grad[2] ≈ matrix_model_reference[2] - scalar_array_model_logp_and_grad = eval_logp_and_grad( - scalar_array_model, test_m, adtype - ) - @test scalar_array_model_logp_and_grad[1] ≈ scalar_array_model_reference[1] - @test scalar_array_model_logp_and_grad[2] ≈ scalar_array_model_reference[2] - array_model_logp_and_grad = eval_logp_and_grad(array_model, test_m, adtype) - @test array_model_logp_and_grad[1] ≈ array_model_reference[1] - @test array_model_logp_and_grad[2] ≈ array_model_reference[2] - end - end -end - -end diff --git a/test/logdensityfunction.jl b/test/logdensityfunction.jl index fbd868f71..06492d6e1 100644 --- a/test/logdensityfunction.jl +++ b/test/logdensityfunction.jl @@ -1,49 +1,240 @@ -using Test, DynamicPPL, ADTypes, LogDensityProblems, ForwardDiff - -@testset "`getmodel` and `setmodel`" begin - @testset "$(nameof(model))" for model in DynamicPPL.TestUtils.DEMO_MODELS - model = DynamicPPL.TestUtils.DEMO_MODELS[1] - ℓ = DynamicPPL.LogDensityFunction(model) - @test DynamicPPL.getmodel(ℓ) == model - @test DynamicPPL.setmodel(ℓ, model).model == model +module DynamicPPLLDFTests + +using AbstractPPL: AbstractPPL +using Chairmarks +using DynamicPPL +using Distributions +using DistributionsAD: filldist +using ADTypes +using DynamicPPL.TestUtils.AD: run_ad, WithExpectedResult, NoTest +using LinearAlgebra: I +using Test +using LogDensityProblems: LogDensityProblems + +using ForwardDiff: ForwardDiff +using ReverseDiff: ReverseDiff +using Mooncake: Mooncake + +@testset "LogDensityFunction: Correctness" begin + @testset "$(m.f)" for m in DynamicPPL.TestUtils.DEMO_MODELS + @testset "$varinfo_func" for varinfo_func in [ + DynamicPPL.untyped_varinfo, + DynamicPPL.typed_varinfo, + DynamicPPL.untyped_vector_varinfo, + DynamicPPL.typed_vector_varinfo, + ] + unlinked_vi = varinfo_func(m) + @testset "$islinked" for islinked in (false, true) + vi = if islinked + DynamicPPL.link!!(unlinked_vi, m) + else + unlinked_vi + end + nt_ranges, dict_ranges = DynamicPPL.get_ranges_and_linked(vi) + params = [x for x in vi[:]] + # Iterate over all variables + for vn in keys(vi) + # Check that `getindex_internal` returns the same thing as using the ranges + # directly + range_with_linked = if AbstractPPL.getoptic(vn) === identity + nt_ranges[AbstractPPL.getsym(vn)] + else + dict_ranges[vn] + end + @test params[range_with_linked.range] == + DynamicPPL.getindex_internal(vi, vn) + # Check that the link status is correct + @test range_with_linked.is_linked == islinked + end + end + end + end + + @testset "Threaded observe" begin + if Threads.nthreads() > 1 + @model function threaded(y) + x ~ Normal() + Threads.@threads for i in eachindex(y) + y[i] ~ Normal(x) + end + end + N = 100 + model = threaded(zeros(N)) + ldf = DynamicPPL.LogDensityFunction(model) + + xs = [1.0] + @test LogDensityProblems.logdensity(ldf, xs) ≈ + logpdf(Normal(), xs[1]) + N * logpdf(Normal(xs[1]), 0.0) + end end end -@testset "LogDensityFunction" begin - @testset "$(nameof(model))" for model in DynamicPPL.TestUtils.DEMO_MODELS - example_values = DynamicPPL.TestUtils.rand_prior_true(model) - vns = DynamicPPL.TestUtils.varnames(model) - varinfos = DynamicPPL.TestUtils.setup_varinfos(model, example_values, vns) - - vi = first(varinfos) - theta = vi[:] - ldf_joint = DynamicPPL.LogDensityFunction(model) - @test LogDensityProblems.logdensity(ldf_joint, theta) ≈ logjoint(model, vi) - ldf_prior = DynamicPPL.LogDensityFunction(model, getlogprior) - @test LogDensityProblems.logdensity(ldf_prior, theta) ≈ logprior(model, vi) - ldf_likelihood = DynamicPPL.LogDensityFunction(model, getloglikelihood) - @test LogDensityProblems.logdensity(ldf_likelihood, theta) ≈ - loglikelihood(model, vi) - - @testset "$(varinfo)" for varinfo in varinfos - # Note use of `getlogjoint` rather than `getlogjoint_internal` here ... - logdensity = DynamicPPL.LogDensityFunction(model, getlogjoint, varinfo) - θ = varinfo[:] - # ... because it has to match with `logjoint(model, vi)`, which always returns - # the unlinked value - @test LogDensityProblems.logdensity(logdensity, θ) ≈ logjoint(model, varinfo) - @test LogDensityProblems.dimension(logdensity) == length(θ) +@testset "LogDensityFunction: interface" begin + # miscellaneous parts of the LogDensityProblems interface + @testset "dimensions" begin + @model function m1() + x ~ Normal() + y ~ Normal() + return nothing + end + model = m1() + ldf = DynamicPPL.LogDensityFunction(model) + @test LogDensityProblems.dimension(ldf) == 2 + + @model function m2() + x ~ Dirichlet(ones(4)) + y ~ Categorical(x) + return nothing end + model = m2() + ldf = DynamicPPL.LogDensityFunction(model) + @test LogDensityProblems.dimension(ldf) == 5 + linked_vi = DynamicPPL.link!!(VarInfo(model), model) + ldf = DynamicPPL.LogDensityFunction(model, getlogjoint_internal, linked_vi) + @test LogDensityProblems.dimension(ldf) == 4 end @testset "capabilities" begin - model = DynamicPPL.TestUtils.DEMO_MODELS[1] + @model f() = x ~ Normal() + model = f() + # No adtype ldf = DynamicPPL.LogDensityFunction(model) @test LogDensityProblems.capabilities(typeof(ldf)) == LogDensityProblems.LogDensityOrder{0}() - - ldf_with_ad = DynamicPPL.LogDensityFunction(model; adtype=AutoForwardDiff()) - @test LogDensityProblems.capabilities(typeof(ldf_with_ad)) == + # With adtype + ldf = DynamicPPL.LogDensityFunction(model; adtype=AutoForwardDiff()) + @test LogDensityProblems.capabilities(typeof(ldf)) == LogDensityProblems.LogDensityOrder{1}() end end + +@testset "LogDensityFunction: performance" begin + if Threads.nthreads() == 1 + # Evaluating these three models should not lead to any allocations (but only when + # not using TSVI). + @model function f() + x ~ Normal() + return 1.0 ~ Normal(x) + end + @model function submodel_inner() + m ~ Normal(0, 1) + s ~ Exponential() + return (m=m, s=s) + end + # Note that for the allocation tests to work on this one, `inner` has + # to be passed as an argument to `submodel_outer`, instead of just + # being called inside the model function itself + @model function submodel_outer(inner) + params ~ to_submodel(inner) + y ~ Normal(params.m, params.s) + return 1.0 ~ Normal(y) + end + @testset for model in + (f(), submodel_inner() | (; s=0.0), submodel_outer(submodel_inner())) + vi = VarInfo(model) + ldf = DynamicPPL.LogDensityFunction(model, DynamicPPL.getlogjoint_internal, vi) + x = vi[:] + bench = median(@be LogDensityProblems.logdensity(ldf, x)) + @test iszero(bench.allocs) + end + end +end + +@testset "AD with LogDensityFunction" begin + # Used as the ground truth that others are compared against. + ref_adtype = AutoForwardDiff() + + test_adtypes = [ + AutoReverseDiff(; compile=false), + AutoReverseDiff(; compile=true), + AutoMooncake(; config=nothing), + ] + + @testset "Correctness" begin + @testset "$(m.f)" for m in DynamicPPL.TestUtils.DEMO_MODELS + varinfo = VarInfo(m) + linked_varinfo = DynamicPPL.link(varinfo, m) + f = LogDensityFunction(m, getlogjoint_internal, linked_varinfo) + x = [p for p in linked_varinfo[:]] + + # Calculate reference logp + gradient of logp using ForwardDiff + ref_ad_result = run_ad(m, ref_adtype; varinfo=linked_varinfo, test=NoTest()) + ref_logp, ref_grad = ref_ad_result.value_actual, ref_ad_result.grad_actual + + @testset "$adtype" for adtype in test_adtypes + @info "Testing AD on: $(m.f) - $adtype" + + @test run_ad( + m, + adtype; + varinfo=linked_varinfo, + test=WithExpectedResult(ref_logp, ref_grad), + ) isa Any + end + end + end + + # Test that various different ways of specifying array types as arguments work with all + # ADTypes. + @testset "Array argument types" begin + test_m = randn(2, 3) + + function eval_logp_and_grad(model, m, adtype) + ldf = LogDensityFunction(model(); adtype=adtype) + return LogDensityProblems.logdensity_and_gradient(ldf, m[:]) + end + + @model function scalar_matrix_model(::Type{T}=Float64) where {T<:Real} + m = Matrix{T}(undef, 2, 3) + return m ~ filldist(MvNormal(zeros(2), I), 3) + end + + scalar_matrix_model_reference = eval_logp_and_grad( + scalar_matrix_model, test_m, ref_adtype + ) + + @model function matrix_model(::Type{T}=Matrix{Float64}) where {T} + m = T(undef, 2, 3) + return m ~ filldist(MvNormal(zeros(2), I), 3) + end + + matrix_model_reference = eval_logp_and_grad(matrix_model, test_m, ref_adtype) + + @model function scalar_array_model(::Type{T}=Float64) where {T<:Real} + m = Array{T}(undef, 2, 3) + return m ~ filldist(MvNormal(zeros(2), I), 3) + end + + scalar_array_model_reference = eval_logp_and_grad( + scalar_array_model, test_m, ref_adtype + ) + + @model function array_model(::Type{T}=Array{Float64}) where {T} + m = T(undef, 2, 3) + return m ~ filldist(MvNormal(zeros(2), I), 3) + end + + array_model_reference = eval_logp_and_grad(array_model, test_m, ref_adtype) + + @testset "$adtype" for adtype in test_adtypes + scalar_matrix_model_logp_and_grad = eval_logp_and_grad( + scalar_matrix_model, test_m, adtype + ) + @test scalar_matrix_model_logp_and_grad[1] ≈ scalar_matrix_model_reference[1] + @test scalar_matrix_model_logp_and_grad[2] ≈ scalar_matrix_model_reference[2] + matrix_model_logp_and_grad = eval_logp_and_grad(matrix_model, test_m, adtype) + @test matrix_model_logp_and_grad[1] ≈ matrix_model_reference[1] + @test matrix_model_logp_and_grad[2] ≈ matrix_model_reference[2] + scalar_array_model_logp_and_grad = eval_logp_and_grad( + scalar_array_model, test_m, adtype + ) + @test scalar_array_model_logp_and_grad[1] ≈ scalar_array_model_reference[1] + @test scalar_array_model_logp_and_grad[2] ≈ scalar_array_model_reference[2] + array_model_logp_and_grad = eval_logp_and_grad(array_model, test_m, adtype) + @test array_model_logp_and_grad[1] ≈ array_model_reference[1] + @test array_model_logp_and_grad[2] ≈ array_model_reference[2] + end + end +end + +end diff --git a/test/runtests.jl b/test/runtests.jl index 1474b426a..9649aebbb 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -13,6 +13,7 @@ using ForwardDiff using LogDensityProblems using MacroTools using MCMCChains +using Mooncake using StableRNGs using ReverseDiff using Mooncake @@ -57,7 +58,6 @@ include("test_util.jl") include("simple_varinfo.jl") include("model.jl") include("distribution_wrappers.jl") - include("logdensityfunction.jl") include("linking.jl") include("serialization.jl") include("pointwise_logdensities.jl") @@ -68,10 +68,11 @@ include("test_util.jl") include("debug_utils.jl") include("submodels.jl") include("chains.jl") - include("bijector.jl") end if GROUP == "All" || GROUP == "Group2" + include("bijector.jl") + include("logdensityfunction.jl") @testset "extensions" begin include("ext/DynamicPPLMCMCChainsExt.jl") include("ext/DynamicPPLJETExt.jl") @@ -80,8 +81,6 @@ include("test_util.jl") @testset "ad" begin include("ext/DynamicPPLForwardDiffExt.jl") include("ext/DynamicPPLMooncakeExt.jl") - include("ad.jl") - include("fasteval.jl") end @testset "prob and logprob macro" begin @test_throws ErrorException prob"..." From c1b935b5356ca98c040cc271e5ad8414d9fd2f61 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Tue, 25 Nov 2025 12:18:10 +0000 Subject: [PATCH 16/45] Minor refactor --- src/varnamedtuple.jl | 28 +-- test/varnamedtuple.jl | 402 +++++++++++++++++++++--------------------- 2 files changed, 216 insertions(+), 214 deletions(-) diff --git a/src/varnamedtuple.jl b/src/varnamedtuple.jl index 006e8f0d5..8d35b34a7 100644 --- a/src/varnamedtuple.jl +++ b/src/varnamedtuple.jl @@ -93,15 +93,17 @@ function _partial_array_dim_size(min_dim) return factor^(Int(ceil(log(factor, min_dim)))) end +function _min_size(iarr::PartialArray, inds) + return ntuple(i -> max(_internal_size(iarr, i), _length_needed(inds[i])), length(inds)) +end + function _resize_partialarray(iarr::PartialArray, inds) - min_sizes = ntuple( - i -> max(_internal_size(iarr, i), _length_needed(inds[i])), length(inds) - ) - new_sizes = map(_partial_array_dim_size, min_sizes) + min_size = _min_size(iarr, inds) + new_size = map(_partial_array_dim_size, min_size) # Generic multidimensional Arrays can not be resized, so we need to make a new one. # See https://github.com/JuliaLang/julia/issues/37900 - new_data = Array{eltype(iarr.data),ndims(iarr)}(undef, new_sizes) - new_mask = fill(false, new_sizes) + new_data = Array{eltype(iarr.data),ndims(iarr)}(undef, new_size) + new_mask = fill(false, new_size) # Note that we have to use CartesianIndices instead of eachindex, because the latter # may use a linear index that does not match between the old and the new arrays. for i in CartesianIndices(iarr.data) @@ -133,14 +135,12 @@ Base.getindex(pa::PartialArray, optic::IndexLens) = Base.getindex(pa, optic.indi Base.haskey(pa::PartialArray, optic::IndexLens) = Base.haskey(pa, optic.indices) function BangBang.setindex!!(iarr::PartialArray, value, inds::Vararg{INDEX_TYPES}) + if length(inds) != ndims(iarr) + throw(BoundsError(iarr, inds)) + end if _has_colon(inds) - # TODO(mhauru) This could be implemented by getting size information from `value`. - # However, the corresponding getindex is more fundamentally ill-defined. throw(ArgumentError("Indexing with colons is not supported")) end - if length(inds) != ndims(iarr) - throw(ArgumentError("Invalid index $(inds)")) - end iarr = if checkbounds(Bool, iarr.mask, inds...) iarr else @@ -156,12 +156,12 @@ function BangBang.setindex!!(iarr::PartialArray, value, inds::Vararg{INDEX_TYPES end function Base.getindex(iarr::PartialArray, inds::Vararg{INDEX_TYPES}) - if _has_colon(inds) - throw(ArgumentError("Indexing with colons is not supported")) - end if length(inds) != ndims(iarr) throw(ArgumentError("Invalid index $(inds)")) end + if _has_colon(inds) + throw(ArgumentError("Indexing with colons is not supported")) + end if !haskey(iarr, inds) throw(BoundsError(iarr, inds)) end diff --git a/test/varnamedtuple.jl b/test/varnamedtuple.jl index 99f528175..02ed3bca8 100644 --- a/test/varnamedtuple.jl +++ b/test/varnamedtuple.jl @@ -4,206 +4,208 @@ using Test: @inferred, @test, @test_throws, @testset using DynamicPPL: @varname, VarNamedTuple using BangBang: setindex!! -@testset "Basic sets and gets" begin - vnt = VarNamedTuple() - vnt = @inferred(setindex!!(vnt, 32.0, @varname(a))) - @test @inferred(getindex(vnt, @varname(a))) == 32.0 - - vnt = @inferred(setindex!!(vnt, [1, 2, 3], @varname(b))) - @test @inferred(getindex(vnt, @varname(b))) == [1, 2, 3] - @test @inferred(getindex(vnt, @varname(b[2]))) == 2 - - vnt = @inferred(setindex!!(vnt, 64.0, @varname(a))) - @test @inferred(getindex(vnt, @varname(a))) == 64.0 - @test @inferred(getindex(vnt, @varname(b))) == [1, 2, 3] - - vnt = @inferred(setindex!!(vnt, 15, @varname(b[2]))) - @test @inferred(getindex(vnt, @varname(b))) == [1, 15, 3] - @test @inferred(getindex(vnt, @varname(b[2]))) == 15 - - vnt = @inferred(setindex!!(vnt, [10], @varname(c.x.y))) - @test @inferred(getindex(vnt, @varname(c.x.y))) == [10] - - vnt = @inferred(setindex!!(vnt, 11, @varname(c.x.y[1]))) - @test @inferred(getindex(vnt, @varname(c.x.y))) == [11] - @test @inferred(getindex(vnt, @varname(c.x.y[1]))) == 11 - - vnt = @inferred(setindex!!(vnt, -1.0, @varname(d[4]))) - @test @inferred(getindex(vnt, @varname(d[4]))) == -1.0 - - vnt = @inferred(setindex!!(vnt, -2.0, @varname(d[4]))) - @test @inferred(getindex(vnt, @varname(d[4]))) == -2.0 - - # These can't be @inferred because `d` now has an abstract element type. Note that this - # does not ruin type stability for other varnames that don't involve `d`. - vnt = setindex!!(vnt, "a", @varname(d[5])) - @test getindex(vnt, @varname(d[5])) == "a" - - vnt = @inferred(setindex!!(vnt, 1.0, @varname(e.f[3].g.h[2].i))) - @test @inferred(getindex(vnt, @varname(e.f[3].g.h[2].i))) == 1.0 - - vnt = @inferred(setindex!!(vnt, 2.0, @varname(e.f[3].g.h[2].i))) - @test @inferred(getindex(vnt, @varname(e.f[3].g.h[2].i))) == 2.0 - - vec = fill(1.0, 4) - vnt = @inferred(setindex!!(vnt, vec, @varname(j[1:4]))) - @test @inferred(getindex(vnt, @varname(j[1:4]))) == vec - @test @inferred(getindex(vnt, @varname(j[2]))) == vec[2] - @test haskey(vnt, @varname(j[4])) - @test !haskey(vnt, @varname(j[5])) - @test_throws BoundsError getindex(vnt, @varname(j[5])) - - vec = fill(2.0, 4) - vnt = @inferred(setindex!!(vnt, vec, @varname(j[2:5]))) - @test @inferred(getindex(vnt, @varname(j[1]))) == 1.0 - @test @inferred(getindex(vnt, @varname(j[2:5]))) == vec - @test haskey(vnt, @varname(j[5])) - - arr = fill(2.0, (4, 2)) - vn = @varname(k.l[2:5, 3, 1:2, 2]) - vnt = @inferred(setindex!!(vnt, arr, vn)) - @test @inferred(getindex(vnt, vn)) == arr - # A subset of the elements set just now. - @test @inferred(getindex(vnt, @varname(k.l[2, 3, 1:2, 2]))) == fill(2.0, 2) - - # Not enough, or too many, indices. - @test_throws "Invalid index" setindex!!(vnt, 0.0, @varname(k.l[1, 2, 3])) - @test_throws "Invalid index" setindex!!(vnt, 0.0, @varname(k.l[1, 2, 3, 4, 5])) - - arr = fill(3.0, (3, 3)) - vn = @varname(k.l[1, 1:3, 1:3, 1]) - vnt = @inferred(setindex!!(vnt, arr, vn)) - @test @inferred(getindex(vnt, vn)) == arr - # A subset of the elements set just now. - @test @inferred(getindex(vnt, @varname(k.l[1, 1:2, 1:2, 1]))) == fill(3.0, 2, 2) - # A subset of the elements set previously. - @test @inferred(getindex(vnt, @varname(k.l[2, 3, 1:2, 2]))) == fill(2.0, 2) - @test !haskey(vnt, @varname(k.l[2, 3, 3, 2])) - - vnt = @inferred(setindex!!(vnt, 1.0, @varname(m[2]))) - vnt = @inferred(setindex!!(vnt, 1.0, @varname(m[3]))) - @test @inferred(getindex(vnt, @varname(m[2:3]))) == [1.0, 1.0] - @test !haskey(vnt, @varname(m[1])) -end - -@testset "equality" begin - vnt1 = VarNamedTuple() - vnt2 = VarNamedTuple() - @test vnt1 == vnt2 - - vnt1 = setindex!!(vnt1, 1.0, @varname(a)) - @test vnt1 != vnt2 - - vnt2 = setindex!!(vnt2, 1.0, @varname(a)) - @test vnt1 == vnt2 - - vnt1 = setindex!!(vnt1, [1, 2], @varname(b)) - vnt2 = setindex!!(vnt2, [1, 2], @varname(b)) - @test vnt1 == vnt2 - - vnt2 = setindex!!(vnt2, [1, 3], @varname(b)) - @test vnt1 != vnt2 - vnt2 = setindex!!(vnt2, [1, 2], @varname(b)) - - # Try with index lenses too - vnt1 = setindex!!(vnt1, 2, @varname(c[2])) - vnt2 = setindex!!(vnt2, 2, @varname(c[2])) - @test vnt1 == vnt2 - - vnt2 = setindex!!(vnt2, 3, @varname(c[2])) - @test vnt1 != vnt2 - vnt2 = setindex!!(vnt2, 2, @varname(c[2])) - - vnt1 = setindex!!(vnt1, ["a", "b"], @varname(d.e[1:2])) - vnt2 = setindex!!(vnt2, ["a", "b"], @varname(d.e[1:2])) - @test vnt1 == vnt2 - - vnt2 = setindex!!(vnt2, :b, @varname(d.e[2])) - @test vnt1 != vnt2 -end - -@testset "merge" begin - vnt1 = VarNamedTuple() - vnt2 = VarNamedTuple() - expected_merge = VarNamedTuple() - # TODO(mhauru) Wrap this merge in @inferred, likewise other merges where it makes sense. - @test merge(vnt1, vnt2) == expected_merge - - vnt1 = setindex!!(vnt1, 1.0, @varname(a)) - vnt2 = setindex!!(vnt2, 2.0, @varname(b)) - vnt1 = setindex!!(vnt1, 1, @varname(c)) - vnt2 = setindex!!(vnt2, 2, @varname(c)) - expected_merge = setindex!!(expected_merge, 1.0, @varname(a)) - expected_merge = setindex!!(expected_merge, 2, @varname(c)) - expected_merge = setindex!!(expected_merge, 2.0, @varname(b)) - @test merge(vnt1, vnt2) == expected_merge - - vnt1 = VarNamedTuple() - vnt2 = VarNamedTuple() - expected_merge = VarNamedTuple() - vnt1 = setindex!!(vnt1, [1], @varname(d.a)) - vnt2 = setindex!!(vnt2, [2, 2], @varname(d.b)) - vnt1 = setindex!!(vnt1, [1], @varname(d.c)) - vnt2 = setindex!!(vnt2, [2, 2], @varname(d.c)) - expected_merge = setindex!!(expected_merge, [1], @varname(d.a)) - expected_merge = setindex!!(expected_merge, [2, 2], @varname(d.c)) - expected_merge = setindex!!(expected_merge, [2, 2], @varname(d.b)) - @test merge(vnt1, vnt2) == expected_merge - - vnt1 = setindex!!(vnt1, 1, @varname(e.a[1])) - vnt2 = setindex!!(vnt2, 2, @varname(e.a[2])) - expected_merge = setindex!!(expected_merge, 1, @varname(e.a[1])) - expected_merge = setindex!!(expected_merge, 2, @varname(e.a[2])) - vnt1 = setindex!!(vnt1, 1, @varname(e.a[3])) - vnt2 = setindex!!(vnt2, 2, @varname(e.a[3])) - expected_merge = setindex!!(expected_merge, 2, @varname(e.a[3])) - @test merge(vnt1, vnt2) == expected_merge - - vnt1 = setindex!!(vnt1, fill(1, 4), @varname(e.a[7:10])) - vnt2 = setindex!!(vnt2, fill(2, 4), @varname(e.a[8:11])) - expected_merge = setindex!!(expected_merge, 1, @varname(e.a[7])) - expected_merge = setindex!!(expected_merge, fill(2, 4), @varname(e.a[8:11])) - @test merge(vnt1, vnt2) == expected_merge - - vnt1 = setindex!!(vnt1, ["1", "1"], @varname(f.a[1].b.c[2, 2].d[1, 3:4])) - vnt2 = setindex!!(vnt2, ["2", "2"], @varname(f.a[1].b.c[2, 2].d[1, 3:4])) - expected_merge = setindex!!( - expected_merge, ["2", "2"], @varname(f.a[1].b.c[2, 2].d[1, 3:4]) - ) - vnt1 = setindex!!(vnt1, :1, @varname(f.a[1].b.c[3, 2].d[1, 1])) - vnt2 = setindex!!(vnt2, :2, @varname(f.a[1].b.c[4, 2].d[1, 1])) - expected_merge = setindex!!(expected_merge, :1, @varname(f.a[1].b.c[3, 2].d[1, 1])) - expected_merge = setindex!!(expected_merge, :2, @varname(f.a[1].b.c[4, 2].d[1, 1])) - @test merge(vnt1, vnt2) == expected_merge - - # PartialArrays with different sizes. - vnt1 = VarNamedTuple() - vnt2 = VarNamedTuple() - vnt1 = setindex!!(vnt1, 1, @varname(a[1])) - vnt1 = setindex!!(vnt1, 1, @varname(a[257])) - vnt2 = setindex!!(vnt2, 2, @varname(a[1])) - vnt2 = setindex!!(vnt2, 2, @varname(a[2])) - expected_merge_12 = VarNamedTuple() - expected_merge_12 = setindex!!(expected_merge_12, 1, @varname(a[257])) - expected_merge_12 = setindex!!(expected_merge_12, 2, @varname(a[1])) - expected_merge_12 = setindex!!(expected_merge_12, 2, @varname(a[2])) - @test merge(vnt1, vnt2) == expected_merge_12 - expected_merge_21 = setindex!!(expected_merge_12, 1, @varname(a[1])) - @test merge(vnt2, vnt1) == expected_merge_21 - - vnt1 = VarNamedTuple() - vnt2 = VarNamedTuple() - vnt1 = setindex!!(vnt1, 1, @varname(a[1, 1])) - vnt1 = setindex!!(vnt1, 1, @varname(a[257, 1])) - vnt2 = setindex!!(vnt2, :2, @varname(a[1, 1])) - vnt2 = setindex!!(vnt2, :2, @varname(a[1, 257])) - expected_merge_12 = VarNamedTuple() - expected_merge_12 = setindex!!(expected_merge_12, :2, @varname(a[1, 1])) - expected_merge_12 = setindex!!(expected_merge_12, 1, @varname(a[257, 1])) - expected_merge_12 = setindex!!(expected_merge_12, :2, @varname(a[1, 257])) - @test merge(vnt1, vnt2) == expected_merge_12 - expected_merge_21 = setindex!!(expected_merge_12, 1, @varname(a[1, 1])) - @test merge(vnt2, vnt1) == expected_merge_21 +@testset "VarNamedTuple" begin + @testset "Basic sets and gets" begin + vnt = VarNamedTuple() + vnt = @inferred(setindex!!(vnt, 32.0, @varname(a))) + @test @inferred(getindex(vnt, @varname(a))) == 32.0 + + vnt = @inferred(setindex!!(vnt, [1, 2, 3], @varname(b))) + @test @inferred(getindex(vnt, @varname(b))) == [1, 2, 3] + @test @inferred(getindex(vnt, @varname(b[2]))) == 2 + + vnt = @inferred(setindex!!(vnt, 64.0, @varname(a))) + @test @inferred(getindex(vnt, @varname(a))) == 64.0 + @test @inferred(getindex(vnt, @varname(b))) == [1, 2, 3] + + vnt = @inferred(setindex!!(vnt, 15, @varname(b[2]))) + @test @inferred(getindex(vnt, @varname(b))) == [1, 15, 3] + @test @inferred(getindex(vnt, @varname(b[2]))) == 15 + + vnt = @inferred(setindex!!(vnt, [10], @varname(c.x.y))) + @test @inferred(getindex(vnt, @varname(c.x.y))) == [10] + + vnt = @inferred(setindex!!(vnt, 11, @varname(c.x.y[1]))) + @test @inferred(getindex(vnt, @varname(c.x.y))) == [11] + @test @inferred(getindex(vnt, @varname(c.x.y[1]))) == 11 + + vnt = @inferred(setindex!!(vnt, -1.0, @varname(d[4]))) + @test @inferred(getindex(vnt, @varname(d[4]))) == -1.0 + + vnt = @inferred(setindex!!(vnt, -2.0, @varname(d[4]))) + @test @inferred(getindex(vnt, @varname(d[4]))) == -2.0 + + # These can't be @inferred because `d` now has an abstract element type. Note that this + # does not ruin type stability for other varnames that don't involve `d`. + vnt = setindex!!(vnt, "a", @varname(d[5])) + @test getindex(vnt, @varname(d[5])) == "a" + + vnt = @inferred(setindex!!(vnt, 1.0, @varname(e.f[3].g.h[2].i))) + @test @inferred(getindex(vnt, @varname(e.f[3].g.h[2].i))) == 1.0 + + vnt = @inferred(setindex!!(vnt, 2.0, @varname(e.f[3].g.h[2].i))) + @test @inferred(getindex(vnt, @varname(e.f[3].g.h[2].i))) == 2.0 + + vec = fill(1.0, 4) + vnt = @inferred(setindex!!(vnt, vec, @varname(j[1:4]))) + @test @inferred(getindex(vnt, @varname(j[1:4]))) == vec + @test @inferred(getindex(vnt, @varname(j[2]))) == vec[2] + @test haskey(vnt, @varname(j[4])) + @test !haskey(vnt, @varname(j[5])) + @test_throws BoundsError getindex(vnt, @varname(j[5])) + + vec = fill(2.0, 4) + vnt = @inferred(setindex!!(vnt, vec, @varname(j[2:5]))) + @test @inferred(getindex(vnt, @varname(j[1]))) == 1.0 + @test @inferred(getindex(vnt, @varname(j[2:5]))) == vec + @test haskey(vnt, @varname(j[5])) + + arr = fill(2.0, (4, 2)) + vn = @varname(k.l[2:5, 3, 1:2, 2]) + vnt = @inferred(setindex!!(vnt, arr, vn)) + @test @inferred(getindex(vnt, vn)) == arr + # A subset of the elements set just now. + @test @inferred(getindex(vnt, @varname(k.l[2, 3, 1:2, 2]))) == fill(2.0, 2) + + # Not enough, or too many, indices. + @test_throws BoundsError setindex!!(vnt, 0.0, @varname(k.l[1, 2, 3])) + @test_throws BoundsError setindex!!(vnt, 0.0, @varname(k.l[1, 2, 3, 4, 5])) + + arr = fill(3.0, (3, 3)) + vn = @varname(k.l[1, 1:3, 1:3, 1]) + vnt = @inferred(setindex!!(vnt, arr, vn)) + @test @inferred(getindex(vnt, vn)) == arr + # A subset of the elements set just now. + @test @inferred(getindex(vnt, @varname(k.l[1, 1:2, 1:2, 1]))) == fill(3.0, 2, 2) + # A subset of the elements set previously. + @test @inferred(getindex(vnt, @varname(k.l[2, 3, 1:2, 2]))) == fill(2.0, 2) + @test !haskey(vnt, @varname(k.l[2, 3, 3, 2])) + + vnt = @inferred(setindex!!(vnt, 1.0, @varname(m[2]))) + vnt = @inferred(setindex!!(vnt, 1.0, @varname(m[3]))) + @test @inferred(getindex(vnt, @varname(m[2:3]))) == [1.0, 1.0] + @test !haskey(vnt, @varname(m[1])) + end + + @testset "equality" begin + vnt1 = VarNamedTuple() + vnt2 = VarNamedTuple() + @test vnt1 == vnt2 + + vnt1 = setindex!!(vnt1, 1.0, @varname(a)) + @test vnt1 != vnt2 + + vnt2 = setindex!!(vnt2, 1.0, @varname(a)) + @test vnt1 == vnt2 + + vnt1 = setindex!!(vnt1, [1, 2], @varname(b)) + vnt2 = setindex!!(vnt2, [1, 2], @varname(b)) + @test vnt1 == vnt2 + + vnt2 = setindex!!(vnt2, [1, 3], @varname(b)) + @test vnt1 != vnt2 + vnt2 = setindex!!(vnt2, [1, 2], @varname(b)) + + # Try with index lenses too + vnt1 = setindex!!(vnt1, 2, @varname(c[2])) + vnt2 = setindex!!(vnt2, 2, @varname(c[2])) + @test vnt1 == vnt2 + + vnt2 = setindex!!(vnt2, 3, @varname(c[2])) + @test vnt1 != vnt2 + vnt2 = setindex!!(vnt2, 2, @varname(c[2])) + + vnt1 = setindex!!(vnt1, ["a", "b"], @varname(d.e[1:2])) + vnt2 = setindex!!(vnt2, ["a", "b"], @varname(d.e[1:2])) + @test vnt1 == vnt2 + + vnt2 = setindex!!(vnt2, :b, @varname(d.e[2])) + @test vnt1 != vnt2 + end + + @testset "merge" begin + vnt1 = VarNamedTuple() + vnt2 = VarNamedTuple() + expected_merge = VarNamedTuple() + # TODO(mhauru) Wrap this merge in @inferred, likewise other merges where it makes sense. + @test merge(vnt1, vnt2) == expected_merge + + vnt1 = setindex!!(vnt1, 1.0, @varname(a)) + vnt2 = setindex!!(vnt2, 2.0, @varname(b)) + vnt1 = setindex!!(vnt1, 1, @varname(c)) + vnt2 = setindex!!(vnt2, 2, @varname(c)) + expected_merge = setindex!!(expected_merge, 1.0, @varname(a)) + expected_merge = setindex!!(expected_merge, 2, @varname(c)) + expected_merge = setindex!!(expected_merge, 2.0, @varname(b)) + @test merge(vnt1, vnt2) == expected_merge + + vnt1 = VarNamedTuple() + vnt2 = VarNamedTuple() + expected_merge = VarNamedTuple() + vnt1 = setindex!!(vnt1, [1], @varname(d.a)) + vnt2 = setindex!!(vnt2, [2, 2], @varname(d.b)) + vnt1 = setindex!!(vnt1, [1], @varname(d.c)) + vnt2 = setindex!!(vnt2, [2, 2], @varname(d.c)) + expected_merge = setindex!!(expected_merge, [1], @varname(d.a)) + expected_merge = setindex!!(expected_merge, [2, 2], @varname(d.c)) + expected_merge = setindex!!(expected_merge, [2, 2], @varname(d.b)) + @test merge(vnt1, vnt2) == expected_merge + + vnt1 = setindex!!(vnt1, 1, @varname(e.a[1])) + vnt2 = setindex!!(vnt2, 2, @varname(e.a[2])) + expected_merge = setindex!!(expected_merge, 1, @varname(e.a[1])) + expected_merge = setindex!!(expected_merge, 2, @varname(e.a[2])) + vnt1 = setindex!!(vnt1, 1, @varname(e.a[3])) + vnt2 = setindex!!(vnt2, 2, @varname(e.a[3])) + expected_merge = setindex!!(expected_merge, 2, @varname(e.a[3])) + @test merge(vnt1, vnt2) == expected_merge + + vnt1 = setindex!!(vnt1, fill(1, 4), @varname(e.a[7:10])) + vnt2 = setindex!!(vnt2, fill(2, 4), @varname(e.a[8:11])) + expected_merge = setindex!!(expected_merge, 1, @varname(e.a[7])) + expected_merge = setindex!!(expected_merge, fill(2, 4), @varname(e.a[8:11])) + @test merge(vnt1, vnt2) == expected_merge + + vnt1 = setindex!!(vnt1, ["1", "1"], @varname(f.a[1].b.c[2, 2].d[1, 3:4])) + vnt2 = setindex!!(vnt2, ["2", "2"], @varname(f.a[1].b.c[2, 2].d[1, 3:4])) + expected_merge = setindex!!( + expected_merge, ["2", "2"], @varname(f.a[1].b.c[2, 2].d[1, 3:4]) + ) + vnt1 = setindex!!(vnt1, :1, @varname(f.a[1].b.c[3, 2].d[1, 1])) + vnt2 = setindex!!(vnt2, :2, @varname(f.a[1].b.c[4, 2].d[1, 1])) + expected_merge = setindex!!(expected_merge, :1, @varname(f.a[1].b.c[3, 2].d[1, 1])) + expected_merge = setindex!!(expected_merge, :2, @varname(f.a[1].b.c[4, 2].d[1, 1])) + @test merge(vnt1, vnt2) == expected_merge + + # PartialArrays with different sizes. + vnt1 = VarNamedTuple() + vnt2 = VarNamedTuple() + vnt1 = setindex!!(vnt1, 1, @varname(a[1])) + vnt1 = setindex!!(vnt1, 1, @varname(a[257])) + vnt2 = setindex!!(vnt2, 2, @varname(a[1])) + vnt2 = setindex!!(vnt2, 2, @varname(a[2])) + expected_merge_12 = VarNamedTuple() + expected_merge_12 = setindex!!(expected_merge_12, 1, @varname(a[257])) + expected_merge_12 = setindex!!(expected_merge_12, 2, @varname(a[1])) + expected_merge_12 = setindex!!(expected_merge_12, 2, @varname(a[2])) + @test merge(vnt1, vnt2) == expected_merge_12 + expected_merge_21 = setindex!!(expected_merge_12, 1, @varname(a[1])) + @test merge(vnt2, vnt1) == expected_merge_21 + + vnt1 = VarNamedTuple() + vnt2 = VarNamedTuple() + vnt1 = setindex!!(vnt1, 1, @varname(a[1, 1])) + vnt1 = setindex!!(vnt1, 1, @varname(a[257, 1])) + vnt2 = setindex!!(vnt2, :2, @varname(a[1, 1])) + vnt2 = setindex!!(vnt2, :2, @varname(a[1, 257])) + expected_merge_12 = VarNamedTuple() + expected_merge_12 = setindex!!(expected_merge_12, :2, @varname(a[1, 1])) + expected_merge_12 = setindex!!(expected_merge_12, 1, @varname(a[257, 1])) + expected_merge_12 = setindex!!(expected_merge_12, :2, @varname(a[1, 257])) + @test merge(vnt1, vnt2) == expected_merge_12 + expected_merge_21 = setindex!!(expected_merge_12, 1, @varname(a[1, 1])) + @test merge(vnt2, vnt1) == expected_merge_21 + end end end From 262a6f98c05f2536ef2e9dcdf6910f332e292e36 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Tue, 25 Nov 2025 12:21:42 +0000 Subject: [PATCH 17/45] Remove IndexDict --- src/varnamedtuple.jl | 36 ++---------------------------------- 1 file changed, 2 insertions(+), 34 deletions(-) diff --git a/src/varnamedtuple.jl b/src/varnamedtuple.jl index 8d35b34a7..d07555db2 100644 --- a/src/varnamedtuple.jl +++ b/src/varnamedtuple.jl @@ -30,11 +30,6 @@ function Base.:(==)(vnt1::VarNamedTuple, vnt2::VarNamedTuple) return vnt1.make_leaf === vnt2.make_leaf && vnt1.data == vnt2.data end -struct IndexDict{T<:Function,Keys,Values} - data::Dict{Keys,Values} - make_leaf::T -end - struct PartialArray{T<:Function,ElType,numdims} data::Array{ElType,numdims} mask::Array{Bool,numdims} @@ -273,18 +268,6 @@ function make_leaf_array(value, optic::IndexLens{T}) where {T} return setindex!!(iarr, value, optic) end -function make_leaf_dict(value, ::PropertyLens{S}) where {S} - return VarNamedTuple(NamedTuple{(S,)}((value,)), make_leaf_dict) -end -make_leaf_dict(value, ::typeof(identity)) = value -function make_leaf_dict(value, optic::ComposedFunction) - sub = make_leaf_dict(value, optic.outer) - return make_leaf_dict(sub, optic.inner) -end -function make_leaf_dict(value, optic::IndexLens) - return IndexDict(Dict(optic.indices => value), make_leaf_dict) -end - VarNamedTuple() = VarNamedTuple((;), make_leaf_array) function Base.show(io::IO, vnt::VarNamedTuple) @@ -299,10 +282,6 @@ function Base.show(io::IO, vnt::VarNamedTuple) return print(io, ")") end -function Base.show(io::IO, id::IndexDict) - return print(io, id.data) -end - Base.getindex(vnt::VarNamedTuple, name::Symbol) = vnt.data[name] function varname_to_lens(name::VarName{S}) where {S} @@ -312,18 +291,13 @@ end function Base.getindex(vnt::VarNamedTuple, name::VarName) return getindex(vnt, varname_to_lens(name)) end -function Base.getindex( - x::Union{VarNamedTuple,IndexDict,PartialArray}, optic::ComposedFunction -) +function Base.getindex(x::Union{VarNamedTuple,PartialArray}, optic::ComposedFunction) subdata = getindex(x, optic.inner) return getindex(subdata, optic.outer) end function Base.getindex(vnt::VarNamedTuple, ::PropertyLens{S}) where {S} return getindex(vnt.data, S) end -function Base.getindex(id::IndexDict, optic::IndexLens) - return getindex(id.data, optic.indices) -end function Base.haskey(vnt::VarNamedTuple, name::VarName) return haskey(vnt, varname_to_lens(name)) @@ -336,9 +310,7 @@ function Base.haskey(vnt::VarNamedTuple, optic::ComposedFunction) end Base.haskey(vnt::VarNamedTuple, ::PropertyLens{S}) where {S} = haskey(vnt.data, S) -Base.haskey(id::IndexDict, optic::IndexLens) = haskey(id.data, optic.indices) Base.haskey(::VarNamedTuple, ::IndexLens) = false -Base.haskey(::IndexDict, ::PropertyLens) = false # TODO(mhauru) This is type piracy. Base.getindex(arr::AbstractArray, optic::IndexLens) = getindex(arr, optic.indices...) @@ -353,7 +325,7 @@ function BangBang.setindex!!(vnt::VarNamedTuple, value, name::VarName) end function BangBang.setindex!!( - vnt::Union{VarNamedTuple,IndexDict,PartialArray}, value, optic::ComposedFunction + vnt::Union{VarNamedTuple,PartialArray}, value, optic::ComposedFunction ) sub = if haskey(vnt, optic.inner) BangBang.setindex!!(getindex(vnt, optic.inner), value, optic.outer) @@ -371,10 +343,6 @@ function BangBang.setindex!!(vnt::VarNamedTuple, value, ::PropertyLens{S}) where return VarNamedTuple(merge(vnt.data, NamedTuple{(S,)}((value,))), vnt.make_leaf) end -function BangBang.setindex!!(id::IndexDict, value, optic::IndexLens) - return IndexDict(setindex!!(id.data, value, optic.indices), id.make_leaf) -end - function apply(func, vnt::VarNamedTuple, name::VarName) if !haskey(vnt.data, name.name) throw(KeyError(repr(name))) From abea08782d93e1d4668c1838bce95d7c8b8c7483 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Tue, 25 Nov 2025 12:27:17 +0000 Subject: [PATCH 18/45] Remove make_leaf as a field --- src/varnamedtuple.jl | 59 +++++++++++++++++++------------------------- 1 file changed, 25 insertions(+), 34 deletions(-) diff --git a/src/varnamedtuple.jl b/src/varnamedtuple.jl index d07555db2..0dfb9ec11 100644 --- a/src/varnamedtuple.jl +++ b/src/varnamedtuple.jl @@ -19,28 +19,24 @@ function _is_multiindex(::T) where {T<:Tuple} return any(x <: UnitRange || x <: Colon for x in T.parameters) end -struct VarNamedTuple{T<:Function,Names,Values} +struct VarNamedTuple{Names,Values} data::NamedTuple{Names,Values} - make_leaf::T end # TODO(mhauru) Since I define this, should I also define `isequal` and `hash`? Same for # PartialArrays. -function Base.:(==)(vnt1::VarNamedTuple, vnt2::VarNamedTuple) - return vnt1.make_leaf === vnt2.make_leaf && vnt1.data == vnt2.data -end +Base.:(==)(vnt1::VarNamedTuple, vnt2::VarNamedTuple) = vnt1.data == vnt2.data -struct PartialArray{T<:Function,ElType,numdims} +struct PartialArray{ElType,numdims} data::Array{ElType,numdims} mask::Array{Bool,numdims} - make_leaf::T end -function PartialArray(eltype, num_dims, make_leaf=make_leaf_array) +function PartialArray(eltype, num_dims) dims = ntuple(_ -> PARTIAL_ARRAY_DIM_GROWTH_FACTOR, num_dims) data = Array{eltype,num_dims}(undef, dims) mask = fill(false, dims) - return PartialArray(data, mask, make_leaf) + return PartialArray(data, mask) end Base.ndims(iarr::PartialArray) = ndims(iarr.data) @@ -50,11 +46,11 @@ Base.ndims(iarr::PartialArray) = ndims(iarr.data) _internal_size(iarr::PartialArray, args...) = size(iarr.data, args...) function Base.copy(pa::PartialArray) - return PartialArray(copy(pa.data), copy(pa.mask), pa.make_leaf) + return PartialArray(copy(pa.data), copy(pa.mask)) end function Base.:(==)(pa1::PartialArray, pa2::PartialArray) - if (pa1.make_leaf !== pa2.make_leaf) || (ndims(pa1) != ndims(pa2)) + if ndims(pa1) != ndims(pa2) return false end size1 = _internal_size(pa1) @@ -108,11 +104,11 @@ function _resize_partialarray(iarr::PartialArray, inds) @inbounds new_data[i] = iarr.data[i] end end - return PartialArray(new_data, new_mask, iarr.make_leaf) + return PartialArray(new_data, new_mask) end # The below implements the same functionality as above, but more performantly for 1D arrays. -function _resize_partialarray(iarr::PartialArray{T,Eltype,1}, (ind,)) where {T,Eltype} +function _resize_partialarray(iarr::PartialArray{Eltype,1}, (ind,)) where {Eltype} # Resize arrays to accommodate new indices. old_size = _internal_size(iarr, 1) min_size = max(old_size, _length_needed(ind)) @@ -147,7 +143,7 @@ function BangBang.setindex!!(iarr::PartialArray, value, inds::Vararg{INDEX_TYPES else iarr.mask[inds...] = true end - return PartialArray(new_data, iarr.mask, iarr.make_leaf) + return PartialArray(new_data, iarr.mask) end function Base.getindex(iarr::PartialArray, inds::Vararg{INDEX_TYPES}) @@ -194,11 +190,6 @@ function _merge_recursive(pa1::PartialArray, pa2::PartialArray) ArgumentError("Cannot merge PartialArrays with different number of dimensions") ) end - if pa1.make_leaf !== pa2.make_leaf - throw( - ArgumentError("Cannot merge PartialArrays with different make_leaf functions") - ) - end num_dims = ndims(pa1) merge_size = ntuple(i -> max(_internal_size(pa1, i), _internal_size(pa2, i)), num_dims) result = if merge_size == _internal_size(pa2) @@ -229,7 +220,7 @@ function _merge_recursive(pa1::PartialArray, pa2::PartialArray) et = promote_type(eltype(pa1), eltype(pa2)) new_data = Array{et,num_dims}(undef, merge_size) new_mask = fill(false, merge_size) - result = PartialArray(new_data, new_mask, pa2.make_leaf) + result = PartialArray(new_data, new_mask) for i in CartesianIndices(pa2.data) @inbounds if pa2.mask[i] result.mask[i] = true @@ -249,26 +240,26 @@ function _merge_recursive(pa1::PartialArray, pa2::PartialArray) return result end -function make_leaf_array(value, ::PropertyLens{S}) where {S} - return VarNamedTuple(NamedTuple{(S,)}((value,)), make_leaf_array) +function make_leaf(value, ::PropertyLens{S}) where {S} + return VarNamedTuple(NamedTuple{(S,)}((value,))) end -make_leaf_array(value, ::typeof(identity)) = value -function make_leaf_array(value, optic::ComposedFunction) - sub = make_leaf_array(value, optic.outer) - return make_leaf_array(sub, optic.inner) +make_leaf(value, ::typeof(identity)) = value +function make_leaf(value, optic::ComposedFunction) + sub = make_leaf(value, optic.outer) + return make_leaf(sub, optic.inner) end -function make_leaf_array(value, optic::IndexLens{T}) where {T} +function make_leaf(value, optic::IndexLens{T}) where {T} inds = optic.indices num_inds = length(inds) # Check if any of the indices are ranges or colons. If yes, value needs to be an # AbstractArray. Otherwise it needs to be an individual value. et = _is_multiindex(optic.indices) ? eltype(value) : typeof(value) - iarr = PartialArray(et, num_inds, make_leaf_array) + iarr = PartialArray(et, num_inds) return setindex!!(iarr, value, optic) end -VarNamedTuple() = VarNamedTuple((;), make_leaf_array) +VarNamedTuple() = VarNamedTuple((;)) function Base.show(io::IO, vnt::VarNamedTuple) print(io, "(") @@ -330,17 +321,17 @@ function BangBang.setindex!!( sub = if haskey(vnt, optic.inner) BangBang.setindex!!(getindex(vnt, optic.inner), value, optic.outer) else - vnt.make_leaf(value, optic.outer) + make_leaf(value, optic.outer) end return BangBang.setindex!!(vnt, sub, optic.inner) end function BangBang.setindex!!(vnt::VarNamedTuple, value, ::PropertyLens{S}) where {S} # I would like this to just read - # return VarNamedTuple(BangBang.setindex!!(vnt.data, value, S), vnt.make_leaf) + # return VarNamedTuple(BangBang.setindex!!(vnt.data, value, S)) # but that seems to be type unstable. Why? Shouldn't it obviously be the same as the # below? - return VarNamedTuple(merge(vnt.data, NamedTuple{(S,)}((value,))), vnt.make_leaf) + return VarNamedTuple(merge(vnt.data, NamedTuple{(S,)}((value,)))) end function apply(func, vnt::VarNamedTuple, name::VarName) @@ -354,7 +345,7 @@ end function Base.map(func, vnt::VarNamedTuple) new_data = NamedTuple{keys(vnt.data)}(map(func, values(vnt.data))) - return VarNamedTuple(new_data, vnt.make_leaf) + return VarNamedTuple(new_data) end function Base.keys(vnt::VarNamedTuple) @@ -409,7 +400,7 @@ function _merge_recursive(vnt1::VarNamedTuple, vnt2::VarNamedTuple) end Accessors.@reset result_data[k] = val end - return VarNamedTuple(result_data, vnt2.make_leaf) + return VarNamedTuple(result_data) end end From 5900f6906a0b3951baf5db3f419fd838a215a10a Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Tue, 25 Nov 2025 14:35:18 +0000 Subject: [PATCH 19/45] Document, refactor, and fix PartialArray --- src/varnamedtuple.jl | 369 ++++++++++++++++++++++++++++-------------- test/varnamedtuple.jl | 18 ++- 2 files changed, 267 insertions(+), 120 deletions(-) diff --git a/src/varnamedtuple.jl b/src/varnamedtuple.jl index 0dfb9ec11..79a09f678 100644 --- a/src/varnamedtuple.jl +++ b/src/varnamedtuple.jl @@ -8,42 +8,139 @@ using ..DynamicPPL: _compose_no_identity export VarNamedTuple -"""The factor by which we increase the dimensions of PartialArrays when resizing them.""" -const PARTIAL_ARRAY_DIM_GROWTH_FACTOR = 4 - -const INDEX_TYPES = Union{Integer,UnitRange,Colon} - +# We define our own getindex, setindex!!, and haskey functions to be able to override their +# behaviour for some types exported from elsewhere without type piracy. This is needed +# because +# 1. We want to index into things with lenses (from Accessors.jl) using getindex and +# setindex!!. +# 2. We want to use getindex, setindex!!, and haskey as the universal functions for getting, +# setting, checking. This includes e.g. checking whether an index is valid for an Array, +# which would normally be done with checkbounds. +_haskey(x, key) = Base.haskey(x, key) +_getindex(x, inds...) = Base.getindex(x, inds...) +_setindex!!(x, value, inds...) = BangBang.setindex!!(x, value, inds...) +_getindex(arr::AbstractArray, optic::IndexLens) = _getindex(arr, optic.indices...) +_haskey(arr::AbstractArray, optic::IndexLens) = _haskey(arr, optic.indices) +function _setindex!!(arr::AbstractArray, value, optic::IndexLens) + return _setindex!!(arr, value, optic.indices...) +end +_haskey(arr::AbstractArray, inds) = checkbounds(Bool, arr, inds...) + +# Some utilities for checking what sort of indices we are dealing with. _has_colon(::T) where {T<:Tuple} = any(x <: Colon for x in T.parameters) - function _is_multiindex(::T) where {T<:Tuple} return any(x <: UnitRange || x <: Colon for x in T.parameters) end -struct VarNamedTuple{Names,Values} - data::NamedTuple{Names,Values} -end +""" + _merge_recursive(x1, x2) -# TODO(mhauru) Since I define this, should I also define `isequal` and `hash`? Same for -# PartialArrays. -Base.:(==)(vnt1::VarNamedTuple, vnt2::VarNamedTuple) = vnt1.data == vnt2.data +Recursively merge two values `x1` and `x2`. + +Unlike `Base.merge`, this function is defined for all types, and by default returns the +second argument. It is overridden for `PartialArray` and `VarNamedTuple`, since they are +nested containers, and calls itself recursively on all elements that are found in both +`x1` and `x2`. + +In other words, if both `x` and `y` are collections with the key `a`, `Base.merge(x, y)[a]` +is `y[a]`, whereas `_merge_recursive(x, y)[a]` be `_merge_recursive(x[a], y[a])`, unless +no specific method is defined for the type of `x` and `y`, in which case +`_merge_recursive(x, y) === y` +""" +_merge_recursive(_, x2) = x2 + +"""The factor by which we increase the dimensions of PartialArrays when resizing them.""" +const PARTIAL_ARRAY_DIM_GROWTH_FACTOR = 4 + +"""A convenience for defining method argument type bounds.""" +const INDEX_TYPES = Union{Integer,UnitRange,Colon} -struct PartialArray{ElType,numdims} - data::Array{ElType,numdims} - mask::Array{Bool,numdims} +""" + PartialArray{ElType,numdims} + +An array-like like structure that may only have some of its elements defined. + +A `PartialArray` is like a `Base.Array,` except not all of its elements are necessarily +defined. That is to say, one can create an empty `PartialArray` `arr` and e.g. set +`arr[3,2] = 5`, but asking for `arr[1,1]` may throw a `BoundsError` if `[1, 1]` has not been +explicitly set yet. + +`PartialArray`s can be indexed with integer indices and ranges. Indexing is always 1-based. +Other types of indexing allowed by `Base.Array` are not supported. Some of these are simply +because we haven't seen a need and haven't bothered to implement them, namely boolean +indexing, linear indexing into multidimensional arrays, and indexing with arrays. However, +notably, indexing with colons (i.e. `:`) is not supported for more fundamental reasons. + +To understand this, note that a `PartialArray` has no well-defined size. For example, if one +creates an empty array and sets `arr[3,2]`, it is unclear if that should be taken to mean +that the array has size `(3,2)`: It could be larger, and saying that the size is `(3,2)` +would also misleadingly suggest that all elements within `1:3,1:2` are set. This is also why +colon indexing is ill-defined: If one would e.g. set `arr[2,:] = [1,2,3]`, we would have no +way of saying whether the right hand side is of an acceptable size or not. + +The fact that its size is ill-defined also means that `PartialArray` is not a subtype of +`AbstractArray`. + +All indexing into `PartialArray`s are done with `getindex` and `setindex!!`. `setindex!`, +`push!`, etc. are not defined. The element type of a `PartialArray` will change as needed +under `setindex!!` to accomoddate the new values. + +Like `Base.Array`s, `PartialArray`s have a well-defined, compile-time-known element type +`ElType` and number of dimensions `numdims`. + +The internal implementation of an `PartialArray` consists of two arrays: one holding the +data and the other one being a boolean mask indicating which elements are defined. These +internal arrays may need resizing when new elements are set that have index ranges larger +than the current internal arrays. To avoid resizing too often, the internal arrays are +resized in exponentially increasing steps. This means that most `setindex!!` calls are very +fast, but some may incur substantial overhead due to resizing and copying data. It also +means that the largest index set so far determines the memory usage of the `PartialArray`. +`PartialArray`s are thus well-suited when most values in it will eventually be set. If only +a few scattered values are set, a structure like `SparseArray` may be more appropriate. +""" +struct PartialArray{ElType,num_dims} + data::Array{ElType,num_dims} + mask::Array{Bool,num_dims} + + function PartialArray( + data::Array{ElType,num_dims}, mask::Array{Bool,num_dims} + ) where {ElType,num_dims} + if size(data) != size(mask) + throw(ArgumentError("Data and mask arrays must have the same size")) + end + return new{ElType,num_dims}(data, mask) + end end -function PartialArray(eltype, num_dims) +""" + PartialArray{ElType,num_dims}(min_size=nothing) + +Create a new empty `PartialArray` with set element type and number of dimensions. + +The optional argument `min_size` can be used to specify the minimum initial size. This is +purely a performance optimisation, to avoid resizing if the eventual size is known ahead of +time. +""" +function PartialArray{ElType,num_dims}( + min_size::Union{Tuple,Nothing}=nothing +) where {ElType,num_dims} + if min_size === nothing + dims = ntuple(_ -> PARTIAL_ARRAY_DIM_GROWTH_FACTOR, num_dims) + else + dims = map(_partial_array_dim_size, min_size) + end dims = ntuple(_ -> PARTIAL_ARRAY_DIM_GROWTH_FACTOR, num_dims) - data = Array{eltype,num_dims}(undef, dims) + data = Array{ElType,num_dims}(undef, dims) mask = fill(false, dims) return PartialArray(data, mask) end -Base.ndims(iarr::PartialArray) = ndims(iarr.data) +Base.ndims(::PartialArray{ElType,num_dims}) where {ElType,num_dims} = num_dims +Base.eltype(::PartialArray{ElType}) where {ElType} = ElType # We deliberately don't define Base.size for PartialArray, because it is ill-defined. # The size of the .data field is an implementation detail. -_internal_size(iarr::PartialArray, args...) = size(iarr.data, args...) +_internal_size(pa::PartialArray, args...) = size(pa.data, args...) function Base.copy(pa::PartialArray) return PartialArray(copy(pa.data), copy(pa.mask)) @@ -55,7 +152,8 @@ function Base.:(==)(pa1::PartialArray, pa2::PartialArray) end size1 = _internal_size(pa1) size2 = _internal_size(pa2) - # TODO(mhauru) This could be optimised, but not sure it's worth it. + # TODO(mhauru) This could be optimised by not calling checkbounds on all elements + # outside the size of an array, but not sure it's worth it. merge_size = ntuple(i -> max(size1[i], size2[i]), ndims(pa1)) for i in CartesianIndices(merge_size) m1 = checkbounds(Bool, pa1.mask, Tuple(i)...) ? pa1.mask[i] : false @@ -70,9 +168,20 @@ function Base.:(==)(pa1::PartialArray, pa2::PartialArray) return true end +function Base.hash(pa::PartialArray, h::UInt) + h = hash(ndims(pa), h) + for i in eachindex(pa.mask) + @inbounds if pa.mask[i] + h = hash(i, h) + h = hash(pa.data[i], h) + end + end + return h +end + +"""Return the length needed in a dimension given an index.""" _length_needed(i::Integer) = i _length_needed(r::UnitRange) = last(r) -_length_needed(::Colon) = 0 """Take the minimum size that a dimension of a PartialArray needs to be, and return the size we choose it to be. This size will be the smallest possible power of @@ -84,92 +193,100 @@ function _partial_array_dim_size(min_dim) return factor^(Int(ceil(log(factor, min_dim)))) end -function _min_size(iarr::PartialArray, inds) - return ntuple(i -> max(_internal_size(iarr, i), _length_needed(inds[i])), length(inds)) +"""Return the minimum internal size needed for a `PartialArray` to be able set the value +at inds. +""" +function _min_size(pa::PartialArray, inds) + return ntuple(i -> max(_internal_size(pa, i), _length_needed(inds[i])), length(inds)) end -function _resize_partialarray(iarr::PartialArray, inds) - min_size = _min_size(iarr, inds) +"""Resize a PartialArray to be able to accommodate the index inds. This operates in place +for vectors, but makes a copy for higher-dimensional arrays, unless no resizing is +necessary, in which case this is a no-op.""" +function _resize_partialarray!!(pa::PartialArray, inds) + min_size = _min_size(pa, inds) new_size = map(_partial_array_dim_size, min_size) + if new_size == _internal_size(pa) + return pa + end # Generic multidimensional Arrays can not be resized, so we need to make a new one. # See https://github.com/JuliaLang/julia/issues/37900 - new_data = Array{eltype(iarr.data),ndims(iarr)}(undef, new_size) + new_data = Array{eltype(pa),ndims(pa)}(undef, new_size) new_mask = fill(false, new_size) # Note that we have to use CartesianIndices instead of eachindex, because the latter # may use a linear index that does not match between the old and the new arrays. - for i in CartesianIndices(iarr.data) - mask_val = iarr.mask[i] - @inbounds new_mask[i] = mask_val + @inbounds for i in CartesianIndices(pa.data) + mask_val = pa.mask[i] + new_mask[i] = mask_val if mask_val - @inbounds new_data[i] = iarr.data[i] + new_data[i] = pa.data[i] end end return PartialArray(new_data, new_mask) end # The below implements the same functionality as above, but more performantly for 1D arrays. -function _resize_partialarray(iarr::PartialArray{Eltype,1}, (ind,)) where {Eltype} +function _resize_partialarray!!(pa::PartialArray{Eltype,1}, (ind,)) where {Eltype} # Resize arrays to accommodate new indices. - old_size = _internal_size(iarr, 1) + old_size = _internal_size(pa, 1) min_size = max(old_size, _length_needed(ind)) new_size = _partial_array_dim_size(min_size) - resize!(iarr.data, new_size) - resize!(iarr.mask, new_size) - @inbounds iarr.mask[(old_size + 1):new_size] .= false - return iarr + if new_size == old_size + return pa + end + resize!(pa.data, new_size) + resize!(pa.mask, new_size) + @inbounds pa.mask[(old_size + 1):new_size] .= false + return pa end -function BangBang.setindex!!(pa::PartialArray, value, optic::IndexLens) - return BangBang.setindex!!(pa, value, optic.indices...) +_getindex(pa::PartialArray, optic::IndexLens) = _getindex(pa, optic.indices...) +_haskey(pa::PartialArray, optic::IndexLens) = _haskey(pa, optic.indices) +function _setindex!!(pa::PartialArray, value, optic::IndexLens) + return _setindex!!(pa, value, optic.indices...) end -Base.getindex(pa::PartialArray, optic::IndexLens) = Base.getindex(pa, optic.indices...) -Base.haskey(pa::PartialArray, optic::IndexLens) = Base.haskey(pa, optic.indices) -function BangBang.setindex!!(iarr::PartialArray, value, inds::Vararg{INDEX_TYPES}) - if length(inds) != ndims(iarr) - throw(BoundsError(iarr, inds)) +"""Throw an appropriate error if the given indices are invalid for `pa`.""" +function _check_index_validity(pa::PartialArray, inds::NTuple{N,INDEX_TYPES}) where {N} + if length(inds) != ndims(pa) + throw(BoundsError(pa, inds)) end if _has_colon(inds) - throw(ArgumentError("Indexing with colons is not supported")) - end - iarr = if checkbounds(Bool, iarr.mask, inds...) - iarr - else - _resize_partialarray(iarr, inds) - end - new_data = setindex!!(iarr.data, value, inds...) - if _is_multiindex(inds) - iarr.mask[inds...] .= true - else - iarr.mask[inds...] = true + throw(ArgumentError("Indexing PartialArrays with Colon is not supported")) end - return PartialArray(new_data, iarr.mask) + return nothing end -function Base.getindex(iarr::PartialArray, inds::Vararg{INDEX_TYPES}) - if length(inds) != ndims(iarr) - throw(ArgumentError("Invalid index $(inds)")) +function _getindex(pa::PartialArray, inds::Vararg{INDEX_TYPES}) + _check_index_validity(pa, inds) + if !_haskey(pa, inds) + throw(BoundsError(pa, inds)) end - if _has_colon(inds) - throw(ArgumentError("Indexing with colons is not supported")) - end - if !haskey(iarr, inds) - throw(BoundsError(iarr, inds)) - end - return getindex(iarr.data, inds...) + return getindex(pa.data, inds...) end -function Base.haskey(iarr::PartialArray, inds) - if _has_colon(inds) - throw(ArgumentError("Indexing with colons is not supported")) +function _haskey(pa::PartialArray, inds::NTuple{N,INDEX_TYPES}) where {N} + _check_index_validity(pa, inds) + return checkbounds(Bool, pa.mask, inds...) && all(@inbounds(getindex(pa.mask, inds...))) +end + +function _setindex!!(pa::PartialArray, value, inds::Vararg{INDEX_TYPES}) + _check_index_validity(pa, inds) + pa = if checkbounds(Bool, pa.mask, inds...) + pa + else + _resize_partialarray!!(pa, inds) end - return checkbounds(Bool, iarr.mask, inds...) && - all(@inbounds(getindex(iarr.mask, inds...))) + new_data = setindex!!(pa.data, value, inds...) + if _is_multiindex(inds) + pa.mask[inds...] .= true + else + pa.mask[inds...] = true + end + return PartialArray(new_data, pa.mask) end Base.merge(x1::PartialArray, x2::PartialArray) = _merge_recursive(x1, x2) -Base.merge(x1::VarNamedTuple, x2::VarNamedTuple) = _merge_recursive(x1, x2) -_merge_recursive(_, x2) = x2 function _merge_element_recursive(x1::PartialArray, x2::PartialArray, ind::CartesianIndex) m1 = x1.mask[ind] @@ -193,7 +310,7 @@ function _merge_recursive(pa1::PartialArray, pa2::PartialArray) num_dims = ndims(pa1) merge_size = ntuple(i -> max(_internal_size(pa1, i), _internal_size(pa2, i)), num_dims) result = if merge_size == _internal_size(pa2) - # Either pa2 is strictly bigger than pa1, or they are equal in size. + # Either pa2 is strictly bigger than pa1 or they are equal in size. result = copy(pa2) for i in CartesianIndices(pa1.data) @inbounds if pa1.mask[i] @@ -240,6 +357,16 @@ function _merge_recursive(pa1::PartialArray, pa2::PartialArray) return result end +struct VarNamedTuple{Names,Values} + data::NamedTuple{Names,Values} +end + +# TODO(mhauru) Since I define this, should I also define `isequal` and `hash`? Same for +# PartialArrays. +Base.:(==)(vnt1::VarNamedTuple, vnt2::VarNamedTuple) = vnt1.data == vnt2.data + +Base.merge(x1::VarNamedTuple, x2::VarNamedTuple) = _merge_recursive(x1, x2) + function make_leaf(value, ::PropertyLens{S}) where {S} return VarNamedTuple(NamedTuple{(S,)}((value,))) end @@ -255,7 +382,7 @@ function make_leaf(value, optic::IndexLens{T}) where {T} # Check if any of the indices are ranges or colons. If yes, value needs to be an # AbstractArray. Otherwise it needs to be an individual value. et = _is_multiindex(optic.indices) ? eltype(value) : typeof(value) - iarr = PartialArray(et, num_inds) + iarr = PartialArray{et,num_inds}() return setindex!!(iarr, value, optic) end @@ -273,62 +400,35 @@ function Base.show(io::IO, vnt::VarNamedTuple) return print(io, ")") end -Base.getindex(vnt::VarNamedTuple, name::Symbol) = vnt.data[name] +_getindex(vnt::VarNamedTuple, name::Symbol) = vnt.data[name] function varname_to_lens(name::VarName{S}) where {S} return _compose_no_identity(getoptic(name), PropertyLens{S}()) end -function Base.getindex(vnt::VarNamedTuple, name::VarName) - return getindex(vnt, varname_to_lens(name)) +function _getindex(vnt::VarNamedTuple, name::VarName) + return _getindex(vnt, varname_to_lens(name)) end -function Base.getindex(x::Union{VarNamedTuple,PartialArray}, optic::ComposedFunction) - subdata = getindex(x, optic.inner) - return getindex(subdata, optic.outer) -end -function Base.getindex(vnt::VarNamedTuple, ::PropertyLens{S}) where {S} - return getindex(vnt.data, S) -end - -function Base.haskey(vnt::VarNamedTuple, name::VarName) - return haskey(vnt, varname_to_lens(name)) +function _getindex(vnt::VarNamedTuple, ::PropertyLens{S}) where {S} + return _getindex(vnt.data, S) end -Base.haskey(vnt::VarNamedTuple, ::typeof(identity)) = true - -function Base.haskey(vnt::VarNamedTuple, optic::ComposedFunction) - return haskey(vnt, optic.inner) && haskey(getindex(vnt, optic.inner), optic.outer) +function _haskey(vnt::VarNamedTuple, name::VarName) + return _haskey(vnt, varname_to_lens(name)) end -Base.haskey(vnt::VarNamedTuple, ::PropertyLens{S}) where {S} = haskey(vnt.data, S) -Base.haskey(::VarNamedTuple, ::IndexLens) = false - -# TODO(mhauru) This is type piracy. -Base.getindex(arr::AbstractArray, optic::IndexLens) = getindex(arr, optic.indices...) +_haskey(vnt::VarNamedTuple, ::typeof(identity)) = true -# TODO(mhauru) This is type piracy. -function BangBang.setindex!!(arr::AbstractArray, value, optic::IndexLens) - return BangBang.setindex!!(arr, value, optic.indices...) -end +_haskey(vnt::VarNamedTuple, ::PropertyLens{S}) where {S} = _haskey(vnt.data, S) +_haskey(::VarNamedTuple, ::IndexLens) = false -function BangBang.setindex!!(vnt::VarNamedTuple, value, name::VarName) - return BangBang.setindex!!(vnt, value, varname_to_lens(name)) -end - -function BangBang.setindex!!( - vnt::Union{VarNamedTuple,PartialArray}, value, optic::ComposedFunction -) - sub = if haskey(vnt, optic.inner) - BangBang.setindex!!(getindex(vnt, optic.inner), value, optic.outer) - else - make_leaf(value, optic.outer) - end - return BangBang.setindex!!(vnt, sub, optic.inner) +function _setindex!!(vnt::VarNamedTuple, value, name::VarName) + return _setindex!!(vnt, value, varname_to_lens(name)) end -function BangBang.setindex!!(vnt::VarNamedTuple, value, ::PropertyLens{S}) where {S} +function _setindex!!(vnt::VarNamedTuple, value, ::PropertyLens{S}) where {S} # I would like this to just read - # return VarNamedTuple(BangBang.setindex!!(vnt.data, value, S)) + # return VarNamedTuple(_setindex!!(vnt.data, value, S)) # but that seems to be type unstable. Why? Shouldn't it obviously be the same as the # below? return VarNamedTuple(merge(vnt.data, NamedTuple{(S,)}((value,)))) @@ -338,9 +438,9 @@ function apply(func, vnt::VarNamedTuple, name::VarName) if !haskey(vnt.data, name.name) throw(KeyError(repr(name))) end - subdata = getindex(vnt, name) + subdata = _getindex(vnt, name) new_subdata = func(subdata) - return BangBang.setindex!!(vnt, new_subdata, name) + return _setindex!!(vnt, new_subdata, name) end function Base.map(func, vnt::VarNamedTuple) @@ -365,7 +465,7 @@ function Base.keys(vnt::VarNamedTuple) return result end -function Base.haskey(vnt::VarNamedTuple, name::VarName{S,Optic}) where {S,Optic} +function _haskey(vnt::VarNamedTuple, name::VarName{S,Optic}) where {S,Optic} if !haskey(vnt.data, S) return false end @@ -403,4 +503,35 @@ function _merge_recursive(vnt1::VarNamedTuple, vnt2::VarNamedTuple) return VarNamedTuple(result_data) end +# The following methods, indexing with ComposedFunction, are exactly the same for +# VarNamedTuple and PartialArray, since they just fall back on indexing with the outer and +# inner lenses. +const VNT_OR_PA = Union{VarNamedTuple,PartialArray} + +function _getindex(x::VNT_OR_PA, optic::ComposedFunction) + subdata = _getindex(x, optic.inner) + return _getindex(subdata, optic.outer) +end + +function _setindex!!(vnt::VNT_OR_PA, value, optic::ComposedFunction) + sub = if _haskey(vnt, optic.inner) + _setindex!!(_getindex(vnt, optic.inner), value, optic.outer) + else + make_leaf(value, optic.outer) + end + return _setindex!!(vnt, sub, optic.inner) +end + +function _haskey(vnt::VNT_OR_PA, optic::ComposedFunction) + return _haskey(vnt, optic.inner) && _haskey(_getindex(vnt, optic.inner), optic.outer) +end + +# The entry points for getting, setting, and checking, using the familiar functions. +Base.haskey(vnt::VarNamedTuple, key) = _haskey(vnt, key) +Base.getindex(vnt::VarNamedTuple, inds...) = _getindex(vnt, inds...) +BangBang.setindex!!(vnt::VarNamedTuple, value, inds...) = _setindex!!(vnt, value, inds...) +Base.haskey(vnt::PartialArray, key) = _haskey(vnt, key) +Base.getindex(vnt::PartialArray, inds...) = _getindex(vnt, inds...) +BangBang.setindex!!(vnt::PartialArray, value, inds...) = _setindex!!(vnt, value, inds...) + end diff --git a/test/varnamedtuple.jl b/test/varnamedtuple.jl index 02ed3bca8..8cbf10a64 100644 --- a/test/varnamedtuple.jl +++ b/test/varnamedtuple.jl @@ -1,18 +1,32 @@ module VarNamedTupleTests using Test: @inferred, @test, @test_throws, @testset -using DynamicPPL: @varname, VarNamedTuple +using DynamicPPL: DynamicPPL, @varname, VarNamedTuple using BangBang: setindex!! @testset "VarNamedTuple" begin + @testset "Construction" begin + pa1 = DynamicPPL.VarNamedTuples.PartialArray{Float64,1}() + pa1 = setindex!!(pa1, 1.0, 16) + pa2 = DynamicPPL.VarNamedTuples.PartialArray{Float64,1}((16,)) + pa2 = setindex!!(pa2, 1.0, 16) + @test pa1 == pa2 + end + @testset "Basic sets and gets" begin vnt = VarNamedTuple() vnt = @inferred(setindex!!(vnt, 32.0, @varname(a))) @test @inferred(getindex(vnt, @varname(a))) == 32.0 + @test haskey(vnt, @varname(a)) + @test !haskey(vnt, @varname(b)) vnt = @inferred(setindex!!(vnt, [1, 2, 3], @varname(b))) @test @inferred(getindex(vnt, @varname(b))) == [1, 2, 3] @test @inferred(getindex(vnt, @varname(b[2]))) == 2 + @test haskey(vnt, @varname(b)) + @test haskey(vnt, @varname(b[1])) + @test haskey(vnt, @varname(b[1:3])) + @test !haskey(vnt, @varname(b[4])) vnt = @inferred(setindex!!(vnt, 64.0, @varname(a))) @test @inferred(getindex(vnt, @varname(a))) == 64.0 @@ -42,6 +56,8 @@ using BangBang: setindex!! vnt = @inferred(setindex!!(vnt, 1.0, @varname(e.f[3].g.h[2].i))) @test @inferred(getindex(vnt, @varname(e.f[3].g.h[2].i))) == 1.0 + @test haskey(vnt, @varname(e.f[3].g.h[2].i)) + @test !haskey(vnt, @varname(e.f[2].g.h[2].i)) vnt = @inferred(setindex!!(vnt, 2.0, @varname(e.f[3].g.h[2].i))) @test @inferred(getindex(vnt, @varname(e.f[3].g.h[2].i))) == 2.0 From 8f17dcf5fdf93a9d011a57b2b7b44cb2edd44981 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Tue, 25 Nov 2025 15:25:05 +0000 Subject: [PATCH 20/45] Make PartialArray more type stable. --- src/varnamedtuple.jl | 36 +++++++++++++++++++++++++++++++++++- test/varnamedtuple.jl | 14 ++++++++++++++ 2 files changed, 49 insertions(+), 1 deletion(-) diff --git a/src/varnamedtuple.jl b/src/varnamedtuple.jl index 79a09f678..aed87017e 100644 --- a/src/varnamedtuple.jl +++ b/src/varnamedtuple.jl @@ -179,6 +179,40 @@ function Base.hash(pa::PartialArray, h::UInt) return h end +""" + _concretise_eltype!!(pa::PartialArray) + +Concretise the element type of a `PartialArray`. + +Returns a new `PartialArray` with the same data and mask as `pa`, but with its element type +concretised to the most specific type that can hold all currently defined elements. + +Note that this function is fundamentally type unstable if the current element type of `pa` +is not already concrete. + +The name has a `!!` not because it mutates its argument, but because the return value +aliases memory with the argument, and is thus not independent of it. +""" +function _concretise_eltype!!(pa::PartialArray) + if isconcretetype(eltype(pa)) + return pa + end + new_et = promote_type((typeof(pa.data[i]) for i in eachindex(pa.mask) if pa.mask[i])...) + # TODO(mhauru) Should we check as below, or rather isconcretetype(new_et)? + # In other words, does it help to be more concrete, even if we aren't fully concrete? + if new_et === eltype(pa) + # The types of the elements do not allow for concretisation. + return pa + end + new_data = Array{new_et,ndims(pa)}(undef, _internal_size(pa)) + @inbounds for i in eachindex(pa.mask) + if pa.mask[i] + new_data[i] = pa.data[i] + end + end + return PartialArray(new_data, pa.mask) +end + """Return the length needed in a dimension given an index.""" _length_needed(i::Integer) = i _length_needed(r::UnitRange) = last(r) @@ -283,7 +317,7 @@ function _setindex!!(pa::PartialArray, value, inds::Vararg{INDEX_TYPES}) else pa.mask[inds...] = true end - return PartialArray(new_data, pa.mask) + return _concretise_eltype!!(PartialArray(new_data, pa.mask)) end Base.merge(x1::PartialArray, x2::PartialArray) = _merge_recursive(x1, x2) diff --git a/test/varnamedtuple.jl b/test/varnamedtuple.jl index 8cbf10a64..08a65b018 100644 --- a/test/varnamedtuple.jl +++ b/test/varnamedtuple.jl @@ -101,6 +101,20 @@ using BangBang: setindex!! vnt = @inferred(setindex!!(vnt, 1.0, @varname(m[3]))) @test @inferred(getindex(vnt, @varname(m[2:3]))) == [1.0, 1.0] @test !haskey(vnt, @varname(m[1])) + + # The below tests are mostly significant for the type stability aspect. For the last + # test to pass, PartialArray needs to actively tighten its eltype when possible. + vnt = @inferred(setindex!!(vnt, 1.0, @varname(n[1].a))) + @test @inferred(getindex(vnt, @varname(n[1].a))) == 1.0 + vnt = @inferred(setindex!!(vnt, 1.0, @varname(n[2].a))) + @test @inferred(getindex(vnt, @varname(n[2].a))) == 1.0 + # This can't be type stable, because n[1] has inhomogeneous types. + vnt = setindex!!(vnt, 1.0, @varname(n[1].b)) + @test getindex(vnt, @varname(n[1].b)) == 1.0 + # The setindex!! call can't be type stable either, but it should return a + # VarNamedTuple with a concrete element type, and hence getindex can be inferred. + vnt = setindex!!(vnt, 1.0, @varname(n[2].b)) + @test @inferred(getindex(vnt, @varname(n[2].b))) == 1.0 end @testset "equality" begin From 8547e250193496fcfedd0d9d950fe69dae65237c Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Tue, 25 Nov 2025 15:39:37 +0000 Subject: [PATCH 21/45] Implement `predict`, `returned`, `logjoint`, ... with `OnlyAccsVarInfo` (#1130) * Use OnlyAccsVarInfo for many re-evaluation functions * drop `fast_` prefix * Add a changelog --- HISTORY.md | 5 + ext/DynamicPPLMCMCChainsExt.jl | 184 ++++++++++++++------------------- src/chains.jl | 26 ----- src/logdensityfunction.jl | 16 +-- src/model.jl | 15 +-- 5 files changed, 102 insertions(+), 144 deletions(-) diff --git a/HISTORY.md b/HISTORY.md index 91306c219..ff28349d8 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -32,9 +32,14 @@ You should not need to use these directly, please use `AbstractPPL.condition` an Removed the method `returned(::Model, values, keys)`; please use `returned(::Model, ::AbstractDict{<:VarName})` instead. +The unexported functions `supports_varname_indexing(chain)`, `getindex_varname(chain)`, and `varnames(chain)` have been removed. + The method `DynamicPPL.init` (for implementing `AbstractInitStrategy`) now has a different signature: it must return a tuple of the generated value, plus a transform function that maps it back to unlinked space. This is a generalisation of the previous behaviour, where `init` would always return an unlinked value (in effect forcing the transform to be the identity function). +The family of functions `returned(model, chain)`, along with the same signatures of `pointwise_logdensities`, `logjoint`, `loglikelihood`, and `logprior`, have been changed such that if the chain does not contain all variables in the model, an error is thrown. +Previously the behaviour would have been to sample missing variables. + ## 0.38.9 Remove warning when using Enzyme as the AD backend. diff --git a/ext/DynamicPPLMCMCChainsExt.jl b/ext/DynamicPPLMCMCChainsExt.jl index e74f0b8a9..8ad828648 100644 --- a/ext/DynamicPPLMCMCChainsExt.jl +++ b/ext/DynamicPPLMCMCChainsExt.jl @@ -1,41 +1,19 @@ module DynamicPPLMCMCChainsExt -using DynamicPPL: DynamicPPL, AbstractPPL, AbstractMCMC +using DynamicPPL: DynamicPPL, AbstractPPL, AbstractMCMC, Random using MCMCChains: MCMCChains -_has_varname_to_symbol(info::NamedTuple{names}) where {names} = :varname_to_symbol in names - -function DynamicPPL.supports_varname_indexing(chain::MCMCChains.Chains) - return _has_varname_to_symbol(chain.info) -end - -function _check_varname_indexing(c::MCMCChains.Chains) - return DynamicPPL.supports_varname_indexing(c) || - error("This `Chains` object does not support indexing using `VarName`s.") -end - -function DynamicPPL.getindex_varname( +function getindex_varname( c::MCMCChains.Chains, sample_idx, vn::DynamicPPL.VarName, chain_idx ) - _check_varname_indexing(c) return c[sample_idx, c.info.varname_to_symbol[vn], chain_idx] end -function DynamicPPL.varnames(c::MCMCChains.Chains) - _check_varname_indexing(c) +function get_varnames(c::MCMCChains.Chains) + haskey(c.info, :varname_to_symbol) || + error("This `Chains` object does not support indexing using `VarName`s.") return keys(c.info.varname_to_symbol) end -function chain_sample_to_varname_dict( - c::MCMCChains.Chains{Tval}, sample_idx, chain_idx -) where {Tval} - _check_varname_indexing(c) - d = Dict{DynamicPPL.VarName,Tval}() - for vn in DynamicPPL.varnames(c) - d[vn] = DynamicPPL.getindex_varname(c, sample_idx, vn, chain_idx) - end - return d -end - """ AbstractMCMC.from_samples( ::Type{MCMCChains.Chains}, @@ -118,8 +96,8 @@ function AbstractMCMC.to_samples( # Get parameters params_matrix = map(idxs) do (sample_idx, chain_idx) d = DynamicPPL.OrderedCollections.OrderedDict{DynamicPPL.VarName,Any}() - for vn in DynamicPPL.varnames(chain) - d[vn] = DynamicPPL.getindex_varname(chain, sample_idx, vn, chain_idx) + for vn in get_varnames(chain) + d[vn] = getindex_varname(chain, sample_idx, vn, chain_idx) end d end @@ -177,6 +155,46 @@ function AbstractMCMC.bundle_samples( return sort_chain ? sort(chain) : chain end +""" + reevaluate_with_chain( + rng::AbstractRNG, + model::Model, + chain::MCMCChains.Chains + accs::NTuple{N,AbstractAccumulator}; + fallback=nothing, + ) + +Re-evaluate `model` for each sample in `chain` using the accumulators provided in `at`, +returning an matrix of `(retval, updated_at)` tuples. + +This loops over all entries in the chain and uses `DynamicPPL.InitFromParams` as the +initialisation strategy when re-evaluating the model. For many usecases the fallback should +not be provided (as we expect the chain to contain all necessary variables); but for +`predict` this has to be `InitFromPrior()` to allow sampling new variables (i.e. generating +the posterior predictions). +""" +function reevaluate_with_chain( + rng::Random.AbstractRNG, + model::DynamicPPL.Model, + chain::MCMCChains.Chains, + accs::NTuple{N,DynamicPPL.AbstractAccumulator}, + fallback::Union{DynamicPPL.AbstractInitStrategy,Nothing}=nothing, +) where {N} + params_with_stats = AbstractMCMC.to_samples(DynamicPPL.ParamsWithStats, chain) + vi = DynamicPPL.OnlyAccsVarInfo(DynamicPPL.AccumulatorTuple(accs)) + return map(params_with_stats) do ps + DynamicPPL.init!!(rng, model, vi, DynamicPPL.InitFromParams(ps.params, fallback)) + end +end +function reevaluate_with_chain( + model::DynamicPPL.Model, + chain::MCMCChains.Chains, + accs::NTuple{N,DynamicPPL.AbstractAccumulator}, + fallback::Union{DynamicPPL.AbstractInitStrategy,Nothing}=nothing, +) where {N} + return reevaluate_with_chain(Random.default_rng(), model, chain, accs, fallback) +end + """ predict([rng::AbstractRNG,] model::Model, chain::MCMCChains.Chains; include_all=false) @@ -245,30 +263,18 @@ function DynamicPPL.predict( include_all=false, ) parameter_only_chain = MCMCChains.get_sections(chain, :parameters) - - # Set up a VarInfo with the right accumulators - varinfo = DynamicPPL.setaccs!!( - DynamicPPL.VarInfo(), - ( - DynamicPPL.LogPriorAccumulator(), - DynamicPPL.LogLikelihoodAccumulator(), - DynamicPPL.ValuesAsInModelAccumulator(false), - ), + accs = ( + DynamicPPL.LogPriorAccumulator(), + DynamicPPL.LogLikelihoodAccumulator(), + DynamicPPL.ValuesAsInModelAccumulator(false), ) - _, varinfo = DynamicPPL.init!!(model, varinfo) - varinfo = DynamicPPL.typed_varinfo(varinfo) - - params_and_stats = AbstractMCMC.to_samples( - DynamicPPL.ParamsWithStats, parameter_only_chain + predictions = map( + DynamicPPL.ParamsWithStats ∘ last, + reevaluate_with_chain( + rng, model, parameter_only_chain, accs, DynamicPPL.InitFromPrior() + ), ) - predictions = map(params_and_stats) do ps - _, varinfo = DynamicPPL.init!!( - rng, model, varinfo, DynamicPPL.InitFromParams(ps.params) - ) - DynamicPPL.ParamsWithStats(varinfo) - end chain_result = AbstractMCMC.from_samples(MCMCChains.Chains, predictions) - parameter_names = if include_all MCMCChains.names(chain_result, :parameters) else @@ -348,18 +354,7 @@ julia> returned(model, chain) """ function DynamicPPL.returned(model::DynamicPPL.Model, chain_full::MCMCChains.Chains) chain = MCMCChains.get_sections(chain_full, :parameters) - varinfo = DynamicPPL.VarInfo(model) - iters = Iterators.product(1:size(chain, 1), 1:size(chain, 3)) - params_with_stats = AbstractMCMC.to_samples(DynamicPPL.ParamsWithStats, chain) - return map(params_with_stats) do ps - first( - DynamicPPL.init!!( - model, - varinfo, - DynamicPPL.InitFromParams(ps.params, DynamicPPL.InitFromPrior()), - ), - ) - end + return map(first, reevaluate_with_chain(model, chain, (), nothing)) end """ @@ -452,24 +447,13 @@ function DynamicPPL.pointwise_logdensities( ::Type{Tout}=MCMCChains.Chains, ::Val{whichlogprob}=Val(:both), ) where {whichlogprob,Tout} - vi = DynamicPPL.VarInfo(model) acc = DynamicPPL.PointwiseLogProbAccumulator{whichlogprob}() accname = DynamicPPL.accumulator_name(acc) - vi = DynamicPPL.setaccs!!(vi, (acc,)) parameter_only_chain = MCMCChains.get_sections(chain, :parameters) - iters = Iterators.product(1:size(chain, 1), 1:size(chain, 3)) - pointwise_logps = map(iters) do (sample_idx, chain_idx) - # Extract values from the chain - values_dict = chain_sample_to_varname_dict(parameter_only_chain, sample_idx, chain_idx) - # Re-evaluate the model - _, vi = DynamicPPL.init!!( - model, - vi, - DynamicPPL.InitFromParams(values_dict, DynamicPPL.InitFromPrior()), - ) - DynamicPPL.getacc(vi, Val(accname)).logps - end - + pointwise_logps = + map(reevaluate_with_chain(model, parameter_only_chain, (acc,), nothing)) do (_, vi) + DynamicPPL.getacc(vi, Val(accname)).logps + end # pointwise_logps is a matrix of OrderedDicts all_keys = DynamicPPL.OrderedCollections.OrderedSet{DynamicPPL.VarName}() for d in pointwise_logps @@ -556,15 +540,15 @@ julia> logjoint(demo_model([1., 2.]), chain) ``` """ function DynamicPPL.logjoint(model::DynamicPPL.Model, chain::MCMCChains.Chains) - var_info = DynamicPPL.VarInfo(model) # extract variables info from the model - map(Iterators.product(1:size(chain, 1), 1:size(chain, 3))) do (iteration_idx, chain_idx) - argvals_dict = DynamicPPL.OrderedCollections.OrderedDict{DynamicPPL.VarName,Any}( - vn_parent => DynamicPPL.values_from_chain( - var_info, vn_parent, chain, chain_idx, iteration_idx - ) for vn_parent in keys(var_info) - ) - DynamicPPL.logjoint(model, argvals_dict) - end + return map( + DynamicPPL.getlogjoint ∘ last, + reevaluate_with_chain( + model, + chain, + (DynamicPPL.LogPriorAccumulator(), DynamicPPL.LogLikelihoodAccumulator()), + nothing, + ), + ) end """ @@ -596,15 +580,12 @@ julia> loglikelihood(demo_model([1., 2.]), chain) ``` """ function DynamicPPL.loglikelihood(model::DynamicPPL.Model, chain::MCMCChains.Chains) - var_info = DynamicPPL.VarInfo(model) # extract variables info from the model - map(Iterators.product(1:size(chain, 1), 1:size(chain, 3))) do (iteration_idx, chain_idx) - argvals_dict = DynamicPPL.OrderedCollections.OrderedDict{DynamicPPL.VarName,Any}( - vn_parent => DynamicPPL.values_from_chain( - var_info, vn_parent, chain, chain_idx, iteration_idx - ) for vn_parent in keys(var_info) - ) - DynamicPPL.loglikelihood(model, argvals_dict) - end + return map( + DynamicPPL.getloglikelihood ∘ last, + reevaluate_with_chain( + model, chain, (DynamicPPL.LogLikelihoodAccumulator(),), nothing + ), + ) end """ @@ -637,15 +618,10 @@ julia> logprior(demo_model([1., 2.]), chain) ``` """ function DynamicPPL.logprior(model::DynamicPPL.Model, chain::MCMCChains.Chains) - var_info = DynamicPPL.VarInfo(model) # extract variables info from the model - map(Iterators.product(1:size(chain, 1), 1:size(chain, 3))) do (iteration_idx, chain_idx) - argvals_dict = DynamicPPL.OrderedCollections.OrderedDict{DynamicPPL.VarName,Any}( - vn_parent => DynamicPPL.values_from_chain( - var_info, vn_parent, chain, chain_idx, iteration_idx - ) for vn_parent in keys(var_info) - ) - DynamicPPL.logprior(model, argvals_dict) - end + return map( + DynamicPPL.getlogprior ∘ last, + reevaluate_with_chain(model, chain, (DynamicPPL.LogPriorAccumulator(),), nothing), + ) end end diff --git a/src/chains.jl b/src/chains.jl index f176b8e68..2fcd4e713 100644 --- a/src/chains.jl +++ b/src/chains.jl @@ -1,29 +1,3 @@ -""" - supports_varname_indexing(chain::AbstractChains) - -Return `true` if `chain` supports indexing using `VarName` in place of the -variable name index. -""" -supports_varname_indexing(::AbstractChains) = false - -""" - getindex_varname(chain::AbstractChains, sample_idx, varname::VarName, chain_idx) - -Return the value of `varname` in `chain` at `sample_idx` and `chain_idx`. - -Whether this method is implemented for `chains` is indicated by [`supports_varname_indexing`](@ref). -""" -function getindex_varname end - -""" - varnames(chains::AbstractChains) - -Return an iterator over the varnames present in `chains`. - -Whether this method is implemented for `chains` is indicated by [`supports_varname_indexing`](@ref). -""" -function varnames end - """ ParamsWithStats diff --git a/src/logdensityfunction.jl b/src/logdensityfunction.jl index 65eab448e..bcdd0bb25 100644 --- a/src/logdensityfunction.jl +++ b/src/logdensityfunction.jl @@ -193,21 +193,21 @@ end # LogDensityProblems.jl interface # ################################### """ - fast_ldf_accs(getlogdensity::Function) + ldf_accs(getlogdensity::Function) Determine which accumulators are needed for fast evaluation with the given `getlogdensity` function. """ -fast_ldf_accs(::Function) = default_accumulators() -fast_ldf_accs(::typeof(getlogjoint_internal)) = default_accumulators() -function fast_ldf_accs(::typeof(getlogjoint)) +ldf_accs(::Function) = default_accumulators() +ldf_accs(::typeof(getlogjoint_internal)) = default_accumulators() +function ldf_accs(::typeof(getlogjoint)) return AccumulatorTuple((LogPriorAccumulator(), LogLikelihoodAccumulator())) end -function fast_ldf_accs(::typeof(getlogprior_internal)) +function ldf_accs(::typeof(getlogprior_internal)) return AccumulatorTuple((LogPriorAccumulator(), LogJacobianAccumulator())) end -fast_ldf_accs(::typeof(getlogprior)) = AccumulatorTuple((LogPriorAccumulator(),)) -fast_ldf_accs(::typeof(getloglikelihood)) = AccumulatorTuple((LogLikelihoodAccumulator(),)) +ldf_accs(::typeof(getlogprior)) = AccumulatorTuple((LogPriorAccumulator(),)) +ldf_accs(::typeof(getloglikelihood)) = AccumulatorTuple((LogLikelihoodAccumulator(),)) struct LogDensityAt{M<:Model,F<:Function,N<:NamedTuple} model::M @@ -219,7 +219,7 @@ function (f::LogDensityAt)(params::AbstractVector{<:Real}) strategy = InitFromParams( VectorWithRanges(f.iden_varname_ranges, f.varname_ranges, params), nothing ) - accs = fast_ldf_accs(f.getlogdensity) + accs = ldf_accs(f.getlogdensity) _, vi = DynamicPPL.init!!(f.model, OnlyAccsVarInfo(accs), strategy) return f.getlogdensity(vi) end diff --git a/src/model.jl b/src/model.jl index 9029318b1..7d5bbf2fb 100644 --- a/src/model.jl +++ b/src/model.jl @@ -1181,12 +1181,15 @@ julia> returned(model, Dict{VarName,Float64}(@varname(m) => 2.0)) ``` """ function returned(model::Model, parameters::Union{NamedTuple,AbstractDict{<:VarName}}) - vi = DynamicPPL.setaccs!!(VarInfo(), ()) # Note: we can't use `fix(model, parameters)` because # https://github.com/TuringLang/DynamicPPL.jl/issues/1097 - # Use `nothing` as the fallback to ensure that any missing parameters cause an error - ctx = InitContext(Random.default_rng(), InitFromParams(parameters, nothing)) - new_model = setleafcontext(model, ctx) - # We can't use new_model() because that overwrites it with an InitContext of its own. - return first(evaluate!!(new_model, vi)) + return first( + init!!( + model, + DynamicPPL.OnlyAccsVarInfo(DynamicPPL.AccumulatorTuple()), + # Use `nothing` as the fallback to ensure that any missing parameters cause an + # error + InitFromParams(parameters, nothing), + ), + ) end From 04b3383dafcfa9beb14cee411e42a9d2794043c3 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Tue, 25 Nov 2025 16:17:06 +0000 Subject: [PATCH 22/45] Fixes and improvements to VNT --- src/varnamedtuple.jl | 199 ++++++++++++++++++++++++------------------ test/varnamedtuple.jl | 81 +++++++++++++++++ 2 files changed, 194 insertions(+), 86 deletions(-) diff --git a/src/varnamedtuple.jl b/src/varnamedtuple.jl index aed87017e..c8c7883dd 100644 --- a/src/varnamedtuple.jl +++ b/src/varnamedtuple.jl @@ -391,37 +391,57 @@ function _merge_recursive(pa1::PartialArray, pa2::PartialArray) return result end -struct VarNamedTuple{Names,Values} - data::NamedTuple{Names,Values} +function Base.keys(pa::PartialArray) + inds = findall(pa.mask) + lenses = map(x -> IndexLens(Tuple(x)), inds) + ks = Any[] + for l in lenses + val = getindex(pa.data, l.indices...) + if val isa VarNamedTuple + subkeys = keys(val) + for vn in subkeys + lens = varname_to_lens(vn) + push!(ks, _compose_no_identity(lens, l)) + end + else + push!(ks, l) + end + end + return ks end -# TODO(mhauru) Since I define this, should I also define `isequal` and `hash`? Same for -# PartialArrays. -Base.:(==)(vnt1::VarNamedTuple, vnt2::VarNamedTuple) = vnt1.data == vnt2.data +""" + VarNamedTuple{Names,Values} -Base.merge(x1::VarNamedTuple, x2::VarNamedTuple) = _merge_recursive(x1, x2) +A `NamedTuple`-like structure with `VarName` keys. -function make_leaf(value, ::PropertyLens{S}) where {S} - return VarNamedTuple(NamedTuple{(S,)}((value,))) -end -make_leaf(value, ::typeof(identity)) = value -function make_leaf(value, optic::ComposedFunction) - sub = make_leaf(value, optic.outer) - return make_leaf(sub, optic.inner) -end +`VarNamedTuple` is a data structure for storing arbitrary data, keyed by `VarName`s, in an +efficient and type stable manner. It is mainly used through `getindex`, `setindex!!`, and +`haskey`, all of which accept `VarName`s and only `VarName`s as keys. Other notable methods +are `merge`, which recursively merges two `VarNamedTuple`s. -function make_leaf(value, optic::IndexLens{T}) where {T} - inds = optic.indices - num_inds = length(inds) - # Check if any of the indices are ranges or colons. If yes, value needs to be an - # AbstractArray. Otherwise it needs to be an individual value. - et = _is_multiindex(optic.indices) ? eltype(value) : typeof(value) - iarr = PartialArray{et,num_inds}() - return setindex!!(iarr, value, optic) +The one major limitation is that indexing by `VarName`s with `Colon`s, (e.g. `a[:]`) is not +supported. This is because the meaning of `a[:]` is ambiguous if only some elements of `a`, +say `a[1]` and `a[3]`, are defined. + +`setindex!!` and `getindex` on `VarNamedTuple` are type stable as long as one does not store +heterogeneous data under different indices of the same symbol. That is, if one either + +* sets `a[1]` and `a[2]` to be of different types, or +* sets `a[1].b` and `a[2].c`, without setting `a[1].c`. or `a[2].b`, + +then getting values for `a[1]` or `a[2]` will not be type stable. +""" +struct VarNamedTuple{Names,Values} + data::NamedTuple{Names,Values} end VarNamedTuple() = VarNamedTuple((;)) +Base.:(==)(vnt1::VarNamedTuple, vnt2::VarNamedTuple) = vnt1.data == vnt2.data +Base.hash(vnt::VarNamedTuple, h::UInt) = hash(vnt.data, h) + +# TODO(mhauru) Rework this printing. function Base.show(io::IO, vnt::VarNamedTuple) print(io, "(") for (i, (name, value)) in enumerate(pairs(vnt.data)) @@ -434,26 +454,22 @@ function Base.show(io::IO, vnt::VarNamedTuple) return print(io, ")") end -_getindex(vnt::VarNamedTuple, name::Symbol) = vnt.data[name] +""" + varname_to_lens(name::VarName{S}) where {S} +Convert a `VarName` to an `Accessor` lens, wrapping the first symdol in a `PropertyLens`. +""" function varname_to_lens(name::VarName{S}) where {S} return _compose_no_identity(getoptic(name), PropertyLens{S}()) end -function _getindex(vnt::VarNamedTuple, name::VarName) - return _getindex(vnt, varname_to_lens(name)) -end -function _getindex(vnt::VarNamedTuple, ::PropertyLens{S}) where {S} - return _getindex(vnt.data, S) -end - -function _haskey(vnt::VarNamedTuple, name::VarName) - return _haskey(vnt, varname_to_lens(name)) -end - -_haskey(vnt::VarNamedTuple, ::typeof(identity)) = true +_getindex(vnt::VarNamedTuple, name::VarName) = _getindex(vnt, varname_to_lens(name)) +_getindex(vnt::VarNamedTuple, ::PropertyLens{S}) where {S} = _getindex(vnt.data, S) +_getindex(vnt::VarNamedTuple, name::Symbol) = vnt.data[name] +_haskey(vnt::VarNamedTuple, name::VarName) = _haskey(vnt, varname_to_lens(name)) _haskey(vnt::VarNamedTuple, ::PropertyLens{S}) where {S} = _haskey(vnt.data, S) +_haskey(vnt::VarNamedTuple, ::typeof(identity)) = true _haskey(::VarNamedTuple, ::IndexLens) = false function _setindex!!(vnt::VarNamedTuple, value, name::VarName) @@ -468,8 +484,41 @@ function _setindex!!(vnt::VarNamedTuple, value, ::PropertyLens{S}) where {S} return VarNamedTuple(merge(vnt.data, NamedTuple{(S,)}((value,)))) end -function apply(func, vnt::VarNamedTuple, name::VarName) - if !haskey(vnt.data, name.name) +Base.merge(x1::VarNamedTuple, x2::VarNamedTuple) = _merge_recursive(x1, x2) + +# TODO(mhauru) Check the performance of this, and make it into a generated function if +# necessary. +function _merge_recursive(vnt1::VarNamedTuple, vnt2::VarNamedTuple) + result_data = vnt1.data + for k in keys(vnt2.data) + val = if haskey(result_data, k) + _merge_recursive(result_data[k], vnt2.data[k]) + else + vnt2.data[k] + end + Accessors.@reset result_data[k] = val + end + return VarNamedTuple(result_data) +end + +""" + apply!!(func, vnt::VarNamedTuple, name::VarName) + +Apply `func` to the subdata at `name` in `vnt`, and set the result back at `name`. + +```jldoctest +julia> vnt = VarNamedTuple() +() + +julia> vnt = setindex!!(vnt, [1,2,3], @varname(a)) +(a -> [1, 2, 3]) + +julia> VarNamedTuples.apply!!(x -> x .+ 1, vnt, @varname(a)) +(a -> [2, 3, 4]) +``` +""" +function apply!!(func, vnt::VarNamedTuple, name::VarName) + if !haskey(vnt, name) throw(KeyError(repr(name))) end subdata = _getindex(vnt, name) @@ -477,11 +526,6 @@ function apply(func, vnt::VarNamedTuple, name::VarName) return _setindex!!(vnt, new_subdata, name) end -function Base.map(func, vnt::VarNamedTuple) - new_data = NamedTuple{keys(vnt.data)}(map(func, values(vnt.data))) - return VarNamedTuple(new_data) -end - function Base.keys(vnt::VarNamedTuple) result = () for sym in keys(vnt.data) @@ -489,54 +533,18 @@ function Base.keys(vnt::VarNamedTuple) if subdata isa VarNamedTuple subkeys = keys(subdata) result = ( - (AbstractPPL.prefix(sk, VarName{sym}()) for sk in subkeys)..., result... + result..., (AbstractPPL.prefix(sk, VarName{sym}()) for sk in subkeys)... ) + elseif subdata isa PartialArray + subkeys = keys(subdata) + result = (result..., (VarName{sym}(lens) for lens in subkeys)...) else - result = (VarName{sym}(), result...) + result = (result..., VarName{sym}()) end - subkeys = keys(vnt.data[sym]) end return result end -function _haskey(vnt::VarNamedTuple, name::VarName{S,Optic}) where {S,Optic} - if !haskey(vnt.data, S) - return false - end - subdata = vnt.data[S] - return if Optic === typeof(identity) - true - elseif Optic <: IndexLens - try - AbstractPPL.getoptic(name)(subdata) - true - catch e - if e isa BoundsError || e isa KeyError - false - else - rethrow(e) - end - end - else - haskey(subdata, AbstractPPL.unprefix(name, VarName{S}())) - end -end - -# TODO(mhauru) Check the performance of this, and make it into a generated function if -# necessary. -function _merge_recursive(vnt1::VarNamedTuple, vnt2::VarNamedTuple) - result_data = vnt1.data - for k in keys(vnt2.data) - val = if haskey(result_data, k) - _merge_recursive(result_data[k], vnt2.data[k]) - else - vnt2.data[k] - end - Accessors.@reset result_data[k] = val - end - return VarNamedTuple(result_data) -end - # The following methods, indexing with ComposedFunction, are exactly the same for # VarNamedTuple and PartialArray, since they just fall back on indexing with the outer and # inner lenses. @@ -561,11 +569,30 @@ function _haskey(vnt::VNT_OR_PA, optic::ComposedFunction) end # The entry points for getting, setting, and checking, using the familiar functions. -Base.haskey(vnt::VarNamedTuple, key) = _haskey(vnt, key) -Base.getindex(vnt::VarNamedTuple, inds...) = _getindex(vnt, inds...) -BangBang.setindex!!(vnt::VarNamedTuple, value, inds...) = _setindex!!(vnt, value, inds...) +Base.haskey(vnt::VarNamedTuple, vn::VarName) = _haskey(vnt, vn) +Base.getindex(vnt::VarNamedTuple, vn::VarName) = _getindex(vnt, vn) +BangBang.setindex!!(vnt::VarNamedTuple, value, vn::VarName) = _setindex!!(vnt, value, vn) Base.haskey(vnt::PartialArray, key) = _haskey(vnt, key) Base.getindex(vnt::PartialArray, inds...) = _getindex(vnt, inds...) BangBang.setindex!!(vnt::PartialArray, value, inds...) = _setindex!!(vnt, value, inds...) +function make_leaf(value, ::PropertyLens{S}) where {S} + return VarNamedTuple(NamedTuple{(S,)}((value,))) +end +make_leaf(value, ::typeof(identity)) = value +function make_leaf(value, optic::ComposedFunction) + sub = make_leaf(value, optic.outer) + return make_leaf(sub, optic.inner) +end + +function make_leaf(value, optic::IndexLens{T}) where {T} + inds = optic.indices + num_inds = length(inds) + # Check if any of the indices are ranges or colons. If yes, value needs to be an + # AbstractArray. Otherwise it needs to be an individual value. + et = _is_multiindex(optic.indices) ? eltype(value) : typeof(value) + iarr = PartialArray{et,num_inds}() + return setindex!!(iarr, value, optic) +end + end diff --git a/test/varnamedtuple.jl b/test/varnamedtuple.jl index 08a65b018..e3e98d270 100644 --- a/test/varnamedtuple.jl +++ b/test/varnamedtuple.jl @@ -236,6 +236,87 @@ using BangBang: setindex!! expected_merge_21 = setindex!!(expected_merge_12, 1, @varname(a[1, 1])) @test merge(vnt2, vnt1) == expected_merge_21 end + + @testset "keys" begin + vnt = VarNamedTuple() + @test keys(vnt) == () + @test all(x -> haskey(vnt, x), keys(vnt)) + + vnt = setindex!!(vnt, 1.0, @varname(a)) + @test keys(vnt) == (@varname(a),) + @test all(x -> haskey(vnt, x), keys(vnt)) + + vnt = setindex!!(vnt, [1, 2, 3], @varname(b)) + @test keys(vnt) == (@varname(a), @varname(b)) + @test all(x -> haskey(vnt, x), keys(vnt)) + + vnt = setindex!!(vnt, 15, @varname(b[2])) + @test keys(vnt) == (@varname(a), @varname(b)) + @test all(x -> haskey(vnt, x), keys(vnt)) + + vnt = setindex!!(vnt, [10], @varname(c.x.y)) + @test keys(vnt) == (@varname(a), @varname(b), @varname(c.x.y)) + @test all(x -> haskey(vnt, x), keys(vnt)) + + vnt = setindex!!(vnt, -1.0, @varname(d[4])) + @test keys(vnt) == (@varname(a), @varname(b), @varname(c.x.y), @varname(d[4])) + @test all(x -> haskey(vnt, x), keys(vnt)) + + vnt = setindex!!(vnt, 2.0, @varname(e.f[3, 3].g.h[2, 4, 1].i)) + @test keys(vnt) == ( + @varname(a), + @varname(b), + @varname(c.x.y), + @varname(d[4]), + @varname(e.f[3, 3].g.h[2, 4, 1].i), + ) + @test all(x -> haskey(vnt, x), keys(vnt)) + + vnt = setindex!!(vnt, fill(1.0, 4), @varname(j[1:4])) + @test keys(vnt) == ( + @varname(a), + @varname(b), + @varname(c.x.y), + @varname(d[4]), + @varname(e.f[3, 3].g.h[2, 4, 1].i), + @varname(j[1]), + @varname(j[2]), + @varname(j[3]), + @varname(j[4]), + ) + @test all(x -> haskey(vnt, x), keys(vnt)) + + vnt = setindex!!(vnt, 1.0, @varname(j[6])) + @test keys(vnt) == ( + @varname(a), + @varname(b), + @varname(c.x.y), + @varname(d[4]), + @varname(e.f[3, 3].g.h[2, 4, 1].i), + @varname(j[1]), + @varname(j[2]), + @varname(j[3]), + @varname(j[4]), + @varname(j[6]), + ) + @test all(x -> haskey(vnt, x), keys(vnt)) + + vnt = setindex!!(vnt, 1.0, @varname(n[2].a)) + @test keys(vnt) == ( + @varname(a), + @varname(b), + @varname(c.x.y), + @varname(d[4]), + @varname(e.f[3, 3].g.h[2, 4, 1].i), + @varname(j[1]), + @varname(j[2]), + @varname(j[3]), + @varname(j[4]), + @varname(j[6]), + @varname(n[2].a), + ) + @test all(x -> haskey(vnt, x), keys(vnt)) + end end end From 59c4dcbba214d484faa4cbf76e206b76c38496da Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Tue, 25 Nov 2025 17:47:38 +0000 Subject: [PATCH 23/45] Proper printing and constructors --- src/varnamedtuple.jl | 103 +++++++++++++++++++++++++++--------------- test/varnamedtuple.jl | 75 +++++++++++++++++++++++++++++- 2 files changed, 141 insertions(+), 37 deletions(-) diff --git a/src/varnamedtuple.jl b/src/varnamedtuple.jl index c8c7883dd..7880275a5 100644 --- a/src/varnamedtuple.jl +++ b/src/varnamedtuple.jl @@ -8,23 +8,23 @@ using ..DynamicPPL: _compose_no_identity export VarNamedTuple -# We define our own getindex, setindex!!, and haskey functions to be able to override their -# behaviour for some types exported from elsewhere without type piracy. This is needed -# because -# 1. We want to index into things with lenses (from Accessors.jl) using getindex and -# setindex!!. -# 2. We want to use getindex, setindex!!, and haskey as the universal functions for getting, -# setting, checking. This includes e.g. checking whether an index is valid for an Array, -# which would normally be done with checkbounds. -_haskey(x, key) = Base.haskey(x, key) -_getindex(x, inds...) = Base.getindex(x, inds...) -_setindex!!(x, value, inds...) = BangBang.setindex!!(x, value, inds...) -_getindex(arr::AbstractArray, optic::IndexLens) = _getindex(arr, optic.indices...) +# We define our own getindex, setindex!!, and haskey functions, which we use to +# get/set/check values in VarNamedTuple and PartialArray. We do this because we want to be +# able to override their behaviour for some types exported from elsewhere without type +# piracy. This is needed because +# 1. We would want to index into things with lenses (from Accessors.jl) using getindex and +# setindex!!, but Accessors does not define these methods. +# 2. We would want `haskey` to fall back onto `checkbounds` when called on Base.Arrays. +function _getindex end +function _haskey end +function _setindex!! end + +_getindex(arr::AbstractArray, optic::IndexLens) = getindex(arr, optic.indices...) _haskey(arr::AbstractArray, optic::IndexLens) = _haskey(arr, optic.indices) +_haskey(arr::AbstractArray, inds) = checkbounds(Bool, arr, inds...) function _setindex!!(arr::AbstractArray, value, optic::IndexLens) - return _setindex!!(arr, value, optic.indices...) + return setindex!!(arr, value, optic.indices...) end -_haskey(arr::AbstractArray, inds) = checkbounds(Bool, arr, inds...) # Some utilities for checking what sort of indices we are dealing with. _has_colon(::T) where {T<:Tuple} = any(x <: Colon for x in T.parameters) @@ -122,22 +122,44 @@ purely a performance optimisation, to avoid resizing if the eventual size is kno time. """ function PartialArray{ElType,num_dims}( - min_size::Union{Tuple,Nothing}=nothing + args::Vararg{Pair}; min_size::Union{Tuple,Nothing}=nothing ) where {ElType,num_dims} - if min_size === nothing - dims = ntuple(_ -> PARTIAL_ARRAY_DIM_GROWTH_FACTOR, num_dims) + dims = if min_size === nothing + ntuple(_ -> PARTIAL_ARRAY_DIM_GROWTH_FACTOR, num_dims) else - dims = map(_partial_array_dim_size, min_size) + map(_partial_array_dim_size, min_size) end - dims = ntuple(_ -> PARTIAL_ARRAY_DIM_GROWTH_FACTOR, num_dims) data = Array{ElType,num_dims}(undef, dims) mask = fill(false, dims) - return PartialArray(data, mask) + pa = PartialArray(data, mask) + + for (inds, value) in args + pa = _setindex!!(pa, convert(ElType, value), inds...) + end + return pa end Base.ndims(::PartialArray{ElType,num_dims}) where {ElType,num_dims} = num_dims Base.eltype(::PartialArray{ElType}) where {ElType} = ElType +function Base.show(io::IO, pa::PartialArray) + print(io, "PartialArray{", eltype(pa), ",", ndims(pa), "}(") + is_first = true + for inds in CartesianIndices(pa.mask) + if @inbounds(!pa.mask[inds]) + continue + end + if !is_first + print(io, ", ") + is_first = false + end + val = @inbounds(pa.data[inds]) + print(io, Tuple(inds), " => ", val) + end + print(")") + return nothing +end + # We deliberately don't define Base.size for PartialArray, because it is ill-defined. # The size of the .data field is an implementation detail. _internal_size(pa::PartialArray, args...) = size(pa.data, args...) @@ -420,9 +442,10 @@ efficient and type stable manner. It is mainly used through `getindex`, `setinde `haskey`, all of which accept `VarName`s and only `VarName`s as keys. Other notable methods are `merge`, which recursively merges two `VarNamedTuple`s. -The one major limitation is that indexing by `VarName`s with `Colon`s, (e.g. `a[:]`) is not -supported. This is because the meaning of `a[:]` is ambiguous if only some elements of `a`, -say `a[1]` and `a[3]`, are defined. +The there are two major limitations to indexing by VarNamedTuples: + +* `VarName`s with `Colon`s, (e.g. `a[:]`) are not supported. This is because the meaning of `a[:]` is ambiguous if only some elements of `a`, say `a[1]` and `a[3]`, are defined. +* Any `VarNames` with IndexLenses` must have a consistent number of indices. That is, one cannot set `a[1]` and `a[1,2]` in the same `VarNamedTuple`. `setindex!!` and `getindex` on `VarNamedTuple` are type stable as long as one does not store heterogeneous data under different indices of the same symbol. That is, if one either @@ -436,20 +459,18 @@ struct VarNamedTuple{Names,Values} data::NamedTuple{Names,Values} end -VarNamedTuple() = VarNamedTuple((;)) +VarNamedTuple(; kwargs...) = VarNamedTuple((; kwargs...)) Base.:(==)(vnt1::VarNamedTuple, vnt2::VarNamedTuple) = vnt1.data == vnt2.data Base.hash(vnt::VarNamedTuple, h::UInt) = hash(vnt.data, h) -# TODO(mhauru) Rework this printing. function Base.show(io::IO, vnt::VarNamedTuple) - print(io, "(") + print(io, "VarNamedTuple(;") for (i, (name, value)) in enumerate(pairs(vnt.data)) if i > 1 - print(io, ", ") + print(io, ",") end - print(io, name, " -> ") - print(io, value) + print(io, " ", name, "=", value) end return print(io, ")") end @@ -464,11 +485,11 @@ function varname_to_lens(name::VarName{S}) where {S} end _getindex(vnt::VarNamedTuple, name::VarName) = _getindex(vnt, varname_to_lens(name)) -_getindex(vnt::VarNamedTuple, ::PropertyLens{S}) where {S} = _getindex(vnt.data, S) +_getindex(vnt::VarNamedTuple, ::PropertyLens{S}) where {S} = getindex(vnt.data, S) _getindex(vnt::VarNamedTuple, name::Symbol) = vnt.data[name] _haskey(vnt::VarNamedTuple, name::VarName) = _haskey(vnt, varname_to_lens(name)) -_haskey(vnt::VarNamedTuple, ::PropertyLens{S}) where {S} = _haskey(vnt.data, S) +_haskey(vnt::VarNamedTuple, ::PropertyLens{S}) where {S} = haskey(vnt.data, S) _haskey(vnt::VarNamedTuple, ::typeof(identity)) = true _haskey(::VarNamedTuple, ::IndexLens) = false @@ -572,14 +593,24 @@ end Base.haskey(vnt::VarNamedTuple, vn::VarName) = _haskey(vnt, vn) Base.getindex(vnt::VarNamedTuple, vn::VarName) = _getindex(vnt, vn) BangBang.setindex!!(vnt::VarNamedTuple, value, vn::VarName) = _setindex!!(vnt, value, vn) + Base.haskey(vnt::PartialArray, key) = _haskey(vnt, key) Base.getindex(vnt::PartialArray, inds...) = _getindex(vnt, inds...) BangBang.setindex!!(vnt::PartialArray, value, inds...) = _setindex!!(vnt, value, inds...) -function make_leaf(value, ::PropertyLens{S}) where {S} - return VarNamedTuple(NamedTuple{(S,)}((value,))) -end +""" + make_leaf(value, optic) + +Make a new leaf node for a VarNamedTuple. + +This is the function that sets any `optic` that is a `PropertyLens` to be stored as a +`VarNamedTuple`, any `IndexLens` to be stored as a `PartialArray`, and other `identity` +optics to be stored as raw values. It is the link that joins `VarNamedTuple` and +`PartialArray` together. +""" make_leaf(value, ::typeof(identity)) = value +make_leaf(value, ::PropertyLens{S}) where {S} = VarNamedTuple(NamedTuple{(S,)}((value,))) + function make_leaf(value, optic::ComposedFunction) sub = make_leaf(value, optic.outer) return make_leaf(sub, optic.inner) @@ -591,8 +622,8 @@ function make_leaf(value, optic::IndexLens{T}) where {T} # Check if any of the indices are ranges or colons. If yes, value needs to be an # AbstractArray. Otherwise it needs to be an individual value. et = _is_multiindex(optic.indices) ? eltype(value) : typeof(value) - iarr = PartialArray{et,num_inds}() - return setindex!!(iarr, value, optic) + pa = PartialArray{et,num_inds}() + return _setindex!!(pa, value, optic) end end diff --git a/test/varnamedtuple.jl b/test/varnamedtuple.jl index e3e98d270..77edefa9a 100644 --- a/test/varnamedtuple.jl +++ b/test/varnamedtuple.jl @@ -6,11 +6,43 @@ using BangBang: setindex!! @testset "VarNamedTuple" begin @testset "Construction" begin + vnt1 = VarNamedTuple() + vnt1 = setindex!!(vnt1, 1.0, @varname(a)) + vnt1 = setindex!!(vnt1, [1, 2, 3], @varname(b)) + vnt1 = setindex!!(vnt1, "a", @varname(c.d.e)) + vnt2 = VarNamedTuple(; + a=1.0, b=[1, 2, 3], c=VarNamedTuple(; d=VarNamedTuple(; e="a")) + ) + @test vnt1 == vnt2 + pa1 = DynamicPPL.VarNamedTuples.PartialArray{Float64,1}() pa1 = setindex!!(pa1, 1.0, 16) - pa2 = DynamicPPL.VarNamedTuples.PartialArray{Float64,1}((16,)) + pa2 = DynamicPPL.VarNamedTuples.PartialArray{Float64,1}(; min_size=(16,)) pa2 = setindex!!(pa2, 1.0, 16) + pa3 = DynamicPPL.VarNamedTuples.PartialArray{Float64,1}(16 => 1.0) + pa4 = DynamicPPL.VarNamedTuples.PartialArray{Float64,1}((16,) => 1.0) + @test pa1 == pa2 + @test pa1 == pa3 + @test pa1 == pa4 + + pa1 = DynamicPPL.VarNamedTuples.PartialArray{String,3}() + pa1 = setindex!!(pa1, "a", 2, 3, 4) + pa1 = setindex!!(pa1, "b", 1, 2, 4) + pa2 = DynamicPPL.VarNamedTuples.PartialArray{String,3}(; min_size=(16, 16, 16)) + pa2 = setindex!!(pa2, "a", 2, 3, 4) + pa2 = setindex!!(pa2, "b", 1, 2, 4) + pa3 = DynamicPPL.VarNamedTuples.PartialArray{String,3}( + (2, 3, 4) => "a", (1, 2, 4) => "b" + ) @test pa1 == pa2 + @test pa1 == pa3 + + @test_throws BoundsError DynamicPPL.VarNamedTuples.PartialArray{Int,1}((0,) => 1) + @test_throws BoundsError DynamicPPL.VarNamedTuples.PartialArray{Int,1}((1, 2) => 1) + @test_throws MethodError DynamicPPL.VarNamedTuples.PartialArray{Int,1}((1,) => "a") + @test_throws MethodError DynamicPPL.VarNamedTuples.PartialArray{Int,1}( + (1,) => 1; min_size=(2, 2) + ) end @testset "Basic sets and gets" begin @@ -317,6 +349,47 @@ using BangBang: setindex!! ) @test all(x -> haskey(vnt, x), keys(vnt)) end + + @testset "printing" begin + vnt = VarNamedTuple() + io = IOBuffer() + show(io, vnt) + output = String(take!(io)) + @test output == "VarNamedTuple(;)" + + vnt = setindex!!(vnt, 1.0, @varname(a)) + io = IOBuffer() + show(io, vnt) + output = String(take!(io)) + @test output == "VarNamedTuple(; a=1.0)" + + vnt = setindex!!(vnt, [1, 2, 3], @varname(b)) + io = IOBuffer() + show(io, vnt) + output = String(take!(io)) + @test output == "VarNamedTuple(; a=1.0, b=[1, 2, 3])" + + vnt = setindex!!(vnt, 15, @varname(c[2])) + io = IOBuffer() + show(io, vnt) + output = String(take!(io)) + @test output == """ + VarNamedTuple(; a=1.0, b=[1, 2, 3], c=PartialArray{Int64,1}((2,) => 15)""" + + vnt = setindex!!(vnt, [16.0, 17.0], @varname(d.e[3].f.g[1:2])) + io = IOBuffer() + show(io, vnt) + output = String(take!(io)) + @test output == """ + VarNamedTuple(; a=1.0, b=[1, 2, 3], \ + c=PartialArray{Int64,1}((2,) => 15, \ + d=VarNamedTuple(; \ + e=PartialArray{DynamicPPL.VarNamedTuples.VarNamedTuple{(:f,), \ + Tuple{DynamicPPL.VarNamedTuples.VarNamedTuple{(:g,), \ + Tuple{PartialArray{Float64, 1}}}}},1}((3,) => \ + VarNamedTuple(; f=VarNamedTuple(; g=PartialArray{Float64,1}((1,) => \ + 16.0(2,) => 17.0))))""" + end end end From 381b1dd4b1bb4646e61c962295908cbf015f9ff5 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Tue, 25 Nov 2025 19:05:57 +0000 Subject: [PATCH 24/45] Fix PartialArray printing --- src/varnamedtuple.jl | 3 ++- test/varnamedtuple.jl | 14 +++++++------- 2 files changed, 9 insertions(+), 8 deletions(-) diff --git a/src/varnamedtuple.jl b/src/varnamedtuple.jl index 7880275a5..e47b27e9e 100644 --- a/src/varnamedtuple.jl +++ b/src/varnamedtuple.jl @@ -151,12 +151,13 @@ function Base.show(io::IO, pa::PartialArray) end if !is_first print(io, ", ") + else is_first = false end val = @inbounds(pa.data[inds]) print(io, Tuple(inds), " => ", val) end - print(")") + print(io, ")") return nothing end diff --git a/test/varnamedtuple.jl b/test/varnamedtuple.jl index 77edefa9a..7b26e9be7 100644 --- a/test/varnamedtuple.jl +++ b/test/varnamedtuple.jl @@ -374,7 +374,7 @@ using BangBang: setindex!! show(io, vnt) output = String(take!(io)) @test output == """ - VarNamedTuple(; a=1.0, b=[1, 2, 3], c=PartialArray{Int64,1}((2,) => 15)""" + VarNamedTuple(; a=1.0, b=[1, 2, 3], c=PartialArray{Int64,1}((2,) => 15))""" vnt = setindex!!(vnt, [16.0, 17.0], @varname(d.e[3].f.g[1:2])) io = IOBuffer() @@ -382,13 +382,13 @@ using BangBang: setindex!! output = String(take!(io)) @test output == """ VarNamedTuple(; a=1.0, b=[1, 2, 3], \ - c=PartialArray{Int64,1}((2,) => 15, \ + c=PartialArray{Int64,1}((2,) => 15), \ d=VarNamedTuple(; \ - e=PartialArray{DynamicPPL.VarNamedTuples.VarNamedTuple{(:f,), \ - Tuple{DynamicPPL.VarNamedTuples.VarNamedTuple{(:g,), \ - Tuple{PartialArray{Float64, 1}}}}},1}((3,) => \ - VarNamedTuple(; f=VarNamedTuple(; g=PartialArray{Float64,1}((1,) => \ - 16.0(2,) => 17.0))))""" + e=PartialArray{VarNamedTuple{(:f,), \ + Tuple{VarNamedTuple{(:g,), \ + Tuple{DynamicPPL.VarNamedTuples.PartialArray{Float64, 1}}}}},1}((3,) => \ + VarNamedTuple(; f=VarNamedTuple(; g=PartialArray{Float64,1}((1,) => 16.0, \ + (2,) => 17.0))))))""" end end From 88db66dd8496d772bb37333aca3ec6096c5e6e83 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Tue, 25 Nov 2025 19:06:33 +0000 Subject: [PATCH 25/45] Update the design doc --- docs/src/internals/varnamedtuple.md | 173 ++++++++++++++++------------ 1 file changed, 99 insertions(+), 74 deletions(-) diff --git a/docs/src/internals/varnamedtuple.md b/docs/src/internals/varnamedtuple.md index 9f7a84cdb..0194d05d7 100644 --- a/docs/src/internals/varnamedtuple.md +++ b/docs/src/internals/varnamedtuple.md @@ -1,112 +1,137 @@ -# VarNamedTuple as the basis of VarInfo +# VarNamedTuple -This document collects thoughts and ideas for how to unify our multitude of AbstractVarInfo types using a VarNamedTuple type. It may eventually turn into a draft design document, but for now it is more raw than that. +In DynamicPPL there is often a need to store data keyed by `VarName`s. +This comes up when getting conditioned variable values from the user, when tracking values of random variables in the model outputs or inputs, etc. +Historically we've had several different approaches to this: Dictionaries, NamedTuples, vectors with subranges corresponding to different `VarName`s, and various combinations thereof. -## The current situation +To unify the treatment of these use cases, and handle them all in a robust and performant way, is the purpose of `VarNamedTuple`, aka VNT. +It's a data structure that can store arbitrary data, indexed by (nearly) arbitrary `VarName`s, in a type stable and performant manner. -We currently have the following AbstractVarInfo types: +`VarNamedTuple` consists of nested `NamedTuple`s and `PartialArray`. +Let's first talk about the `NamedTuple` part. +This is what is needed for handling `PropertyLens`es in `VarName`s, that is, `VarName`s consisting of nested symbols, like in `@varname(a.b.c)`. +In a `VarNamedTuple` each level of such nesting of `PropertyLens`es corresponds to a level of nested `NamedTuple`s, with the `Symbol`s of the lens as the keys. +For instance, the `VarNamedTuple` mapping `@varname(x) => 1, @varname(y.z) => 2` would be stored as - - A: VarInfo with Metadata - - B: VarInfo with VarNamedVector - - C: VarInfo with NamedTuple, with values being Metadata - - D: VarInfo with NamedTuple, with values being VarNamedVector - - E: SimpleVarInfo with NamedTuples - - F: SimpleVarInfo with OrderedDict - -A and C are the classic ones, and the defaults. C wraps groups the Metadata objects by the lead Symbol of the VarName of a variable, e.g. `x` in `@varname(x.y[1].z)`, which allows different lead Symbols to have different element types and for the VarInfo to still be type stable. B and D were created to simplify A and C, give them a nicer interface, and make them deal better with changing variable sizes, but according to recent (Oct 2025) benchmarks are quite a lot slower, which needs work. +``` +VarNamedTuple(; x=1, y=VarNamedTuple(; z=2)) +``` -E and F are entirely distinct in implementation from the others. E is simply a mapping from Symbols to values, with each VarName being converted to a single symbol, e.g. `Symbol("a[1]")`. F is a mapping from VarNames to values as an OrderedDict, with VarName as the key type. +where `VarNamedTuple(; x=a, y=b)` is just a thin wrapper around the `NamedTuple` `(; x=a, y=b)`. -A-D carry within them values for variables, but also their bijectors/distributions, and store all values vectorised, using the bijectors to map to the original values. They also store for each variable a flag for whether the variable has been linked. E-F store only the raw values, and a global flag for the whole SimpleVarInfo for whether it's linked. The link transform itself is implicit. +It's often handy to think of this as a tree, with each node being a `VarNamedTuple`, like so: -TODO: Write a better summary of pros and cons of each approach. +``` + VNT +x / \ y + 1 VNT + \ z + 2 +``` -## VarNamedTuple +If all `VarName`s consisted of only `PropertyLens`es we would be done designing the data structure. +However, recall that VarNames allow three different kinds of lenses: `PropertyLens`es, `IndexLens`es, and `identity` (the trivial lens). +The `identity` lens presents no complications, and in fact in the above example there was an implicit identity lens in e.g. `@varname(x) => 1`. +It is the `IndexLenses` that require more structure. -VarNamedTuple has been discussed as a possible data structure to generalise the structure used in VarInfo to achieve type stability, i.e. grouping VarNames by their lead Symbol. The same NamedTuple structure has been used elsewhere, too, e.g. in Turing.GibbsContext. The idea was to encapsulate this structure into its own type, reducing code duplication and making the design more robust and powerful. See https://github.com/TuringLang/DynamicPPL.jl/issues/900 for the discussion. +An `IndexLens` is the indexing layer in `VarName`s like `@varname(x[1])`, `@varname(x[1].a.b[2:3])` and `@varname(x[:].b[1,2,3].c[1:5,:])`. +`VarNamedTuple` can not deal with `IndexLens`es in their full generality, for reasons we'll discuss below. +Instead we restrict ourselves to `IndexLens`es where the indices are integers, explicit ranges with end points, like `1:5`, or tuples thereof. -An AbstractVarInfo type could be only one application of VarNamedTuple, but here I'll focus on it exclusively. If we can make VarNamedTuple work for an AbstractVarInfo, I bet we can make it work for other purposes (condition, fix, Gibbs) as well. +When storing data in a `VarNamedTuple`, we recursively go through the nested lenses in the `VarName`, inserting a new `VarNamedTuple` for every `PropertyLens`. +When we meet an `IndexLens`, we instead instert into the tree something called a `PartialArray`. -Without going into full detail, here's @mhauru's current proposal for what it would look like. This proposal remains in constant flux as I develop the code. +A `PartialArray` is like a regular `Base.Array`, but with some elements possibly unset. +It is a data structure we define ourselves for use within `VarNamedTuple`s. +A `PartialArray` has an element type and a number of dimensions, and they are known at compile time, but it does not have a size, and this thus not an `AbstractArray`. +This is because if we set the elements `x[1,2]` and `x[14,10]` in a `PartialArray` called `x`, this does not mean that 14 and 10 are the ends of their respective dimensions. +The typical use of this structure in DynamicPPL is that the user may define values for elements in an array-like structure one by one, and we do not always know how large these arrays are. -A VarNamedTuple is a mapping of VarNames to values. Values can be anything. In the case of using VarNamedTuple to implement an AbstractVarInfo, the values would be random samples for random variables. However, they could hold with them extra information. For instance, we might use a value that is a tuple of a vectorised value, a bijector, and a flag for whether the variable is linked. +This is also the reason why `PartialArray`, and by extension `VarNamedTuple`, do not support indexing by `Colon()`, i.e. `:`, as in `x[:]`. +A `Colon()` says that we should get or set all the values along that dimension, but a `PartialArray` does not know how many values there may be. +If `x[1]` and `x[4]` have been set, asking for `x[:]` is not a well-posed question. -I sometimes shorten VarNamedTuple to VNT. +`PartialArray`s have other restrictions, compared to the full indexing syntax of Julia, as well: +They do not support linearly indexing into multidimemensional arrays (as in `rand(3,3)[8]`), nor indexing with arrays of indices (as in `rand(4)[[1,3]]`), nor indexing with boolean mask arrays as in `rand(4)[[true, false, true, false]]`). +This is mostly because we haven't seen a need to support them, and implementing would complicate the codebase for little gain. +We may add support for them later if needed. -Internally, a VarNamedTuple consists of nested NamedTuples. For instance, the mapping `@varname(x) => 1, @varname(y.z) => 2` would be stored as +`PartialArray`s can hold any values, just like `Base.Array`s, and in particular they can hold `VarNamedTuple`s. +Thus we nest them with `VarNamedTuple`s to support storing `VarName`s with arbitrary combinations of `PropertyLens`es and `IndexLens`es. +A code example illustrates this the best: -``` -(; x=1, y=(; z=2)) -``` +```julia +julia> vnt = VarNamedTuple(); -(This is a slight simplification, really it would be nested VarNamedTuples rather than NamedTuples, but I omit this detail.) -This forms a tree, with each node being a NamedTuple, like so: +julia> vnt = setindex!!(vnt, 1.0, @varname(a)); -``` - NT -x / \ y - 1 NT - \ z - 2 -``` +julia> vnt = setindex!!(vnt, [2.0, 3.0], @varname(b.c)); -Each `NT` marks a NamedTuple, and the labels on the edges its keys. Here the root node has the keys `x` and `y`. This is like with the type stable VarInfo in our current design, except with possibly more levels (our current one only has the root node). Each nested `PropertyLens`, i.e. each `.` in a VarName like `@varname(a.b.c.e)`, creates a new layer of the tree. +julia> vnt = setindex!!(vnt, [:hip, :hop], @varname(d.e[2].f[3:4])); -For simplicity, at least for now, we ban any VarNames where an `IndexLens` precedes a `PropertyLens`. That is, we ban any VarNames like `@varname(a.b[1].c)`. Recall that VarNames allow three different kinds of lenses: `PropertyLens`es, `IndexLens`es, and `identity` (the trivial lens). Thus the only allowed VarName types are `@varname(a.b.c.d)` and `@varname(a.b.c.d[i,j,k])`. +julia> print(vnt) +VarNamedTuple(; a=1.0, b=VarNamedTuple(; c=[2.0, 3.0]), d=VarNamedTuple(; e=PartialArray{VarNamedTuple{(:f,), Tuple{DynamicPPL.VarNamedTuples.PartialArray{Symbol, 1}}},1}((2,) => VarNamedTuple(; f=PartialArray{Symbol,1}((3,) => hip, (4,) => hop))))) +``` -This means that we can add levels to the NamedTuple tree until all `PropertyLenses` have been covered. The leaves of the tree are then of two kinds: They are either the raw value itself if the last lens of the VarName is an `identity`, or otherwise they are something that can be indexed with an `IndexLens`, such as an `Array`. +The output there may be a bit hard bit hard to parse, so to illustrate: -To get a value from a VarNamedTuple is very simple: For `getindex(vnt::VNT, vn::VarName{S})` (`S` being the lead Symbol) you recurse into `getindex(vnt[S], unprefix(vn, S))`. If the last lens of `vn` is an `IndexLens`, we assume that the leaf of the NamedTuple tree we've reached contains something that can be indexed with it. +```julia +julia> vnt[@varname(b)] +VarNamedTuple(; c=[2.0, 3.0]) -Setting values in a VNT is equally simple if there are no `IndexLenses`: For `setindex!!(vnt::VNT, value::Any, vn::VarName)` one simply finds the leaf of the `vnt` tree corresponding to `vn` and sets its value to `value`. +julia> vnt[@varname(b.c[1])] +2.0 -The tricky part is what to do when setting values with `IndexLenses`. There are three possible situations. Say one calls `setindex!!(vnt, 3.0, @varname(a.b[3]))`. +julia> vnt[@varname(d.e)] +PartialArray{VarNamedTuple{(:f,), Tuple{DynamicPPL.VarNamedTuples.PartialArray{Symbol, 1}}},1}((2,) => VarNamedTuple(; f=PartialArray{Symbol,1}((3,) => hip, (4,) => hop))) - 1. If `getindex(vnt, @varname(a.b))` is already a vector of length at least 3, this is easy: Just set the third element. - 2. If `getindex(vnt, @varname(a.b))` is a vector of length less than 3, what should we do? Do we error? Do we extend that vector? - 3. If `getindex(vnt, @varname(a.b))` isn't even set, what do we do? Say for instance that `vnt` is currently empty. We should set `vnt` to be something like `(; a=(; b=x))`, where `x` is such that `x[3] = 3.0`, but what exactly should `x` be? Is it a dictionary? A vector of length 3? If the latter, what are `x[2]` and `x[1]`? Or should this `setindex!!` call simply error? +julia> vnt[@varname(d.e[2].f)] +PartialArray{Symbol,1}((3,) => hip, (4,) => hop) +``` -A note at this point: VarNamedTuples must always use `setindex!!`, the `!!` version that may or may not operate in place. The NamedTuples can't be modified in place, but the values at the leaves may be. Always using a `!!` function makes type stability easier, and makes structures like the type unstable old VarInfo with Metadata unnecessary: Any value can be set into any VarNamedTuple. The type parameters of the VNT will simply expand as necessary. +The above example also highlights how setting indices in a `VarNamedTuple` is done using `BangBang.setindex!!`. +We do not define a method for `Base.setindex!` at all, the `setindex!!` is the only way. +This is because `VarNamedTuple` mixes mutable an immutable data structures. +It is also for user convenience: +One does not ever have to think about whether the value that one is inserting into a `VarNamedTuple` is of the right type to fit in. +Rather the containers will flex to fit it, keeping element types concrete when possible, but making them abstract if needed. +`VarNamedTuple`, or more precisely `PartialArray`, even explicitly concretises element types whenever possible. +For instance, one can make an abstractly typed `VarNamedTuple` like so: -To solve the problem of points 2. and 3. above I propose expanding the definition of VNT a bit. This will also help make VNT more flexible, which may help performance or allow more use cases. The modification is this: +```julia +julia> vnt = VarNamedTuple(); -Unlike I said above, let's say that VNT isn't just nested NamedTuples with some values at the leaves. Let's say it also has a field called `make_leaf`. `make_leaf(value, lens)` is a function that takes any value, and a lens that is either `identity` or an `IndexLens`, and returns the value wrapped in some suitable struct that can be stored in the leaf of the NamedTuple tree. The values should always be such that `make_leaf(value, lens)[lens] == value`. +julia> vnt = setindex!!(vnt, 1.0, @varname(a[1])); -Our earlier example of `VarNamedTuple(@varname(x) => 1, @varname(y.z) => 2; make_leaf=f)` would be stored as a tree like +julia> vnt = setindex!!(vnt, "hello", @varname(a[2])); -``` - --NT-- - x / \ y -f(1, identity) NT - \ z - f(2, identity) +julia> print(vnt) +VarNamedTuple(; a=PartialArray{Any,1}((1,) => 1.0, (2,) => hello)) ``` -The above, first draft of VNT which did not include `make_leaf` is equivalent to the trivial choice `make_leaf(value, lens) = lens === identity ? value : error("Don't know how to deal IndexLenses")`. The problems 2. and 3. above are "solved" by making it `make_leaf`'s problem to figure out what to do. For instance, `make_leaf` can always return a `Dict` that maps lenses to values. This is probably slow, but works for any lens. Or it can initialise a vector type, that can grow as needed when indexed into. +Note the element type of `PartialArray{Any}`. +But if one changes the values to make them homogeneous, the element type is automatically made concrete again: -The idea would be to use `make_leaf` to try out different ways of implementing a VarInfo, find a good default, and ,if necessary, leave the option for power users to customise behaviour. The first ones to implement would be +```julia +julia> vnt = setindex!!(vnt, "me here", @varname(a[1])); - - `make_leaf` that returns a Metadata object. This would be a direct replacement for type stable VarInfo that uses Metadata, except now with more nested levels of NamedTuple. - - `make_leaf` that returns an `OrderedDict`. This would be a direct replacement for SimpleVarInfo with OrderedDict. - -You may ask, have we simple gone from too many VarInfo types to too many `make_leaf` functions. Yes we have. But hopefully we have gained something in the process: - - - The leaf types can be simpler. They do not need to deal with VarNames any more, they only need to deal with `identity` lenses and `IndexLenses`. - - All AbstactVarInfos are as type stable as their leaf types allow. There is no more notion of an untyped VarInfo being converted to a typed one. - - Type stability is maintained even with nested `PropertyLenses` like `@varname(a.b)`, which happens a lot with submodels. - - Many functions that are currently implemented individually for each AbstactVarInfo type would now have a single implementation for the VarNamedTuple-based AbstactVarInfo type, reducing code duplication. I would also hope to get ride of most of the generated functions for in `varinfo.jl`. - -My guess is that the eventual One AbstractVarInfo To Rule Them All would have a `make_leaf` function that stores the raw values when the lens is an `identity`, and uses a flexible Vector, a lot like VarNamedVector, when the lens is an IndexLens. However, I could be wrong on that being the best option. Implementing and benchmarking is the only way to know. +julia> print(vnt) +VarNamedTuple(; a=PartialArray{String,1}((1,) => me here, (2,) => hello)) +``` -I think the two big questions are: +This approach is at the core of why `VarNamedTuple` is performant: +As long as one does not store inhomogeneous types within a single `PartialArray`, by assigning different types to `VarName`s like `@varname(a[1])` and `@varname(a[2])`, different variables in a `VarNamedTuple` can have different types, and all `getindex` and `setindex!!` operations remain type stable. +Note that assigning a value to `@varname(a[1].b)` but not to `@varname(a[2].b)` has the same effect as assigning values of different types to `@varname(a[1])` and `@varname(a[2])`, and also causes a loss of type stability for for `getindex` and `setindex!!`. +Although, this only affects `getindex` and `setindex!!` on sub-`VarName`s of `@varname(a)`, you can still use the same `VarNamedTuple` to store information about an unrelated `@varname(c)` with stability. - - Will we run into some big, unanticipated blockers when we start to implement this. - - Will the nesting of NamedTuples cause performance regressions, if the compiler either chokes or gives up. +Some miscellaneous notes -I'll try to derisk these early on in this PR. +## Limitations -## Questions / issues +This design has a several of benefits, for performance and generality, but it also has limitations: - - People might really need IndexLenses in the middle of VarNames. The one place this comes up is submodels within a loop. I'm still inclined to keep designing without allowing for that, for now, but should keep in mind that that needs to be relaxed eventually. If it makes it easier, we can require that users explicitly tell us the size of any arrays for which this is done. - - When storing values for nested NamedTuples, the actual variable may be a struct. Do we need to be able to reconstruct the struct from the NamedTuple? If so, how do we do that? - - Do `Colon` indices cause any extra trouble for the leafnodes? + 1. The lack of support for `Colon`s in `VarName`s. + 2. The lack of support for some other indexing syntaxes supported by Julia, such as linear indexing and boolean indexing. + 3. An assymmetry between storing arrays with `setindex!!(vnt, array, @varname(a))` and elements of arrays with `setindex!!(vnt, element, @varname(a[i]))`. + The former stores the whole array, which can then be indexed with both `@varname(a)` and `@varname(a[i])`. + The latter stores only individual elements, and even if all elements have been set, one still can't get the value associated with `@varname(a)` as a regular `Base.Array`. From a6d56a2b9074d9da27eea4a6e4a2ab9a3013913f Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Thu, 27 Nov 2025 12:08:24 +0000 Subject: [PATCH 26/45] Improve FastLDF type stability when all parameters are linked or unlinked (#1141) * Improve type stability when all parameters are linked or unlinked * fix a merge conflict * fix enzyme gc crash (locally at least) * Fixes from review --- src/chains.jl | 8 +++-- src/contexts/init.jl | 33 ++++++++++++++++--- src/logdensityfunction.jl | 56 +++++++++++++++++++++++++-------- test/integration/enzyme/main.jl | 10 ++++-- test/logdensityfunction.jl | 16 ++++++++++ 5 files changed, 99 insertions(+), 24 deletions(-) diff --git a/src/chains.jl b/src/chains.jl index 2fcd4e713..d01606c3d 100644 --- a/src/chains.jl +++ b/src/chains.jl @@ -130,13 +130,15 @@ via `unflatten` plus re-evaluation. It is faster for two reasons: """ function ParamsWithStats( param_vector::AbstractVector, - ldf::DynamicPPL.LogDensityFunction, + ldf::DynamicPPL.LogDensityFunction{Tlink}, stats::NamedTuple=NamedTuple(); include_colon_eq::Bool=true, include_log_probs::Bool=true, -) +) where {Tlink} strategy = InitFromParams( - VectorWithRanges(ldf._iden_varname_ranges, ldf._varname_ranges, param_vector), + VectorWithRanges{Tlink}( + ldf._iden_varname_ranges, ldf._varname_ranges, param_vector + ), nothing, ) accs = if include_log_probs diff --git a/src/contexts/init.jl b/src/contexts/init.jl index a79969a13..80a494c23 100644 --- a/src/contexts/init.jl +++ b/src/contexts/init.jl @@ -214,7 +214,7 @@ struct RangeAndLinked end """ - VectorWithRanges( + VectorWithRanges{Tlink}( iden_varname_ranges::NamedTuple, varname_ranges::Dict{VarName,RangeAndLinked}, vect::AbstractVector{<:Real}, @@ -223,6 +223,12 @@ end A struct that wraps a vector of parameter values, plus information about how random variables map to ranges in that vector. +The type parameter `Tlink` can be either `true` or `false`, to mark that the variables in +this `VectorWithRanges` are linked/not linked, or `nothing` if either the linking status is +not known or is mixed, i.e. some are linked while others are not. Using `nothing` does not +affect functionality or correctness, but causes more work to be done at runtime, with +possible impacts on type stability and performance. + In the simplest case, this could be accomplished only with a single dictionary mapping VarNames to ranges and link status. However, for performance reasons, we separate out VarNames with identity optics into a NamedTuple (`iden_varname_ranges`). All @@ -231,13 +237,26 @@ non-identity-optic VarNames are stored in the `varname_ranges` Dict. It would be nice to improve the NamedTuple and Dict approach. See, e.g. https://github.com/TuringLang/DynamicPPL.jl/issues/1116. """ -struct VectorWithRanges{N<:NamedTuple,T<:AbstractVector{<:Real}} +struct VectorWithRanges{Tlink,N<:NamedTuple,T<:AbstractVector{<:Real}} # This NamedTuple stores the ranges for identity VarNames iden_varname_ranges::N # This Dict stores the ranges for all other VarNames varname_ranges::Dict{VarName,RangeAndLinked} # The full parameter vector which we index into to get variable values vect::T + + function VectorWithRanges{Tlink}( + iden_varname_ranges::N, varname_ranges::Dict{VarName,RangeAndLinked}, vect::T + ) where {Tlink,N,T} + if !(Tlink isa Union{Bool,Nothing}) + throw( + ArgumentError( + "VectorWithRanges type parameter has to be one of `true`, `false`, or `nothing`.", + ), + ) + end + return new{Tlink,N,T}(iden_varname_ranges, varname_ranges, vect) + end end function _get_range_and_linked( @@ -252,11 +271,15 @@ function init( ::Random.AbstractRNG, vn::VarName, dist::Distribution, - p::InitFromParams{<:VectorWithRanges}, -) + p::InitFromParams{<:VectorWithRanges{T}}, +) where {T} vr = p.params range_and_linked = _get_range_and_linked(vr, vn) - transform = if range_and_linked.is_linked + # T can either be `nothing` (i.e., link status is mixed, in which + # case we use the stored link status), or `true` / `false`, which + # indicates that all variables are linked / unlinked. + linked = isnothing(T) ? range_and_linked.is_linked : T + transform = if linked from_linked_vec_transform(dist) else from_vec_transform(dist) diff --git a/src/logdensityfunction.jl b/src/logdensityfunction.jl index bcdd0bb25..7d1094fa3 100644 --- a/src/logdensityfunction.jl +++ b/src/logdensityfunction.jl @@ -140,6 +140,9 @@ with such models.** This is a general limitation of vectorised parameters: the o `unflatten` + `evaluate!!` approach also fails with such models. """ struct LogDensityFunction{ + # true if all variables are linked; false if all variables are unlinked; nothing if + # mixed + Tlink, M<:Model, AD<:Union{ADTypes.AbstractADType,Nothing}, F<:Function, @@ -163,6 +166,21 @@ struct LogDensityFunction{ # Figure out which variable corresponds to which index, and # which variables are linked. all_iden_ranges, all_ranges = get_ranges_and_linked(varinfo) + # Figure out if all variables are linked, unlinked, or mixed + link_statuses = Bool[] + for ral in all_iden_ranges + push!(link_statuses, ral.is_linked) + end + for (_, ral) in all_ranges + push!(link_statuses, ral.is_linked) + end + Tlink = if all(link_statuses) + true + elseif all(!s for s in link_statuses) + false + else + nothing + end x = [val for val in varinfo[:]] dim = length(x) # Do AD prep if needed @@ -172,12 +190,13 @@ struct LogDensityFunction{ # Make backend-specific tweaks to the adtype adtype = DynamicPPL.tweak_adtype(adtype, model, varinfo) DI.prepare_gradient( - LogDensityAt(model, getlogdensity, all_iden_ranges, all_ranges), + LogDensityAt{Tlink}(model, getlogdensity, all_iden_ranges, all_ranges), adtype, x, ) end return new{ + Tlink, typeof(model), typeof(adtype), typeof(getlogdensity), @@ -209,15 +228,24 @@ end ldf_accs(::typeof(getlogprior)) = AccumulatorTuple((LogPriorAccumulator(),)) ldf_accs(::typeof(getloglikelihood)) = AccumulatorTuple((LogLikelihoodAccumulator(),)) -struct LogDensityAt{M<:Model,F<:Function,N<:NamedTuple} +struct LogDensityAt{Tlink,M<:Model,F<:Function,N<:NamedTuple} model::M getlogdensity::F iden_varname_ranges::N varname_ranges::Dict{VarName,RangeAndLinked} + + function LogDensityAt{Tlink}( + model::M, + getlogdensity::F, + iden_varname_ranges::N, + varname_ranges::Dict{VarName,RangeAndLinked}, + ) where {Tlink,M,F,N} + return new{Tlink,M,F,N}(model, getlogdensity, iden_varname_ranges, varname_ranges) + end end -function (f::LogDensityAt)(params::AbstractVector{<:Real}) +function (f::LogDensityAt{Tlink})(params::AbstractVector{<:Real}) where {Tlink} strategy = InitFromParams( - VectorWithRanges(f.iden_varname_ranges, f.varname_ranges, params), nothing + VectorWithRanges{Tlink}(f.iden_varname_ranges, f.varname_ranges, params), nothing ) accs = ldf_accs(f.getlogdensity) _, vi = DynamicPPL.init!!(f.model, OnlyAccsVarInfo(accs), strategy) @@ -225,9 +253,9 @@ function (f::LogDensityAt)(params::AbstractVector{<:Real}) end function LogDensityProblems.logdensity( - ldf::LogDensityFunction, params::AbstractVector{<:Real} -) - return LogDensityAt( + ldf::LogDensityFunction{Tlink}, params::AbstractVector{<:Real} +) where {Tlink} + return LogDensityAt{Tlink}( ldf.model, ldf._getlogdensity, ldf._iden_varname_ranges, ldf._varname_ranges )( params @@ -235,10 +263,10 @@ function LogDensityProblems.logdensity( end function LogDensityProblems.logdensity_and_gradient( - ldf::LogDensityFunction, params::AbstractVector{<:Real} -) + ldf::LogDensityFunction{Tlink}, params::AbstractVector{<:Real} +) where {Tlink} return DI.value_and_gradient( - LogDensityAt( + LogDensityAt{Tlink}( ldf.model, ldf._getlogdensity, ldf._iden_varname_ranges, ldf._varname_ranges ), ldf._adprep, @@ -247,12 +275,14 @@ function LogDensityProblems.logdensity_and_gradient( ) end -function LogDensityProblems.capabilities(::Type{<:LogDensityFunction{M,Nothing}}) where {M} +function LogDensityProblems.capabilities( + ::Type{<:LogDensityFunction{T,M,Nothing}} +) where {T,M} return LogDensityProblems.LogDensityOrder{0}() end function LogDensityProblems.capabilities( - ::Type{<:LogDensityFunction{M,<:ADTypes.AbstractADType}} -) where {M} + ::Type{<:LogDensityFunction{T,M,<:ADTypes.AbstractADType}} +) where {T,M} return LogDensityProblems.LogDensityOrder{1}() end function LogDensityProblems.dimension(ldf::LogDensityFunction) diff --git a/test/integration/enzyme/main.jl b/test/integration/enzyme/main.jl index ea4ec497d..edfd67d18 100644 --- a/test/integration/enzyme/main.jl +++ b/test/integration/enzyme/main.jl @@ -5,11 +5,15 @@ using Test: @test, @testset import Enzyme: set_runtime_activity, Forward, Reverse, Const using ForwardDiff: ForwardDiff # run_ad uses FD for correctness test -ADTYPES = Dict( - "EnzymeForward" => +ADTYPES = ( + ( + "EnzymeForward", AutoEnzyme(; mode=set_runtime_activity(Forward), function_annotation=Const), - "EnzymeReverse" => + ), + ( + "EnzymeReverse", AutoEnzyme(; mode=set_runtime_activity(Reverse), function_annotation=Const), + ), ) @testset "$ad_key" for (ad_key, ad_type) in ADTYPES diff --git a/test/logdensityfunction.jl b/test/logdensityfunction.jl index 06492d6e1..f43ed45a4 100644 --- a/test/logdensityfunction.jl +++ b/test/logdensityfunction.jl @@ -108,6 +108,22 @@ end end end +@testset "LogDensityFunction: Type stability" begin + @testset "$(m.f)" for m in DynamicPPL.TestUtils.DEMO_MODELS + unlinked_vi = DynamicPPL.VarInfo(m) + @testset "$islinked" for islinked in (false, true) + vi = if islinked + DynamicPPL.link!!(unlinked_vi, m) + else + unlinked_vi + end + ldf = DynamicPPL.LogDensityFunction(m, DynamicPPL.getlogjoint_internal, vi) + x = vi[:] + @inferred LogDensityProblems.logdensity(ldf, x) + end + end +end + @testset "LogDensityFunction: performance" begin if Threads.nthreads() == 1 # Evaluating these three models should not lead to any allocations (but only when From eca65d5cd6f6147199f6fddb6f9be009abcbf454 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Thu, 27 Nov 2025 13:14:27 +0000 Subject: [PATCH 27/45] Fix a test --- test/varnamedtuple.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/varnamedtuple.jl b/test/varnamedtuple.jl index 7b26e9be7..a93db10cc 100644 --- a/test/varnamedtuple.jl +++ b/test/varnamedtuple.jl @@ -384,8 +384,8 @@ using BangBang: setindex!! VarNamedTuple(; a=1.0, b=[1, 2, 3], \ c=PartialArray{Int64,1}((2,) => 15), \ d=VarNamedTuple(; \ - e=PartialArray{VarNamedTuple{(:f,), \ - Tuple{VarNamedTuple{(:g,), \ + e=PartialArray{DynamicPPL.VarNamedTuples.VarNamedTuple{(:f,), \ + Tuple{DynamicPPL.VarNamedTuples.VarNamedTuple{(:g,), \ Tuple{DynamicPPL.VarNamedTuples.PartialArray{Float64, 1}}}}},1}((3,) => \ VarNamedTuple(; f=VarNamedTuple(; g=PartialArray{Float64,1}((1,) => 16.0, \ (2,) => 17.0))))))""" From 5e27a052725226c6cf15b0e44a38d63cd39eb032 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Thu, 27 Nov 2025 13:45:12 +0000 Subject: [PATCH 28/45] Fix copy and show --- src/varnamedtuple.jl | 39 ++++++++++++++++++++++++++++++++++++--- test/varnamedtuple.jl | 18 +++++++++--------- 2 files changed, 45 insertions(+), 12 deletions(-) diff --git a/src/varnamedtuple.jl b/src/varnamedtuple.jl index e47b27e9e..2068566b4 100644 --- a/src/varnamedtuple.jl +++ b/src/varnamedtuple.jl @@ -155,7 +155,12 @@ function Base.show(io::IO, pa::PartialArray) is_first = false end val = @inbounds(pa.data[inds]) - print(io, Tuple(inds), " => ", val) + # Note the distinction: The raw strings that form part of the structure of the print + # out are `print`ed, whereas the keys and values are `show`n. The latter ensures + # that strings are quoted, Symbols are prefixed with :, etc. + show(io, Tuple(inds)) + print(io, " => ") + show(io, val) end print(io, ")") return nothing @@ -166,7 +171,17 @@ end _internal_size(pa::PartialArray, args...) = size(pa.data, args...) function Base.copy(pa::PartialArray) - return PartialArray(copy(pa.data), copy(pa.mask)) + # Make a shallow copy of pa, except for any VarNamedTuple elements, which we recursively + # copy. + pa_copy = PartialArray(copy(pa.data), copy(pa.mask)) + if VarNamedTuple <: eltype(pa) || eltype(pa) <: VarNamedTuple + @inbounds for i in eachindex(pa.mask) + if pa.mask[i] && pa_copy.data[i] isa VarNamedTuple + pa_copy.data[i] = copy(pa.data[i]) + end + end + end + return pa_copy end function Base.:(==)(pa1::PartialArray, pa2::PartialArray) @@ -471,11 +486,29 @@ function Base.show(io::IO, vnt::VarNamedTuple) if i > 1 print(io, ",") end - print(io, " ", name, "=", value) + print(io, " ") + print(io, name) + print(io, "=") + # Note the distinction: The raw strings that form part of the structure of the print + # out are `print`ed, whereas the value itself is `show`n. The latter ensures that + # strings are quoted, Symbols are prefixed with :, etc. + show(io, value) end return print(io, ")") end +function Base.copy(vnt::VarNamedTuple{Names}) where {Names} + # Make a shallow copy of vnt, except for any VarNamedTuple or PartialArray elements, + # which we recursively copy. + return VarNamedTuple( + NamedTuple{Names}( + map( + x -> x isa Union{VarNamedTuple,PartialArray} ? copy(x) : x, values(vnt.data) + ), + ), + ) +end + """ varname_to_lens(name::VarName{S}) where {S} diff --git a/test/varnamedtuple.jl b/test/varnamedtuple.jl index a93db10cc..ad5fba8c1 100644 --- a/test/varnamedtuple.jl +++ b/test/varnamedtuple.jl @@ -357,35 +357,35 @@ using BangBang: setindex!! output = String(take!(io)) @test output == "VarNamedTuple(;)" - vnt = setindex!!(vnt, 1.0, @varname(a)) + vnt = setindex!!(vnt, "s", @varname(a)) io = IOBuffer() show(io, vnt) output = String(take!(io)) - @test output == "VarNamedTuple(; a=1.0)" + @test output == """VarNamedTuple(; a="s")""" vnt = setindex!!(vnt, [1, 2, 3], @varname(b)) io = IOBuffer() show(io, vnt) output = String(take!(io)) - @test output == "VarNamedTuple(; a=1.0, b=[1, 2, 3])" + @test output == """VarNamedTuple(; a="s", b=[1, 2, 3])""" - vnt = setindex!!(vnt, 15, @varname(c[2])) + vnt = setindex!!(vnt, :dada, @varname(c[2])) io = IOBuffer() show(io, vnt) output = String(take!(io)) @test output == """ - VarNamedTuple(; a=1.0, b=[1, 2, 3], c=PartialArray{Int64,1}((2,) => 15))""" + VarNamedTuple(; a="s", b=[1, 2, 3], c=PartialArray{Symbol,1}((2,) => :dada))""" vnt = setindex!!(vnt, [16.0, 17.0], @varname(d.e[3].f.g[1:2])) io = IOBuffer() show(io, vnt) output = String(take!(io)) @test output == """ - VarNamedTuple(; a=1.0, b=[1, 2, 3], \ - c=PartialArray{Int64,1}((2,) => 15), \ + VarNamedTuple(; a="s", b=[1, 2, 3], \ + c=PartialArray{Symbol,1}((2,) => :dada), \ d=VarNamedTuple(; \ - e=PartialArray{DynamicPPL.VarNamedTuples.VarNamedTuple{(:f,), \ - Tuple{DynamicPPL.VarNamedTuples.VarNamedTuple{(:g,), \ + e=PartialArray{VarNamedTuple{(:f,), \ + Tuple{VarNamedTuple{(:g,), \ Tuple{DynamicPPL.VarNamedTuples.PartialArray{Float64, 1}}}}},1}((3,) => \ VarNamedTuple(; f=VarNamedTuple(; g=PartialArray{Float64,1}((1,) => 16.0, \ (2,) => 17.0))))))""" From 050b8c54ca435a50a6cf24bb4052b5a65385b0f9 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Thu, 27 Nov 2025 13:49:41 +0000 Subject: [PATCH 29/45] Add test_invariants to VNT tests --- test/varnamedtuple.jl | 67 +++++++++++++++++++++++++++++++++---------- 1 file changed, 52 insertions(+), 15 deletions(-) diff --git a/test/varnamedtuple.jl b/test/varnamedtuple.jl index ad5fba8c1..f55f8b996 100644 --- a/test/varnamedtuple.jl +++ b/test/varnamedtuple.jl @@ -2,47 +2,67 @@ module VarNamedTupleTests using Test: @inferred, @test, @test_throws, @testset using DynamicPPL: DynamicPPL, @varname, VarNamedTuple +using DynamicPPL.VarNamedTuples: PartialArray using BangBang: setindex!! +""" + test_invariants(vnt::VarNamedTuple) + +Test properties that should hold for all VarNamedTuples. + +Uses @test for all the tests. Intended to be called inside a @testset. +""" +function test_invariants(vnt::VarNamedTuple) + # Check that for all keys in vnt, haskey is true, and resetting the value is a no-op. + for k in keys(vnt) + @test haskey(vnt, k) + v = getindex(vnt, k) + vnt2 = setindex!!(copy(vnt), v, k) + @test vnt == vnt2 + end + # Check that the printed representation can be parsed back to an equal VarNamedTuple. + vnt3 = eval(Meta.parse(repr(vnt))) + @test vnt == vnt3 +end + @testset "VarNamedTuple" begin @testset "Construction" begin vnt1 = VarNamedTuple() + test_invariants(vnt1) vnt1 = setindex!!(vnt1, 1.0, @varname(a)) vnt1 = setindex!!(vnt1, [1, 2, 3], @varname(b)) vnt1 = setindex!!(vnt1, "a", @varname(c.d.e)) + test_invariants(vnt1) vnt2 = VarNamedTuple(; a=1.0, b=[1, 2, 3], c=VarNamedTuple(; d=VarNamedTuple(; e="a")) ) + test_invariants(vnt2) @test vnt1 == vnt2 - pa1 = DynamicPPL.VarNamedTuples.PartialArray{Float64,1}() + pa1 = PartialArray{Float64,1}() pa1 = setindex!!(pa1, 1.0, 16) - pa2 = DynamicPPL.VarNamedTuples.PartialArray{Float64,1}(; min_size=(16,)) + pa2 = PartialArray{Float64,1}(; min_size=(16,)) pa2 = setindex!!(pa2, 1.0, 16) - pa3 = DynamicPPL.VarNamedTuples.PartialArray{Float64,1}(16 => 1.0) - pa4 = DynamicPPL.VarNamedTuples.PartialArray{Float64,1}((16,) => 1.0) + pa3 = PartialArray{Float64,1}(16 => 1.0) + pa4 = PartialArray{Float64,1}((16,) => 1.0) @test pa1 == pa2 @test pa1 == pa3 @test pa1 == pa4 - pa1 = DynamicPPL.VarNamedTuples.PartialArray{String,3}() + pa1 = PartialArray{String,3}() pa1 = setindex!!(pa1, "a", 2, 3, 4) pa1 = setindex!!(pa1, "b", 1, 2, 4) - pa2 = DynamicPPL.VarNamedTuples.PartialArray{String,3}(; min_size=(16, 16, 16)) + pa2 = PartialArray{String,3}(; min_size=(16, 16, 16)) pa2 = setindex!!(pa2, "a", 2, 3, 4) pa2 = setindex!!(pa2, "b", 1, 2, 4) - pa3 = DynamicPPL.VarNamedTuples.PartialArray{String,3}( - (2, 3, 4) => "a", (1, 2, 4) => "b" - ) + pa3 = PartialArray{String,3}((2, 3, 4) => "a", (1, 2, 4) => "b") @test pa1 == pa2 @test pa1 == pa3 - @test_throws BoundsError DynamicPPL.VarNamedTuples.PartialArray{Int,1}((0,) => 1) - @test_throws BoundsError DynamicPPL.VarNamedTuples.PartialArray{Int,1}((1, 2) => 1) - @test_throws MethodError DynamicPPL.VarNamedTuples.PartialArray{Int,1}((1,) => "a") - @test_throws MethodError DynamicPPL.VarNamedTuples.PartialArray{Int,1}( - (1,) => 1; min_size=(2, 2) - ) + @test_throws BoundsError PartialArray{Int,1}((0,) => 1) + @test_throws BoundsError PartialArray{Int,1}((1, 2) => 1) + @test_throws MethodError PartialArray{Int,1}((1,) => "a") + @test_throws MethodError PartialArray{Int,1}((1,) => 1; min_size=(2, 2)) end @testset "Basic sets and gets" begin @@ -51,6 +71,7 @@ using BangBang: setindex!! @test @inferred(getindex(vnt, @varname(a))) == 32.0 @test haskey(vnt, @varname(a)) @test !haskey(vnt, @varname(b)) + test_invariants(vnt) vnt = @inferred(setindex!!(vnt, [1, 2, 3], @varname(b))) @test @inferred(getindex(vnt, @varname(b))) == [1, 2, 3] @@ -59,40 +80,50 @@ using BangBang: setindex!! @test haskey(vnt, @varname(b[1])) @test haskey(vnt, @varname(b[1:3])) @test !haskey(vnt, @varname(b[4])) + test_invariants(vnt) vnt = @inferred(setindex!!(vnt, 64.0, @varname(a))) @test @inferred(getindex(vnt, @varname(a))) == 64.0 @test @inferred(getindex(vnt, @varname(b))) == [1, 2, 3] + test_invariants(vnt) vnt = @inferred(setindex!!(vnt, 15, @varname(b[2]))) @test @inferred(getindex(vnt, @varname(b))) == [1, 15, 3] @test @inferred(getindex(vnt, @varname(b[2]))) == 15 + test_invariants(vnt) vnt = @inferred(setindex!!(vnt, [10], @varname(c.x.y))) @test @inferred(getindex(vnt, @varname(c.x.y))) == [10] + test_invariants(vnt) vnt = @inferred(setindex!!(vnt, 11, @varname(c.x.y[1]))) @test @inferred(getindex(vnt, @varname(c.x.y))) == [11] @test @inferred(getindex(vnt, @varname(c.x.y[1]))) == 11 + test_invariants(vnt) vnt = @inferred(setindex!!(vnt, -1.0, @varname(d[4]))) @test @inferred(getindex(vnt, @varname(d[4]))) == -1.0 + test_invariants(vnt) vnt = @inferred(setindex!!(vnt, -2.0, @varname(d[4]))) @test @inferred(getindex(vnt, @varname(d[4]))) == -2.0 + test_invariants(vnt) # These can't be @inferred because `d` now has an abstract element type. Note that this # does not ruin type stability for other varnames that don't involve `d`. vnt = setindex!!(vnt, "a", @varname(d[5])) @test getindex(vnt, @varname(d[5])) == "a" + test_invariants(vnt) vnt = @inferred(setindex!!(vnt, 1.0, @varname(e.f[3].g.h[2].i))) @test @inferred(getindex(vnt, @varname(e.f[3].g.h[2].i))) == 1.0 @test haskey(vnt, @varname(e.f[3].g.h[2].i)) @test !haskey(vnt, @varname(e.f[2].g.h[2].i)) + test_invariants(vnt) vnt = @inferred(setindex!!(vnt, 2.0, @varname(e.f[3].g.h[2].i))) @test @inferred(getindex(vnt, @varname(e.f[3].g.h[2].i))) == 2.0 + test_invariants(vnt) vec = fill(1.0, 4) vnt = @inferred(setindex!!(vnt, vec, @varname(j[1:4]))) @@ -101,12 +132,14 @@ using BangBang: setindex!! @test haskey(vnt, @varname(j[4])) @test !haskey(vnt, @varname(j[5])) @test_throws BoundsError getindex(vnt, @varname(j[5])) + test_invariants(vnt) vec = fill(2.0, 4) vnt = @inferred(setindex!!(vnt, vec, @varname(j[2:5]))) @test @inferred(getindex(vnt, @varname(j[1]))) == 1.0 @test @inferred(getindex(vnt, @varname(j[2:5]))) == vec @test haskey(vnt, @varname(j[5])) + test_invariants(vnt) arr = fill(2.0, (4, 2)) vn = @varname(k.l[2:5, 3, 1:2, 2]) @@ -114,6 +147,7 @@ using BangBang: setindex!! @test @inferred(getindex(vnt, vn)) == arr # A subset of the elements set just now. @test @inferred(getindex(vnt, @varname(k.l[2, 3, 1:2, 2]))) == fill(2.0, 2) + test_invariants(vnt) # Not enough, or too many, indices. @test_throws BoundsError setindex!!(vnt, 0.0, @varname(k.l[1, 2, 3])) @@ -128,11 +162,13 @@ using BangBang: setindex!! # A subset of the elements set previously. @test @inferred(getindex(vnt, @varname(k.l[2, 3, 1:2, 2]))) == fill(2.0, 2) @test !haskey(vnt, @varname(k.l[2, 3, 3, 2])) + test_invariants(vnt) vnt = @inferred(setindex!!(vnt, 1.0, @varname(m[2]))) vnt = @inferred(setindex!!(vnt, 1.0, @varname(m[3]))) @test @inferred(getindex(vnt, @varname(m[2:3]))) == [1.0, 1.0] @test !haskey(vnt, @varname(m[1])) + test_invariants(vnt) # The below tests are mostly significant for the type stability aspect. For the last # test to pass, PartialArray needs to actively tighten its eltype when possible. @@ -147,6 +183,7 @@ using BangBang: setindex!! # VarNamedTuple with a concrete element type, and hence getindex can be inferred. vnt = setindex!!(vnt, 1.0, @varname(n[2].b)) @test @inferred(getindex(vnt, @varname(n[2].b))) == 1.0 + test_invariants(vnt) end @testset "equality" begin From f5616df867e2a685783686dbf1d382888a510974 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Thu, 27 Nov 2025 16:44:23 +0000 Subject: [PATCH 30/45] Improve VNT internal docs --- docs/src/internals/varnamedtuple.md | 47 ++++++++++++++++++++--------- 1 file changed, 33 insertions(+), 14 deletions(-) diff --git a/docs/src/internals/varnamedtuple.md b/docs/src/internals/varnamedtuple.md index 0194d05d7..7198aae9f 100644 --- a/docs/src/internals/varnamedtuple.md +++ b/docs/src/internals/varnamedtuple.md @@ -1,16 +1,16 @@ -# VarNamedTuple +# `VarNamedTuple` In DynamicPPL there is often a need to store data keyed by `VarName`s. This comes up when getting conditioned variable values from the user, when tracking values of random variables in the model outputs or inputs, etc. -Historically we've had several different approaches to this: Dictionaries, NamedTuples, vectors with subranges corresponding to different `VarName`s, and various combinations thereof. +Historically we've had several different approaches to this: Dictionaries, `NamedTuple`s, vectors with subranges corresponding to different `VarName`s, and various combinations thereof. To unify the treatment of these use cases, and handle them all in a robust and performant way, is the purpose of `VarNamedTuple`, aka VNT. It's a data structure that can store arbitrary data, indexed by (nearly) arbitrary `VarName`s, in a type stable and performant manner. -`VarNamedTuple` consists of nested `NamedTuple`s and `PartialArray`. +`VarNamedTuple` consists of nested `NamedTuple`s and `PartialArray`s. Let's first talk about the `NamedTuple` part. This is what is needed for handling `PropertyLens`es in `VarName`s, that is, `VarName`s consisting of nested symbols, like in `@varname(a.b.c)`. -In a `VarNamedTuple` each level of such nesting of `PropertyLens`es corresponds to a level of nested `NamedTuple`s, with the `Symbol`s of the lens as the keys. +In a `VarNamedTuple` each level of such nesting of `PropertyLens`es corresponds to a level of nested `NamedTuple`s, with the `Symbol`s of the lenses as keys. For instance, the `VarNamedTuple` mapping `@varname(x) => 1, @varname(y.z) => 2` would be stored as ``` @@ -30,11 +30,11 @@ x / \ y ``` If all `VarName`s consisted of only `PropertyLens`es we would be done designing the data structure. -However, recall that VarNames allow three different kinds of lenses: `PropertyLens`es, `IndexLens`es, and `identity` (the trivial lens). +However, recall that `VarName`s allow three different kinds of lenses: `PropertyLens`es, `IndexLens`es, and `identity` (the trivial lens). The `identity` lens presents no complications, and in fact in the above example there was an implicit identity lens in e.g. `@varname(x) => 1`. It is the `IndexLenses` that require more structure. -An `IndexLens` is the indexing layer in `VarName`s like `@varname(x[1])`, `@varname(x[1].a.b[2:3])` and `@varname(x[:].b[1,2,3].c[1:5,:])`. +An `IndexLens` is the square bracket indexing part in `VarName`s like `@varname(x[1])`, `@varname(x[1].a.b[2:3])` and `@varname(x[:].b[1,2,3].c[1:5,:])`. `VarNamedTuple` can not deal with `IndexLens`es in their full generality, for reasons we'll discuss below. Instead we restrict ourselves to `IndexLens`es where the indices are integers, explicit ranges with end points, like `1:5`, or tuples thereof. @@ -43,7 +43,7 @@ When we meet an `IndexLens`, we instead instert into the tree something called a A `PartialArray` is like a regular `Base.Array`, but with some elements possibly unset. It is a data structure we define ourselves for use within `VarNamedTuple`s. -A `PartialArray` has an element type and a number of dimensions, and they are known at compile time, but it does not have a size, and this thus not an `AbstractArray`. +A `PartialArray` has an element type and a number of dimensions, and they are known at compile time, but it does not have a size, and thus is not an `AbstractArray`. This is because if we set the elements `x[1,2]` and `x[14,10]` in a `PartialArray` called `x`, this does not mean that 14 and 10 are the ends of their respective dimensions. The typical use of this structure in DynamicPPL is that the user may define values for elements in an array-like structure one by one, and we do not always know how large these arrays are. @@ -52,8 +52,8 @@ A `Colon()` says that we should get or set all the values along that dimension, If `x[1]` and `x[4]` have been set, asking for `x[:]` is not a well-posed question. `PartialArray`s have other restrictions, compared to the full indexing syntax of Julia, as well: -They do not support linearly indexing into multidimemensional arrays (as in `rand(3,3)[8]`), nor indexing with arrays of indices (as in `rand(4)[[1,3]]`), nor indexing with boolean mask arrays as in `rand(4)[[true, false, true, false]]`). -This is mostly because we haven't seen a need to support them, and implementing would complicate the codebase for little gain. +They do not support linearly indexing into multidimemensional arrays (as in `rand(3,3)[8]`), nor indexing with arrays of indices (as in `rand(4)[[1,3]]`), nor indexing with boolean mask arrays (as in `rand(4)[[true, false, true, false]]`). +This is mostly because we haven't seen a need to support them, and implementing them would complicate the codebase for little gain. We may add support for them later if needed. `PartialArray`s can hold any values, just like `Base.Array`s, and in particular they can hold `VarNamedTuple`s. @@ -89,8 +89,20 @@ julia> vnt[@varname(d.e[2].f)] PartialArray{Symbol,1}((3,) => hip, (4,) => hop) ``` -The above example also highlights how setting indices in a `VarNamedTuple` is done using `BangBang.setindex!!`. -We do not define a method for `Base.setindex!` at all, the `setindex!!` is the only way. +Or as a tree drawing, where `PA` marks a `PartialArray`: + +``` + /----VNT------\ +a / | b \ d + 1 [2.0, 3.0] VNT + | e + PA(2 => VNT) + | f + PA(3 => :hip, 4 => :hop) +``` + +The above code also highlights how setting indices in a `VarNamedTuple` is done using `BangBang.setindex!!`. +We do not define a method for `Base.setindex!` at all, `setindex!!` is the only way. This is because `VarNamedTuple` mixes mutable an immutable data structures. It is also for user convenience: One does not ever have to think about whether the value that one is inserting into a `VarNamedTuple` is of the right type to fit in. @@ -122,9 +134,15 @@ VarNamedTuple(; a=PartialArray{String,1}((1,) => me here, (2,) => hello)) This approach is at the core of why `VarNamedTuple` is performant: As long as one does not store inhomogeneous types within a single `PartialArray`, by assigning different types to `VarName`s like `@varname(a[1])` and `@varname(a[2])`, different variables in a `VarNamedTuple` can have different types, and all `getindex` and `setindex!!` operations remain type stable. Note that assigning a value to `@varname(a[1].b)` but not to `@varname(a[2].b)` has the same effect as assigning values of different types to `@varname(a[1])` and `@varname(a[2])`, and also causes a loss of type stability for for `getindex` and `setindex!!`. -Although, this only affects `getindex` and `setindex!!` on sub-`VarName`s of `@varname(a)`, you can still use the same `VarNamedTuple` to store information about an unrelated `@varname(c)` with stability. +Although, this only affects `getindex` and `setindex!!` on sub-`VarName`s of `@varname(a)`; +You can still use the same `VarNamedTuple` to store information about an unrelated `@varname(c)` with stability. -Some miscellaneous notes +Note that if you `setindex!!` a new value into a `VarNamedTuple` with an `IndexLens`, this causes a `PartialArray` to be created. +However, if there already is a regular `Base.Array` stored in a `VarNamedTuple`, you can index into it with `IndexLens`es without involving `PartialArray`s. +That is, if you do `vnt = setindex!!(vnt, @varname(a), [1.0, 2.0])`, you can then either get the values with e.g. `vnt[@varname(a[1])`, which returns 1.0. +You can also set the elements with `vnt = setindex!!(vnt, @varname(a[1]), 3.0)`, and this will modify the existing `Base.Array`. +At this point you can not set any new values in that array that would be outside of its range, with something like `vnt = setindex!!(vnt, @varname(a[5]), 5.0)`. +The philosophy here is that once a `Base.Array` has been attached to a `VarName`, that takes precedence, and a `PartialArray` is only used as a fallback when we are told to store a value for `@varname(a[i])` without having any previous knowledge about what `@varname(a)` is. ## Limitations @@ -132,6 +150,7 @@ This design has a several of benefits, for performance and generality, but it al 1. The lack of support for `Colon`s in `VarName`s. 2. The lack of support for some other indexing syntaxes supported by Julia, such as linear indexing and boolean indexing. - 3. An assymmetry between storing arrays with `setindex!!(vnt, array, @varname(a))` and elements of arrays with `setindex!!(vnt, element, @varname(a[i]))`. + 3. `VarNamedTuple` can not store indices with different numbers of dimensions in the same value, so for instance `@varname(a[1])` and `@varname(a[1,1])` can not be stored in the same `VarNamedTuple`. + 4. There is an assymmetry between storing arrays with `setindex!!(vnt, array, @varname(a))` and elements of arrays with `setindex!!(vnt, element, @varname(a[i]))`. The former stores the whole array, which can then be indexed with both `@varname(a)` and `@varname(a[i])`. The latter stores only individual elements, and even if all elements have been set, one still can't get the value associated with `@varname(a)` as a regular `Base.Array`. From ec5dc8f0a53029475418cefb062fe9fe346b2a7b Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Thu, 27 Nov 2025 17:41:40 +0000 Subject: [PATCH 31/45] Polish VNT --- src/varnamedtuple.jl | 74 +++++++++++++++++++++++++++---------------- test/varnamedtuple.jl | 8 ++++- 2 files changed, 54 insertions(+), 28 deletions(-) diff --git a/src/varnamedtuple.jl b/src/varnamedtuple.jl index 2068566b4..47340f8f4 100644 --- a/src/varnamedtuple.jl +++ b/src/varnamedtuple.jl @@ -43,8 +43,8 @@ nested containers, and calls itself recursively on all elements that are found i `x1` and `x2`. In other words, if both `x` and `y` are collections with the key `a`, `Base.merge(x, y)[a]` -is `y[a]`, whereas `_merge_recursive(x, y)[a]` be `_merge_recursive(x[a], y[a])`, unless -no specific method is defined for the type of `x` and `y`, in which case +is `y[a]`, whereas `_merge_recursive(x, y)[a]` will be `_merge_recursive(x[a], y[a])`, +unless no specific method is defined for the type of `x` and `y`, in which case `_merge_recursive(x, y) === y` """ _merge_recursive(_, x2) = x2 @@ -81,12 +81,18 @@ way of saying whether the right hand side is of an acceptable size or not. The fact that its size is ill-defined also means that `PartialArray` is not a subtype of `AbstractArray`. -All indexing into `PartialArray`s are done with `getindex` and `setindex!!`. `setindex!`, +All indexing into `PartialArray`s is done with `getindex` and `setindex!!`. `setindex!`, `push!`, etc. are not defined. The element type of a `PartialArray` will change as needed under `setindex!!` to accomoddate the new values. Like `Base.Array`s, `PartialArray`s have a well-defined, compile-time-known element type -`ElType` and number of dimensions `numdims`. +`ElType` and number of dimensions `numdims`. Indices into a `PartialArray` must have exactly +`numdims` elements. + +If the element type of a `PartialArray` is not concrete, any call to `setindex!!` will check +if, after the new value has been set, the element type can be made more concrete. If so, +a new `PartialArray` with a more concrete element type is returned. Thus the element type +of any `PartialArray` should always be as concrete as is allowed by the elements in it. The internal implementation of an `PartialArray` consists of two arrays: one holding the data and the other one being a boolean mask indicating which elements are defined. These @@ -113,13 +119,20 @@ struct PartialArray{ElType,num_dims} end """ - PartialArray{ElType,num_dims}(min_size=nothing) + PartialArray{ElType,num_dims}(args::Vararg{Pair}; min_size=nothing) + +Create a new `PartialArray`. -Create a new empty `PartialArray` with set element type and number of dimensions. +The element type and number of dimensions have to be specified explicitly as type +parameters. The positional arguments can be `Pair`s of indices and values. For example, +```jldoctest +julia> pa = PartialArray{Int,2}((1,2) => 5, (3,4) => 10) +PartialArray{Int,2}((1, 2) => 5, (3, 4) => 10) +``` -The optional argument `min_size` can be used to specify the minimum initial size. This is -purely a performance optimisation, to avoid resizing if the eventual size is known ahead of -time. +The optional keywoard argument `min_size` can be used to specify the minimum initial size. +This is purely a performance optimisation, to avoid resizing if the eventual size is known +ahead of time. """ function PartialArray{ElType,num_dims}( args::Vararg{Pair}; min_size::Union{Tuple,Nothing}=nothing @@ -376,12 +389,12 @@ end function _merge_recursive(pa1::PartialArray, pa2::PartialArray) if ndims(pa1) != ndims(pa2) throw( - ArgumentError("Cannot merge PartialArrays with different number of dimensions") + ArgumentError("Cannot merge PartialArrays with different numbers of dimensions") ) end num_dims = ndims(pa1) merge_size = ntuple(i -> max(_internal_size(pa1, i), _internal_size(pa2, i)), num_dims) - result = if merge_size == _internal_size(pa2) + return if merge_size == _internal_size(pa2) # Either pa2 is strictly bigger than pa1 or they are equal in size. result = copy(pa2) for i in CartesianIndices(pa1.data) @@ -426,23 +439,22 @@ function _merge_recursive(pa1::PartialArray, pa2::PartialArray) result end end - return result end function Base.keys(pa::PartialArray) inds = findall(pa.mask) lenses = map(x -> IndexLens(Tuple(x)), inds) ks = Any[] - for l in lenses - val = getindex(pa.data, l.indices...) + for lens in lenses + val = getindex(pa.data, lens.indices...) if val isa VarNamedTuple subkeys = keys(val) for vn in subkeys - lens = varname_to_lens(vn) - push!(ks, _compose_no_identity(lens, l)) + sublens = _varname_to_lens(vn) + push!(ks, _compose_no_identity(sublens, lens)) end else - push!(ks, l) + push!(ks, lens) end end return ks @@ -455,8 +467,8 @@ A `NamedTuple`-like structure with `VarName` keys. `VarNamedTuple` is a data structure for storing arbitrary data, keyed by `VarName`s, in an efficient and type stable manner. It is mainly used through `getindex`, `setindex!!`, and -`haskey`, all of which accept `VarName`s and only `VarName`s as keys. Other notable methods -are `merge`, which recursively merges two `VarNamedTuple`s. +`haskey`, all of which accept `VarName`s and only `VarName`s as keys. Anther notable methods +is `merge`, which recursively merges two `VarNamedTuple`s. The there are two major limitations to indexing by VarNamedTuples: @@ -470,6 +482,9 @@ heterogeneous data under different indices of the same symbol. That is, if one e * sets `a[1].b` and `a[2].c`, without setting `a[1].c`. or `a[2].b`, then getting values for `a[1]` or `a[2]` will not be type stable. + +`VarNamedTuple` is intrinsically linked to `PartialArray`, which it'll use to store data +related to `VarName`s with `IndexLens` components. """ struct VarNamedTuple{Names,Values} data::NamedTuple{Names,Values} @@ -513,26 +528,29 @@ end varname_to_lens(name::VarName{S}) where {S} Convert a `VarName` to an `Accessor` lens, wrapping the first symdol in a `PropertyLens`. + +This is used to simplify method dispatch for `_getindx`, `_setindex!!`, and `_haskey`, by +considering `VarName`s to just be a special case of lenses. """ -function varname_to_lens(name::VarName{S}) where {S} +function _varname_to_lens(name::VarName{S}) where {S} return _compose_no_identity(getoptic(name), PropertyLens{S}()) end -_getindex(vnt::VarNamedTuple, name::VarName) = _getindex(vnt, varname_to_lens(name)) +_getindex(vnt::VarNamedTuple, name::VarName) = _getindex(vnt, _varname_to_lens(name)) _getindex(vnt::VarNamedTuple, ::PropertyLens{S}) where {S} = getindex(vnt.data, S) _getindex(vnt::VarNamedTuple, name::Symbol) = vnt.data[name] -_haskey(vnt::VarNamedTuple, name::VarName) = _haskey(vnt, varname_to_lens(name)) +_haskey(vnt::VarNamedTuple, name::VarName) = _haskey(vnt, _varname_to_lens(name)) _haskey(vnt::VarNamedTuple, ::PropertyLens{S}) where {S} = haskey(vnt.data, S) _haskey(vnt::VarNamedTuple, ::typeof(identity)) = true _haskey(::VarNamedTuple, ::IndexLens) = false function _setindex!!(vnt::VarNamedTuple, value, name::VarName) - return _setindex!!(vnt, value, varname_to_lens(name)) + return _setindex!!(vnt, value, _varname_to_lens(name)) end function _setindex!!(vnt::VarNamedTuple, value, ::PropertyLens{S}) where {S} - # I would like this to just read + # I would like for this to just read # return VarNamedTuple(_setindex!!(vnt.data, value, S)) # but that seems to be type unstable. Why? Shouldn't it obviously be the same as the # below? @@ -556,6 +574,8 @@ function _merge_recursive(vnt1::VarNamedTuple, vnt2::VarNamedTuple) return VarNamedTuple(result_data) end +# TODO(mhauru) The below remains unfinished an undertested. I think it's incorrect for more +# complex VarNames. It is unexported though. """ apply!!(func, vnt::VarNamedTuple, name::VarName) @@ -565,7 +585,7 @@ Apply `func` to the subdata at `name` in `vnt`, and set the result back at `name julia> vnt = VarNamedTuple() () -julia> vnt = setindex!!(vnt, [1,2,3], @varname(a)) +julia> vnt = setindex!!(vnt, [1, 2, 3], @varname(a)) (a -> [1, 2, 3]) julia> VarNamedTuples.apply!!(x -> x .+ 1, vnt, @varname(a)) @@ -650,12 +670,12 @@ function make_leaf(value, optic::ComposedFunction) return make_leaf(sub, optic.inner) end -function make_leaf(value, optic::IndexLens{T}) where {T} +function make_leaf(value, optic::IndexLens) inds = optic.indices num_inds = length(inds) # Check if any of the indices are ranges or colons. If yes, value needs to be an # AbstractArray. Otherwise it needs to be an individual value. - et = _is_multiindex(optic.indices) ? eltype(value) : typeof(value) + et = _is_multiindex(inds) ? eltype(value) : typeof(value) pa = PartialArray{et,num_inds}() return _setindex!!(pa, value, optic) end diff --git a/test/varnamedtuple.jl b/test/varnamedtuple.jl index f55f8b996..803d8c546 100644 --- a/test/varnamedtuple.jl +++ b/test/varnamedtuple.jl @@ -19,10 +19,12 @@ function test_invariants(vnt::VarNamedTuple) v = getindex(vnt, k) vnt2 = setindex!!(copy(vnt), v, k) @test vnt == vnt2 + @test hash(vnt) == hash(vnt2) end # Check that the printed representation can be parsed back to an equal VarNamedTuple. vnt3 = eval(Meta.parse(repr(vnt))) @test vnt == vnt3 + @test hash(vnt) == hash(vnt3) end @testset "VarNamedTuple" begin @@ -417,13 +419,17 @@ end io = IOBuffer() show(io, vnt) output = String(take!(io)) + # Depending on what's in scope, and maybe sometimes even the Julia version, + # sometimes types in the output are fully qualified, sometimes not. To avoid + # brittle tests, we normalise the output: + output = replace(output, "DynamicPPL." => "", "VarNamedTuples." => "") @test output == """ VarNamedTuple(; a="s", b=[1, 2, 3], \ c=PartialArray{Symbol,1}((2,) => :dada), \ d=VarNamedTuple(; \ e=PartialArray{VarNamedTuple{(:f,), \ Tuple{VarNamedTuple{(:g,), \ - Tuple{DynamicPPL.VarNamedTuples.PartialArray{Float64, 1}}}}},1}((3,) => \ + Tuple{PartialArray{Float64, 1}}}}},1}((3,) => \ VarNamedTuple(; f=VarNamedTuple(; g=PartialArray{Float64,1}((1,) => 16.0, \ (2,) => 17.0))))))""" end From 3ca36c48b9b52175a788407a90d11da940e12cac Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Thu, 27 Nov 2025 18:26:06 +0000 Subject: [PATCH 32/45] Make VNT merge type stable. Simplify printing, improve tests. --- src/varnamedtuple.jl | 55 ++++++++++++++++++---------------- test/varnamedtuple.jl | 70 +++++++++++++++++++++++++++++++------------ 2 files changed, 80 insertions(+), 45 deletions(-) diff --git a/src/varnamedtuple.jl b/src/varnamedtuple.jl index 47340f8f4..e49e1cb66 100644 --- a/src/varnamedtuple.jl +++ b/src/varnamedtuple.jl @@ -461,7 +461,7 @@ function Base.keys(pa::PartialArray) end """ - VarNamedTuple{Names,Values} + VarNamedTuple{names,Values} A `NamedTuple`-like structure with `VarName` keys. @@ -496,27 +496,19 @@ Base.:(==)(vnt1::VarNamedTuple, vnt2::VarNamedTuple) = vnt1.data == vnt2.data Base.hash(vnt::VarNamedTuple, h::UInt) = hash(vnt.data, h) function Base.show(io::IO, vnt::VarNamedTuple) - print(io, "VarNamedTuple(;") - for (i, (name, value)) in enumerate(pairs(vnt.data)) - if i > 1 - print(io, ",") - end - print(io, " ") - print(io, name) - print(io, "=") - # Note the distinction: The raw strings that form part of the structure of the print - # out are `print`ed, whereas the value itself is `show`n. The latter ensures that - # strings are quoted, Symbols are prefixed with :, etc. - show(io, value) + if isempty(vnt.data) + return print(io, "VarNamedTuple()") end - return print(io, ")") + print(io, "VarNamedTuple") + show(io, vnt.data) + return nothing end -function Base.copy(vnt::VarNamedTuple{Names}) where {Names} +function Base.copy(vnt::VarNamedTuple{names}) where {names} # Make a shallow copy of vnt, except for any VarNamedTuple or PartialArray elements, # which we recursively copy. return VarNamedTuple( - NamedTuple{Names}( + NamedTuple{names}( map( x -> x isa Union{VarNamedTuple,PartialArray} ? copy(x) : x, values(vnt.data) ), @@ -559,19 +551,25 @@ end Base.merge(x1::VarNamedTuple, x2::VarNamedTuple) = _merge_recursive(x1, x2) -# TODO(mhauru) Check the performance of this, and make it into a generated function if -# necessary. -function _merge_recursive(vnt1::VarNamedTuple, vnt2::VarNamedTuple) - result_data = vnt1.data - for k in keys(vnt2.data) - val = if haskey(result_data, k) - _merge_recursive(result_data[k], vnt2.data[k]) +# This needs to be a generated function for type stability. +@generated function _merge_recursive( + vnt1::VarNamedTuple{names1}, vnt2::VarNamedTuple{names2} +) where {names1,names2} + all_names = union(names1, names2) + exs = Expr[] + push!(exs, :(data = (;))) + for name in all_names + val_expr = if name in names1 && name in names2 + :(_merge_recursive(vnt1.data[$(QuoteNode(name))], vnt2.data[$(QuoteNode(name))])) + elseif name in names1 + :(vnt1.data[$(QuoteNode(name))]) else - vnt2.data[k] + :(vnt2.data[$(QuoteNode(name))]) end - Accessors.@reset result_data[k] = val + push!(exs, :(data = merge(data, NamedTuple{($(QuoteNode(name)),)}(($val_expr,))))) end - return VarNamedTuple(result_data) + push!(exs, :(return VarNamedTuple(data))) + return Expr(:block, exs...) end # TODO(mhauru) The below remains unfinished an undertested. I think it's incorrect for more @@ -601,6 +599,11 @@ function apply!!(func, vnt::VarNamedTuple, name::VarName) return _setindex!!(vnt, new_subdata, name) end +# TODO(mhauru) Should this return tuples, like it does now? That makes sense for +# VarNamedTuple itself, but if there is a nested PartialArray the tuple might get very big. +# Also, this is not very type stable, it fails even in basic cases. A generated function +# would help, but I failed to make one. Might be something to do with a recursive +# generated function. function Base.keys(vnt::VarNamedTuple) result = () for sym in keys(vnt.data) diff --git a/test/varnamedtuple.jl b/test/varnamedtuple.jl index 803d8c546..53ce10e94 100644 --- a/test/varnamedtuple.jl +++ b/test/varnamedtuple.jl @@ -3,6 +3,7 @@ module VarNamedTupleTests using Test: @inferred, @test, @test_throws, @testset using DynamicPPL: DynamicPPL, @varname, VarNamedTuple using DynamicPPL.VarNamedTuples: PartialArray +using AbstractPPL: VarName, prefix using BangBang: setindex!! """ @@ -25,6 +26,9 @@ function test_invariants(vnt::VarNamedTuple) vnt3 = eval(Meta.parse(repr(vnt))) @test vnt == vnt3 @test hash(vnt) == hash(vnt3) + # Check that merge with an empty VarNamedTuple is a no-op. + @test merge(vnt, VarNamedTuple()) == vnt + @test merge(VarNamedTuple(), vnt) == vnt end @testset "VarNamedTuple" begin @@ -186,6 +190,32 @@ end vnt = setindex!!(vnt, 1.0, @varname(n[2].b)) @test @inferred(getindex(vnt, @varname(n[2].b))) == 1.0 test_invariants(vnt) + + # Some funky Symbols in VarNames + # TODO(mhauru) This still isn't as robust as it should be, for instance Symbol(":") + # fails the eval(Meta.parse(print(vnt))) == vnt test because NamedTuple show doesn't + # respect the eval-property. + vn1 = VarName{Symbol("a b c")}() + vnt = @inferred(setindex!!(vnt, 2, vn1)) + @test @inferred(getindex(vnt, vn1)) == 2 + test_invariants(vnt) + vn2 = VarName{Symbol("1")}() + vnt = @inferred(setindex!!(vnt, 3, vn2)) + @test @inferred(getindex(vnt, vn2)) == 3 + test_invariants(vnt) + vn3 = VarName{Symbol("?!")}() + vnt = @inferred(setindex!!(vnt, 4, vn3)) + @test @inferred(getindex(vnt, vn3)) == 4 + test_invariants(vnt) + vnt = VarNamedTuple() + vn4 = prefix(prefix(vn1, vn2), vn3) + vnt = @inferred(setindex!!(vnt, 5, vn4)) + @test @inferred(getindex(vnt, vn4)) == 5 + test_invariants(vnt) + vn5 = prefix(prefix(vn3, vn2), vn1) + vnt = @inferred(setindex!!(vnt, 6, vn5)) + @test @inferred(getindex(vnt, vn5)) == 6 + test_invariants(vnt) end @testset "equality" begin @@ -229,7 +259,7 @@ end vnt2 = VarNamedTuple() expected_merge = VarNamedTuple() # TODO(mhauru) Wrap this merge in @inferred, likewise other merges where it makes sense. - @test merge(vnt1, vnt2) == expected_merge + @test @inferred(merge(vnt1, vnt2)) == expected_merge vnt1 = setindex!!(vnt1, 1.0, @varname(a)) vnt2 = setindex!!(vnt2, 2.0, @varname(b)) @@ -238,7 +268,7 @@ end expected_merge = setindex!!(expected_merge, 1.0, @varname(a)) expected_merge = setindex!!(expected_merge, 2, @varname(c)) expected_merge = setindex!!(expected_merge, 2.0, @varname(b)) - @test merge(vnt1, vnt2) == expected_merge + @test @inferred(merge(vnt1, vnt2)) == expected_merge vnt1 = VarNamedTuple() vnt2 = VarNamedTuple() @@ -250,7 +280,7 @@ end expected_merge = setindex!!(expected_merge, [1], @varname(d.a)) expected_merge = setindex!!(expected_merge, [2, 2], @varname(d.c)) expected_merge = setindex!!(expected_merge, [2, 2], @varname(d.b)) - @test merge(vnt1, vnt2) == expected_merge + @test @inferred(merge(vnt1, vnt2)) == expected_merge vnt1 = setindex!!(vnt1, 1, @varname(e.a[1])) vnt2 = setindex!!(vnt2, 2, @varname(e.a[2])) @@ -259,13 +289,13 @@ end vnt1 = setindex!!(vnt1, 1, @varname(e.a[3])) vnt2 = setindex!!(vnt2, 2, @varname(e.a[3])) expected_merge = setindex!!(expected_merge, 2, @varname(e.a[3])) - @test merge(vnt1, vnt2) == expected_merge + @test @inferred(merge(vnt1, vnt2)) == expected_merge vnt1 = setindex!!(vnt1, fill(1, 4), @varname(e.a[7:10])) vnt2 = setindex!!(vnt2, fill(2, 4), @varname(e.a[8:11])) expected_merge = setindex!!(expected_merge, 1, @varname(e.a[7])) expected_merge = setindex!!(expected_merge, fill(2, 4), @varname(e.a[8:11])) - @test merge(vnt1, vnt2) == expected_merge + @test @inferred(merge(vnt1, vnt2)) == expected_merge vnt1 = setindex!!(vnt1, ["1", "1"], @varname(f.a[1].b.c[2, 2].d[1, 3:4])) vnt2 = setindex!!(vnt2, ["2", "2"], @varname(f.a[1].b.c[2, 2].d[1, 3:4])) @@ -289,9 +319,9 @@ end expected_merge_12 = setindex!!(expected_merge_12, 1, @varname(a[257])) expected_merge_12 = setindex!!(expected_merge_12, 2, @varname(a[1])) expected_merge_12 = setindex!!(expected_merge_12, 2, @varname(a[2])) - @test merge(vnt1, vnt2) == expected_merge_12 + @test @inferred(merge(vnt1, vnt2)) == expected_merge_12 expected_merge_21 = setindex!!(expected_merge_12, 1, @varname(a[1])) - @test merge(vnt2, vnt1) == expected_merge_21 + @test @inferred(merge(vnt2, vnt1)) == expected_merge_21 vnt1 = VarNamedTuple() vnt2 = VarNamedTuple() @@ -310,11 +340,13 @@ end @testset "keys" begin vnt = VarNamedTuple() - @test keys(vnt) == () + @test @inferred(keys(vnt)) == () @test all(x -> haskey(vnt, x), keys(vnt)) vnt = setindex!!(vnt, 1.0, @varname(a)) - @test keys(vnt) == (@varname(a),) + # TODO(mhauru) that the below passes @inferred, but any of the later ones don't. + # We should improve type stability of keys(). + @test @inferred(keys(vnt)) == (@varname(a),) @test all(x -> haskey(vnt, x), keys(vnt)) vnt = setindex!!(vnt, [1, 2, 3], @varname(b)) @@ -394,26 +426,26 @@ end io = IOBuffer() show(io, vnt) output = String(take!(io)) - @test output == "VarNamedTuple(;)" + @test output == "VarNamedTuple()" vnt = setindex!!(vnt, "s", @varname(a)) io = IOBuffer() show(io, vnt) output = String(take!(io)) - @test output == """VarNamedTuple(; a="s")""" + @test output == """VarNamedTuple(a = "s",)""" vnt = setindex!!(vnt, [1, 2, 3], @varname(b)) io = IOBuffer() show(io, vnt) output = String(take!(io)) - @test output == """VarNamedTuple(; a="s", b=[1, 2, 3])""" + @test output == """VarNamedTuple(a = "s", b = [1, 2, 3])""" vnt = setindex!!(vnt, :dada, @varname(c[2])) io = IOBuffer() show(io, vnt) output = String(take!(io)) @test output == """ - VarNamedTuple(; a="s", b=[1, 2, 3], c=PartialArray{Symbol,1}((2,) => :dada))""" + VarNamedTuple(a = "s", b = [1, 2, 3], c = PartialArray{Symbol,1}((2,) => :dada))""" vnt = setindex!!(vnt, [16.0, 17.0], @varname(d.e[3].f.g[1:2])) io = IOBuffer() @@ -424,14 +456,14 @@ end # brittle tests, we normalise the output: output = replace(output, "DynamicPPL." => "", "VarNamedTuples." => "") @test output == """ - VarNamedTuple(; a="s", b=[1, 2, 3], \ - c=PartialArray{Symbol,1}((2,) => :dada), \ - d=VarNamedTuple(; \ - e=PartialArray{VarNamedTuple{(:f,), \ + VarNamedTuple(a = "s", b = [1, 2, 3], \ + c = PartialArray{Symbol,1}((2,) => :dada), \ + d = VarNamedTuple(\ + e = PartialArray{VarNamedTuple{(:f,), \ Tuple{VarNamedTuple{(:g,), \ Tuple{PartialArray{Float64, 1}}}}},1}((3,) => \ - VarNamedTuple(; f=VarNamedTuple(; g=PartialArray{Float64,1}((1,) => 16.0, \ - (2,) => 17.0))))))""" + VarNamedTuple(f = VarNamedTuple(g = PartialArray{Float64,1}((1,) => 16.0, \ + (2,) => 17.0),),)),))""" end end From 59f67fd1c62a5ad8885fa207f666dc58d713117c Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Thu, 27 Nov 2025 19:23:45 +0000 Subject: [PATCH 33/45] Add VNT too API docs --- docs/src/api.md | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/docs/src/api.md b/docs/src/api.md index adb476db5..a3d93aa22 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -364,6 +364,12 @@ Base.empty! SimpleVarInfo ``` +#### `VarNamedTuple` + +```@docs +VarNamedTuple +``` + ### Accumulators The subtypes of [`AbstractVarInfo`](@ref) store the cumulative log prior and log likelihood, and sometimes other variables that change during executing, in what are called accumulators. From 9aba468eeb65471f26469475d168b90cdcf54b61 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Thu, 27 Nov 2025 19:27:36 +0000 Subject: [PATCH 34/45] Fix doctests --- src/varnamedtuple.jl | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/src/varnamedtuple.jl b/src/varnamedtuple.jl index e49e1cb66..d3f2ba13a 100644 --- a/src/varnamedtuple.jl +++ b/src/varnamedtuple.jl @@ -126,8 +126,10 @@ Create a new `PartialArray`. The element type and number of dimensions have to be specified explicitly as type parameters. The positional arguments can be `Pair`s of indices and values. For example, ```jldoctest +julia> using DynamicPPL.VarNamedTuples: PartialArray + julia> pa = PartialArray{Int,2}((1,2) => 5, (3,4) => 10) -PartialArray{Int,2}((1, 2) => 5, (3, 4) => 10) +PartialArray{Int64,2}((1, 2) => 5, (3, 4) => 10) ``` The optional keywoard argument `min_size` can be used to specify the minimum initial size. @@ -580,14 +582,18 @@ end Apply `func` to the subdata at `name` in `vnt`, and set the result back at `name`. ```jldoctest +julia> using DynamicPPL: VarNamedTuple, setindex!! + +julia> using DynamicPPL.VarNamedTuples: apply!! + julia> vnt = VarNamedTuple() -() +VarNamedTuple() julia> vnt = setindex!!(vnt, [1, 2, 3], @varname(a)) -(a -> [1, 2, 3]) +VarNamedTuple(a = [1, 2, 3],) -julia> VarNamedTuples.apply!!(x -> x .+ 1, vnt, @varname(a)) -(a -> [2, 3, 4]) +julia> apply!!(x -> x .+ 1, vnt, @varname(a)) +VarNamedTuple(a = [2, 3, 4],) ``` """ function apply!!(func, vnt::VarNamedTuple, name::VarName) From 0b4c772460f8283b45e919f7e6277db864f2371b Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Thu, 27 Nov 2025 19:39:24 +0000 Subject: [PATCH 35/45] Clean up tests a bit --- test/varnamedtuple.jl | 14 ++------------ 1 file changed, 2 insertions(+), 12 deletions(-) diff --git a/test/varnamedtuple.jl b/test/varnamedtuple.jl index 53ce10e94..67f3d5c2b 100644 --- a/test/varnamedtuple.jl +++ b/test/varnamedtuple.jl @@ -258,7 +258,6 @@ end vnt1 = VarNamedTuple() vnt2 = VarNamedTuple() expected_merge = VarNamedTuple() - # TODO(mhauru) Wrap this merge in @inferred, likewise other merges where it makes sense. @test @inferred(merge(vnt1, vnt2)) == expected_merge vnt1 = setindex!!(vnt1, 1.0, @varname(a)) @@ -341,29 +340,23 @@ end @testset "keys" begin vnt = VarNamedTuple() @test @inferred(keys(vnt)) == () - @test all(x -> haskey(vnt, x), keys(vnt)) vnt = setindex!!(vnt, 1.0, @varname(a)) # TODO(mhauru) that the below passes @inferred, but any of the later ones don't. # We should improve type stability of keys(). @test @inferred(keys(vnt)) == (@varname(a),) - @test all(x -> haskey(vnt, x), keys(vnt)) vnt = setindex!!(vnt, [1, 2, 3], @varname(b)) @test keys(vnt) == (@varname(a), @varname(b)) - @test all(x -> haskey(vnt, x), keys(vnt)) vnt = setindex!!(vnt, 15, @varname(b[2])) @test keys(vnt) == (@varname(a), @varname(b)) - @test all(x -> haskey(vnt, x), keys(vnt)) vnt = setindex!!(vnt, [10], @varname(c.x.y)) @test keys(vnt) == (@varname(a), @varname(b), @varname(c.x.y)) - @test all(x -> haskey(vnt, x), keys(vnt)) vnt = setindex!!(vnt, -1.0, @varname(d[4])) @test keys(vnt) == (@varname(a), @varname(b), @varname(c.x.y), @varname(d[4])) - @test all(x -> haskey(vnt, x), keys(vnt)) vnt = setindex!!(vnt, 2.0, @varname(e.f[3, 3].g.h[2, 4, 1].i)) @test keys(vnt) == ( @@ -373,7 +366,6 @@ end @varname(d[4]), @varname(e.f[3, 3].g.h[2, 4, 1].i), ) - @test all(x -> haskey(vnt, x), keys(vnt)) vnt = setindex!!(vnt, fill(1.0, 4), @varname(j[1:4])) @test keys(vnt) == ( @@ -387,7 +379,6 @@ end @varname(j[3]), @varname(j[4]), ) - @test all(x -> haskey(vnt, x), keys(vnt)) vnt = setindex!!(vnt, 1.0, @varname(j[6])) @test keys(vnt) == ( @@ -402,7 +393,6 @@ end @varname(j[4]), @varname(j[6]), ) - @test all(x -> haskey(vnt, x), keys(vnt)) vnt = setindex!!(vnt, 1.0, @varname(n[2].a)) @test keys(vnt) == ( @@ -418,7 +408,6 @@ end @varname(j[6]), @varname(n[2].a), ) - @test all(x -> haskey(vnt, x), keys(vnt)) end @testset "printing" begin @@ -445,7 +434,8 @@ end show(io, vnt) output = String(take!(io)) @test output == """ - VarNamedTuple(a = "s", b = [1, 2, 3], c = PartialArray{Symbol,1}((2,) => :dada))""" + VarNamedTuple(a = "s", b = [1, 2, 3], \ + c = PartialArray{Symbol,1}((2,) => :dada))""" vnt = setindex!!(vnt, [16.0, 17.0], @varname(d.e[3].f.g[1:2])) io = IOBuffer() From 38662a8537ac2d00f10f743affd6116c92e06e4a Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Thu, 27 Nov 2025 19:46:22 +0000 Subject: [PATCH 36/45] Fix API docs --- docs/src/api.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/src/api.md b/docs/src/api.md index a3d93aa22..1a22e4cdb 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -367,7 +367,7 @@ SimpleVarInfo #### `VarNamedTuple` ```@docs -VarNamedTuple +DynamicPPL.VarNamedTuples.VarNamedTuple ``` ### Accumulators From e41afcaf063124cd61fbefe3a58c5966be8ca61f Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Thu, 27 Nov 2025 19:47:34 +0000 Subject: [PATCH 37/45] Fix a bug and a docstring --- src/chains.jl | 5 +---- src/contexts/init.jl | 8 -------- 2 files changed, 1 insertion(+), 12 deletions(-) diff --git a/src/chains.jl b/src/chains.jl index d01606c3d..4d69b3590 100644 --- a/src/chains.jl +++ b/src/chains.jl @@ -136,10 +136,7 @@ function ParamsWithStats( include_log_probs::Bool=true, ) where {Tlink} strategy = InitFromParams( - VectorWithRanges{Tlink}( - ldf._iden_varname_ranges, ldf._varname_ranges, param_vector - ), - nothing, + VectorWithRanges{Tlink}(ldf._varname_ranges, param_vector), nothing ) accs = if include_log_probs ( diff --git a/src/contexts/init.jl b/src/contexts/init.jl index 305f28767..90394a24c 100644 --- a/src/contexts/init.jl +++ b/src/contexts/init.jl @@ -227,14 +227,6 @@ this `VectorWithRanges` are linked/not linked, or `nothing` if either the linkin not known or is mixed, i.e. some are linked while others are not. Using `nothing` does not affect functionality or correctness, but causes more work to be done at runtime, with possible impacts on type stability and performance. - -In the simplest case, this could be accomplished only with a single dictionary mapping -VarNames to ranges and link status. However, for performance reasons, we separate out -VarNames with identity optics into a NamedTuple (`iden_varname_ranges`). All -non-identity-optic VarNames are stored in the `varname_ranges` Dict. - -It would be nice to improve the NamedTuple and Dict approach. See, e.g. -https://github.com/TuringLang/DynamicPPL.jl/issues/1116. """ struct VectorWithRanges{Tlink,VNT<:VarNamedTuple,T<:AbstractVector{<:Real}} # Ranges for all VarNames From 8c50bbb2bd3e6aaff059ea9a5a3afe6455f5a4b1 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Fri, 28 Nov 2025 09:14:02 +0000 Subject: [PATCH 38/45] Apply suggestions from code review Co-authored-by: Penelope Yong --- docs/src/internals/varnamedtuple.md | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/docs/src/internals/varnamedtuple.md b/docs/src/internals/varnamedtuple.md index 7198aae9f..2d787574b 100644 --- a/docs/src/internals/varnamedtuple.md +++ b/docs/src/internals/varnamedtuple.md @@ -35,7 +35,7 @@ The `identity` lens presents no complications, and in fact in the above example It is the `IndexLenses` that require more structure. An `IndexLens` is the square bracket indexing part in `VarName`s like `@varname(x[1])`, `@varname(x[1].a.b[2:3])` and `@varname(x[:].b[1,2,3].c[1:5,:])`. -`VarNamedTuple` can not deal with `IndexLens`es in their full generality, for reasons we'll discuss below. +`VarNamedTuple` cannot deal with `IndexLens`es in their full generality, for reasons we'll discuss below. Instead we restrict ourselves to `IndexLens`es where the indices are integers, explicit ranges with end points, like `1:5`, or tuples thereof. When storing data in a `VarNamedTuple`, we recursively go through the nested lenses in the `VarName`, inserting a new `VarNamedTuple` for every `PropertyLens`. @@ -73,7 +73,7 @@ julia> print(vnt) VarNamedTuple(; a=1.0, b=VarNamedTuple(; c=[2.0, 3.0]), d=VarNamedTuple(; e=PartialArray{VarNamedTuple{(:f,), Tuple{DynamicPPL.VarNamedTuples.PartialArray{Symbol, 1}}},1}((2,) => VarNamedTuple(; f=PartialArray{Symbol,1}((3,) => hip, (4,) => hop))))) ``` -The output there may be a bit hard bit hard to parse, so to illustrate: +The output there may be a bit hard to parse, so to illustrate: ```julia julia> vnt[@varname(b)] @@ -103,7 +103,7 @@ a / | b \ d The above code also highlights how setting indices in a `VarNamedTuple` is done using `BangBang.setindex!!`. We do not define a method for `Base.setindex!` at all, `setindex!!` is the only way. -This is because `VarNamedTuple` mixes mutable an immutable data structures. +This is because `VarNamedTuple` mixes mutable and immutable data structures. It is also for user convenience: One does not ever have to think about whether the value that one is inserting into a `VarNamedTuple` is of the right type to fit in. Rather the containers will flex to fit it, keeping element types concrete when possible, but making them abstract if needed. @@ -150,7 +150,7 @@ This design has a several of benefits, for performance and generality, but it al 1. The lack of support for `Colon`s in `VarName`s. 2. The lack of support for some other indexing syntaxes supported by Julia, such as linear indexing and boolean indexing. - 3. `VarNamedTuple` can not store indices with different numbers of dimensions in the same value, so for instance `@varname(a[1])` and `@varname(a[1,1])` can not be stored in the same `VarNamedTuple`. - 4. There is an assymmetry between storing arrays with `setindex!!(vnt, array, @varname(a))` and elements of arrays with `setindex!!(vnt, element, @varname(a[i]))`. + 3. `VarNamedTuple` cannot store indices with different numbers of dimensions in the same value, so for instance `@varname(a[1])` and `@varname(a[1,1])` cannot be stored in the same `VarNamedTuple`. + 4. There is an asymmetry between storing arrays with `setindex!!(vnt, array, @varname(a))` and elements of arrays with `setindex!!(vnt, element, @varname(a[i]))`. The former stores the whole array, which can then be indexed with both `@varname(a)` and `@varname(a[i])`. The latter stores only individual elements, and even if all elements have been set, one still can't get the value associated with `@varname(a)` as a regular `Base.Array`. From cae8864c636bb33d49d26a8aaaf19ec173935689 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Fri, 28 Nov 2025 09:16:25 +0000 Subject: [PATCH 39/45] Fix VNT docs --- docs/src/internals/varnamedtuple.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/src/internals/varnamedtuple.md b/docs/src/internals/varnamedtuple.md index 2d787574b..47ff9c65e 100644 --- a/docs/src/internals/varnamedtuple.md +++ b/docs/src/internals/varnamedtuple.md @@ -94,9 +94,9 @@ Or as a tree drawing, where `PA` marks a `PartialArray`: ``` /----VNT------\ a / | b \ d - 1 [2.0, 3.0] VNT - | e - PA(2 => VNT) + 1 VNT VNT + | c | e + [2.0, 3.0] PA(2 => VNT) | f PA(3 => :hip, 4 => :hop) ``` From c27f5e0854b11e21c960569dcf478df2dba066a6 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Mon, 1 Dec 2025 13:22:04 +0000 Subject: [PATCH 40/45] Make threadsafe evaluation opt-in (#1151) * Make threadsafe evaluation opt-in * Reduce number of type parameters in methods * Make `warned_warn_about_threads_threads_threads_threads` shorter * Improve `setthreadsafe` docstring * warn on bare `@threads` as well * fix merge * Fix performance issues * Use maxthreadid() in TSVI * Move convert_eltype code to threadsafe eval function * Point to new Turing docs page * Add a test for setthreadsafe * Tidy up check_model * Apply suggestions from code review Fix outdated docstrings Co-authored-by: Markus Hauru * Improve warning message * Export `requires_threadsafe` * Add an actual docstring for `requires_threadsafe` --------- Co-authored-by: Markus Hauru --- HISTORY.md | 39 +++++++- docs/src/api.md | 8 ++ src/DynamicPPL.jl | 2 + src/compiler.jl | 54 ++++++++--- src/debug_utils.jl | 6 +- src/model.jl | 185 ++++++++++++++++++------------------- src/simple_varinfo.jl | 10 +- src/threadsafe.jl | 7 +- src/varinfo.jl | 16 +--- test/compiler.jl | 42 +++++++-- test/logdensityfunction.jl | 78 ++++++++-------- test/threadsafe.jl | 75 +++++---------- 12 files changed, 281 insertions(+), 241 deletions(-) diff --git a/HISTORY.md b/HISTORY.md index ff28349d8..5dcb008d1 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -9,12 +9,49 @@ This version provides a reimplementation of `LogDensityFunction` that provides performance improvements on the order of 2–10× for both model evaluation as well as automatic differentiation. Exact speedups depend on the model size: larger models have less significant speedups because the bulk of the work is done in calls to `logpdf`. -For more information about how this is accomplished, please see https://github.com/TuringLang/DynamicPPL.jl/pull/1113 as well as the `src/fasteval.jl` file, which contains extensive comments. +For more information about how this is accomplished, please see https://github.com/TuringLang/DynamicPPL.jl/pull/1113 as well as the `src/logdensityfunction.jl` file, which contains extensive comments. As a result of this change, `LogDensityFunction` no longer stores a VarInfo inside it. In general, if `ldf` is a `LogDensityFunction`, it is now only valid to access `ldf.model` and `ldf.adtype`. If you were previously relying on this behaviour, you will need to store a VarInfo separately. +#### Threadsafe evaluation + +DynamicPPL models have traditionally supported running some probabilistic statements (e.g. tilde-statements, or `@addlogprob!`) in parallel. +Prior to DynamicPPL 0.39, thread safety for such models used to be enabled by default if Julia was launched with more than one thread. + +In DynamicPPL 0.39, **thread-safe evaluation is now disabled by default**. +If you need it (see below for more discussion of when you _do_ need it), you **must** now manually mark it as so, using: + +```julia +@model f() = ... +model = f() +model = setthreadsafe(model, true) +``` + +The problem with the previous on-by-default is that it can sacrifice a huge amount of performance when thread safety is not needed. +This is especially true when running Julia in a notebook, where multiple threads are often enabled by default. +Furthermore, it is not actually the correct approach: just because Julia has multiple threads does not mean that a particular model actually requires threadsafe evaluation. + +**A model requires threadsafe evaluation if, and only if, the VarInfo object used inside the model is manipulated in parallel.** +This can occur if any of the following are inside `Threads.@threads` or other concurrency functions / macros: + + - tilde-statements + - calls to `@addlogprob!` + - any direct manipulation of the special `__varinfo__` variable + +If you have none of these inside threaded blocks, then you do not need to mark your model as threadsafe. +**Notably, the following do not require threadsafe evaluation:** + + - Using threading for any computation that does not involve VarInfo. For example, you can calculate a log-probability in parallel, and then add it using `@addlogprob!` outside of the threaded block. This does not require threadsafe evaluation. + - Sampling with `AbstractMCMC.MCMCThreads()`. + +For more information about threadsafe evaluation, please see [the Turing docs](https://turinglang.org/docs/usage/threadsafe-evaluation/). + +When threadsafe evaluation is enabled for a model, an internal flag is set on the model. +The value of this flag can be queried using `DynamicPPL.requires_threadsafe(model)`, which returns a boolean. +This function is newly exported in this version of DynamicPPL. + #### Parent and leaf contexts The `DynamicPPL.NodeTrait` function has been removed. diff --git a/docs/src/api.md b/docs/src/api.md index adb476db5..193a6ce4c 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -42,6 +42,14 @@ The context of a model can be set using [`contextualize`](@ref): contextualize ``` +Some models require threadsafe evaluation (see [the Turing docs](https://turinglang.org/docs/usage/threadsafe-evaluation/) for more information on when this is necessary). +If this is the case, one must enable threadsafe evaluation for a model: + +```@docs +setthreadsafe +requires_threadsafe +``` + ## Evaluation With [`rand`](@ref) one can draw samples from the prior distribution of a [`Model`](@ref). diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index a885f6a96..fda428eaa 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -90,6 +90,8 @@ export AbstractVarInfo, Model, getmissings, getargnames, + setthreadsafe, + requires_threadsafe, extract_priors, values_as_in_model, # evaluation diff --git a/src/compiler.jl b/src/compiler.jl index 3324780ca..1b4260121 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -301,7 +301,7 @@ function model(mod, linenumbernode, expr, warn) modeldef = build_model_definition(expr) # Generate main body - modeldef[:body] = generate_mainbody(mod, modeldef[:body], warn) + modeldef[:body] = generate_mainbody(mod, modeldef[:body], warn, true) return build_output(modeldef, linenumbernode) end @@ -346,10 +346,11 @@ Generate the body of the main evaluation function from expression `expr` and arg If `warn` is true, a warning is displayed if internal variables are used in the model definition. """ -generate_mainbody(mod, expr, warn) = generate_mainbody!(mod, Symbol[], expr, warn) +generate_mainbody(mod, expr, warn, warn_threads) = + generate_mainbody!(mod, Symbol[], expr, warn, warn_threads) -generate_mainbody!(mod, found, x, warn) = x -function generate_mainbody!(mod, found, sym::Symbol, warn) +generate_mainbody!(mod, found, x, warn, warn_threads) = x +function generate_mainbody!(mod, found, sym::Symbol, warn, warn_threads) if warn && sym in INTERNALNAMES && sym ∉ found @warn "you are using the internal variable `$sym`" push!(found, sym) @@ -357,17 +358,39 @@ function generate_mainbody!(mod, found, sym::Symbol, warn) return sym end -function generate_mainbody!(mod, found, expr::Expr, warn) +function generate_mainbody!(mod, found, expr::Expr, warn, warn_threads) # Do not touch interpolated expressions expr.head === :$ && return expr.args[1] + # Flag to determine whether we've issued a warning for threadsafe macros Note that this + # detection is not fully correct. We can only detect the presence of a macro that has + # the symbol `Threads.@threads`, however, we can't detect if that *is actually* + # Threads.@threads from Base.Threads. + # Do we don't want escaped expressions because we unfortunately # escape the entire body afterwards. - Meta.isexpr(expr, :escape) && return generate_mainbody(mod, found, expr.args[1], warn) + Meta.isexpr(expr, :escape) && + return generate_mainbody(mod, found, expr.args[1], warn, warn_threads) # If it's a macro, we expand it if Meta.isexpr(expr, :macrocall) - return generate_mainbody!(mod, found, macroexpand(mod, expr; recursive=true), warn) + if ( + expr.args[1] == Symbol("@threads") || + expr.args[1] == Expr(:., :Threads, QuoteNode(Symbol("@threads"))) && + warn_threads + ) + warn_threads = false + @warn ( + "It looks like you are using `Threads.@threads` in your model definition." * + "\n\nNote that since version 0.39 of DynamicPPL, threadsafe evaluation of models is disabled by default." * + " If you need it, you will need to explicitly enable it by creating the model, and then running `model = setthreadsafe(model, true)`." * + "\n\nThreadsafe model evaluation is only needed when parallelising tilde-statements (not arbitrary Julia code), and avoiding it can often lead to significant performance improvements." * + "\n\nPlease see https://turinglang.org/docs/usage/threadsafe-evaluation/ for more details of when threadsafe evaluation is actually required." + ) + end + return generate_mainbody!( + mod, found, macroexpand(mod, expr; recursive=true), warn, warn_threads + ) end # Modify dotted tilde operators. @@ -375,7 +398,7 @@ function generate_mainbody!(mod, found, expr::Expr, warn) if args_dottilde !== nothing L, R = args_dottilde return generate_mainbody!( - mod, found, Base.remove_linenums!(generate_dot_tilde(L, R)), warn + mod, found, Base.remove_linenums!(generate_dot_tilde(L, R)), warn, warn_threads ) end @@ -385,8 +408,8 @@ function generate_mainbody!(mod, found, expr::Expr, warn) L, R = args_tilde return Base.remove_linenums!( generate_tilde( - generate_mainbody!(mod, found, L, warn), - generate_mainbody!(mod, found, R, warn), + generate_mainbody!(mod, found, L, warn, warn_threads), + generate_mainbody!(mod, found, R, warn, warn_threads), ), ) end @@ -397,13 +420,16 @@ function generate_mainbody!(mod, found, expr::Expr, warn) L, R = args_assign return Base.remove_linenums!( generate_assign( - generate_mainbody!(mod, found, L, warn), - generate_mainbody!(mod, found, R, warn), + generate_mainbody!(mod, found, L, warn, warn_threads), + generate_mainbody!(mod, found, R, warn, warn_threads), ), ) end - return Expr(expr.head, map(x -> generate_mainbody!(mod, found, x, warn), expr.args)...) + return Expr( + expr.head, + map(x -> generate_mainbody!(mod, found, x, warn, warn_threads), expr.args)..., + ) end function generate_assign(left, right) @@ -699,7 +725,7 @@ function build_output(modeldef, linenumbernode) # to the call site modeldef[:body] = MacroTools.@q begin $(linenumbernode) - return $(DynamicPPL.Model)($name, $args_nt; $(kwargs_inclusion...)) + return $(DynamicPPL.Model){false}($name, $args_nt; $(kwargs_inclusion...)) end return MacroTools.@q begin diff --git a/src/debug_utils.jl b/src/debug_utils.jl index e8b50a0b7..8810b9819 100644 --- a/src/debug_utils.jl +++ b/src/debug_utils.jl @@ -424,8 +424,10 @@ function check_model_and_trace( # Perform checks before evaluating the model. issuccess = check_model_pre_evaluation(model) - # Force single-threaded execution. - _, varinfo = DynamicPPL.evaluate_threadunsafe!!(model, varinfo) + # TODO(penelopeysm): Implement merge, etc. for DebugAccumulator, and then perform a + # check on the merged accumulator, rather than checking it in the accumulate_assume + # calls. That way we can also correctly support multi-threaded evaluation. + _, varinfo = DynamicPPL.evaluate!!(model, varinfo) # Perform checks after evaluating the model. debug_acc = DynamicPPL.getacc(varinfo, Val(_DEBUG_ACC_NAME)) diff --git a/src/model.jl b/src/model.jl index 7d5bbf2fb..e82fdc60c 100644 --- a/src/model.jl +++ b/src/model.jl @@ -1,5 +1,5 @@ """ - struct Model{F,argnames,defaultnames,missings,Targs,Tdefaults,Ctx<:AbstractContext} + struct Model{F,argnames,defaultnames,missings,Targs,Tdefaults,Ctx<:AbstractContext,Threaded} f::F args::NamedTuple{argnames,Targs} defaults::NamedTuple{defaultnames,Tdefaults} @@ -17,6 +17,10 @@ An argument with a type of `Missing` will be in `missings` by default. However, non-traditional use-cases `missings` can be defined differently. All variables in `missings` are treated as random variables rather than observations. +The `Threaded` type parameter indicates whether the model requires threadsafe evaluation +(i.e., whether the model contains statements which modify the internal VarInfo that are +executed in parallel). By default, this is set to `false`. + The default arguments are used internally when constructing instances of the same model with different arguments. @@ -33,26 +37,27 @@ julia> Model{(:y,)}(f, (x = 1.0, y = 2.0), (x = 42,)) # with special definition Model{typeof(f),(:x, :y),(:x,),(:y,),Tuple{Float64,Float64},Tuple{Int64}}(f, (x = 1.0, y = 2.0), (x = 42,)) ``` """ -struct Model{F,argnames,defaultnames,missings,Targs,Tdefaults,Ctx<:AbstractContext} <: - AbstractProbabilisticProgram +struct Model{ + F,argnames,defaultnames,missings,Targs,Tdefaults,Ctx<:AbstractContext,Threaded +} <: AbstractProbabilisticProgram f::F args::NamedTuple{argnames,Targs} defaults::NamedTuple{defaultnames,Tdefaults} context::Ctx @doc """ - Model{missings}(f, args::NamedTuple, defaults::NamedTuple) + Model{Threaded,missings}(f, args::NamedTuple, defaults::NamedTuple) Create a model with evaluation function `f` and missing arguments overwritten by `missings`. """ - function Model{missings}( + function Model{Threaded,missings}( f::F, args::NamedTuple{argnames,Targs}, defaults::NamedTuple{defaultnames,Tdefaults}, context::Ctx=DefaultContext(), - ) where {missings,F,argnames,Targs,defaultnames,Tdefaults,Ctx} - return new{F,argnames,defaultnames,missings,Targs,Tdefaults,Ctx}( + ) where {missings,F,argnames,Targs,defaultnames,Tdefaults,Ctx,Threaded} + return new{F,argnames,defaultnames,missings,Targs,Tdefaults,Ctx,Threaded}( f, args, defaults, context ) end @@ -66,23 +71,39 @@ Create a model with evaluation function `f` and missing arguments deduced from ` Default arguments `defaults` are used internally when constructing instances of the same model with different arguments. """ -@generated function Model( +@generated function Model{Threaded}( f::F, args::NamedTuple{argnames,Targs}, defaults::NamedTuple{kwargnames,Tkwargs}, context::AbstractContext=DefaultContext(), -) where {F,argnames,Targs,kwargnames,Tkwargs} +) where {Threaded,F,argnames,Targs,kwargnames,Tkwargs} missing_args = Tuple( name for (name, typ) in zip(argnames, Targs.types) if typ <: Missing ) missing_kwargs = Tuple( name for (name, typ) in zip(kwargnames, Tkwargs.types) if typ <: Missing ) - return :(Model{$(missing_args..., missing_kwargs...)}(f, args, defaults, context)) + return :(Model{Threaded,$(missing_args..., missing_kwargs...)}( + f, args, defaults, context + )) +end + +function Model{Threaded}( + f, args::NamedTuple, context::AbstractContext=DefaultContext(); kwargs... +) where {Threaded} + return Model{Threaded}(f, args, NamedTuple(kwargs), context) end -function Model(f, args::NamedTuple, context::AbstractContext=DefaultContext(); kwargs...) - return Model(f, args, NamedTuple(kwargs), context) +""" + requires_threadsafe(model::Model) + +Return whether `model` has been marked as needing threadsafe evaluation (using +`setthreadsafe`). +""" +function requires_threadsafe( + ::Model{F,A,D,M,Ta,Td,Ctx,Threaded} +) where {F,A,D,M,Ta,Td,Ctx,Threaded} + return Threaded end """ @@ -92,7 +113,7 @@ Return a new `Model` with the same evaluation function and other arguments, but with its underlying context set to `context`. """ function contextualize(model::Model, context::AbstractContext) - return Model(model.f, model.args, model.defaults, context) + return Model{requires_threadsafe(model)}(model.f, model.args, model.defaults, context) end """ @@ -105,6 +126,33 @@ function setleafcontext(model::Model, context::AbstractContext) return contextualize(model, setleafcontext(model.context, context)) end +""" + setthreadsafe(model::Model, threadsafe::Bool) + +Returns a new `Model` with its threadsafe flag set to `threadsafe`. + +Threadsafe evaluation ensures correctness when executing model statements that mutate the +internal `VarInfo` object in parallel. For example, this is needed if tilde-statements are +nested inside `Threads.@threads` or similar constructs. + +It is not needed for generic multithreaded operations that don't involve VarInfo. For +example, calculating a log-likelihood term in parallel and then calling `@addlogprob!` +outside of the parallel region is safe without needing to set `threadsafe=true`. + +It is also not needed for multithreaded sampling with AbstractMCMC's `MCMCThreads()`. + +Setting `threadsafe` to `true` increases the overhead in evaluating the model. Please see +[the Turing.jl docs](https://turinglang.org/docs/usage/threadsafe-evaluation/) for more +details. +""" +function setthreadsafe(model::Model{F,A,D,M}, threadsafe::Bool) where {F,A,D,M} + return if requires_threadsafe(model) == threadsafe + model + else + Model{threadsafe,M}(model.f, model.args, model.defaults, model.context) + end +end + """ model | (x = 1.0, ...) @@ -863,16 +911,6 @@ function (model::Model)(rng::Random.AbstractRNG, varinfo::AbstractVarInfo=VarInf return first(init!!(rng, model, varinfo)) end -""" - use_threadsafe_eval(context::AbstractContext, varinfo::AbstractVarInfo) - -Return `true` if evaluation of a model using `context` and `varinfo` should -wrap `varinfo` in `ThreadSafeVarInfo`, i.e. threadsafe evaluation, and `false` otherwise. -""" -function use_threadsafe_eval(context::AbstractContext, varinfo::AbstractVarInfo) - return Threads.nthreads() > 1 -end - """ init!!( [rng::Random.AbstractRNG,] @@ -889,10 +927,7 @@ If `init_strategy` is not provided, defaults to `InitFromPrior()`. Returns a tuple of the model's return value, plus the updated `varinfo` object. """ -@inline function init!!( - # Note that this `@inline` is mandatory for performance, especially for - # LogDensityFunction. If it's not inlined, it leads to extra allocations (even for - # trivial models) and much slower runtime. +function init!!( rng::Random.AbstractRNG, model::Model, vi::AbstractVarInfo, @@ -900,36 +935,11 @@ Returns a tuple of the model's return value, plus the updated `varinfo` object. ) ctx = InitContext(rng, strategy) model = DynamicPPL.setleafcontext(model, ctx) - # TODO(penelopeysm): This should _not_ check Threads.nthreads(). I still don't know what - # it _should_ do, but this is wrong regardless. - # https://github.com/TuringLang/DynamicPPL.jl/issues/1086 - return if Threads.nthreads() > 1 - # TODO(penelopeysm): The logic for setting eltype of accs is very similar to that - # used in `unflatten`. The reason why we need it here is because the VarInfo `vi` - # won't have been filled with parameters prior to `init!!` being called. - # - # Note that this eltype promotion is only needed for threadsafe evaluation. In an - # ideal world, this code should be handled inside `evaluate_threadsafe!!` or a - # similar method. In other words, it should not be here, and it should not be inside - # `unflatten` either. The problem is performance. Shifting this code around can have - # massive, inexplicable, impacts on performance. This should be investigated - # properly. - param_eltype = DynamicPPL.get_param_eltype(strategy) - accs = map(vi.accs) do acc - DynamicPPL.convert_eltype(float_type_with_fallback(param_eltype), acc) - end - vi = DynamicPPL.setaccs!!(vi, accs) - tsvi = ThreadSafeVarInfo(resetaccs!!(vi)) - retval, tsvi_new = DynamicPPL._evaluate!!(model, tsvi) - return retval, setaccs!!(tsvi_new.varinfo, DynamicPPL.getaccs(tsvi_new)) - else - return DynamicPPL._evaluate!!(model, resetaccs!!(vi)) - end + return DynamicPPL.evaluate!!(model, vi) end -@inline function init!!( +function init!!( model::Model, vi::AbstractVarInfo, strategy::AbstractInitStrategy=InitFromPrior() ) - # This `@inline` is also mandatory for performance return init!!(Random.default_rng(), model, vi, strategy) end @@ -938,55 +948,42 @@ end Evaluate the `model` with the given `varinfo`. -If multiple threads are available, the varinfo provided will be wrapped in a -`ThreadSafeVarInfo` before evaluation. +If the model has been marked as requiring threadsafe evaluation, are available, the varinfo +provided will be wrapped in a `ThreadSafeVarInfo` before evaluation. Returns a tuple of the model's return value, plus the updated `varinfo` (unwrapped if necessary). """ function AbstractPPL.evaluate!!(model::Model, varinfo::AbstractVarInfo) - return if use_threadsafe_eval(model.context, varinfo) - evaluate_threadsafe!!(model, varinfo) + return if requires_threadsafe(model) + # Use of float_type_with_fallback(eltype(x)) is necessary to deal with cases where x is + # a gradient type of some AD backend. + # TODO(mhauru) How could we do this more cleanly? The problem case is map_accumulator!! + # for ThreadSafeVarInfo. In that one, if the map produces e.g a ForwardDiff.Dual, but + # the accumulators in the VarInfo are plain floats, we error since we can't change the + # element type of ThreadSafeVarInfo.accs_by_thread. However, doing this conversion here + # messes with cases like using Float32 of logprobs and Float64 for x. Also, this is just + # plain ugly and hacky. + # The below line is finicky for type stability. For instance, assigning the eltype to + # convert to into an intermediate variable makes this unstable (constant propagation + # fails). Take care when editing. + param_eltype = DynamicPPL.get_param_eltype(varinfo, model.context) + accs = map(DynamicPPL.getaccs(varinfo)) do acc + DynamicPPL.convert_eltype(float_type_with_fallback(param_eltype), acc) + end + varinfo = DynamicPPL.setaccs!!(varinfo, accs) + wrapper = ThreadSafeVarInfo(resetaccs!!(varinfo)) + result, wrapper_new = _evaluate!!(model, wrapper) + # TODO(penelopeysm): If seems that if you pass a TSVI to this method, it + # will return the underlying VI, which is a bit counterintuitive (because + # calling TSVI(::TSVI) returns the original TSVI, instead of wrapping it + # again). + return result, setaccs!!(wrapper_new.varinfo, getaccs(wrapper_new)) else - evaluate_threadunsafe!!(model, varinfo) + _evaluate!!(model, resetaccs!!(varinfo)) end end -""" - evaluate_threadunsafe!!(model, varinfo) - -Evaluate the `model` without wrapping `varinfo` inside a `ThreadSafeVarInfo`. - -If the `model` makes use of Julia's multithreading this will lead to undefined behaviour. -This method is not exposed and supposed to be used only internally in DynamicPPL. - -See also: [`evaluate_threadsafe!!`](@ref) -""" -function evaluate_threadunsafe!!(model, varinfo) - return _evaluate!!(model, resetaccs!!(varinfo)) -end - -""" - evaluate_threadsafe!!(model, varinfo, context) - -Evaluate the `model` with `varinfo` wrapped inside a `ThreadSafeVarInfo`. - -With the wrapper, Julia's multithreading can be used for observe statements in the `model` -but parallel sampling will lead to undefined behaviour. -This method is not exposed and supposed to be used only internally in DynamicPPL. - -See also: [`evaluate_threadunsafe!!`](@ref) -""" -function evaluate_threadsafe!!(model, varinfo) - wrapper = ThreadSafeVarInfo(resetaccs!!(varinfo)) - result, wrapper_new = _evaluate!!(model, wrapper) - # TODO(penelopeysm): If seems that if you pass a TSVI to this method, it - # will return the underlying VI, which is a bit counterintuitive (because - # calling TSVI(::TSVI) returns the original TSVI, instead of wrapping it - # again). - return result, setaccs!!(wrapper_new.varinfo, getaccs(wrapper_new)) -end - """ _evaluate!!(model::Model, varinfo) diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index 434480be6..9d3fb1925 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -278,15 +278,7 @@ end function unflatten(svi::SimpleVarInfo, x::AbstractVector) vals = unflatten(svi.values, x) - # TODO(mhauru) See comment in unflatten in src/varinfo.jl for why this conversion is - # required but undesireable. - # The below line is finicky for type stability. For instance, assigning the eltype to - # convert to into an intermediate variable makes this unstable (constant propagation) - # fails. Take care when editing. - accs = map( - acc -> convert_eltype(float_type_with_fallback(eltype(x)), acc), getaccs(svi) - ) - return SimpleVarInfo(vals, accs, svi.transformation) + return SimpleVarInfo(vals, svi.accs, svi.transformation) end function BangBang.empty!!(vi::SimpleVarInfo) diff --git a/src/threadsafe.jl b/src/threadsafe.jl index 89877f385..0e906b6ca 100644 --- a/src/threadsafe.jl +++ b/src/threadsafe.jl @@ -13,12 +13,7 @@ function ThreadSafeVarInfo(vi::AbstractVarInfo) # fields. This is not good practice --- see # https://github.com/TuringLang/DynamicPPL.jl/issues/924 for a full # explanation --- but it has worked okay so far. - # The use of nthreads()*2 here ensures that threadid() doesn't exceed - # the length of the logps array. Ideally, we would use maxthreadid(), - # but Mooncake can't differentiate through that. Empirically, nthreads()*2 - # seems to provide an upper bound to maxthreadid(), so we use that here. - # See https://github.com/TuringLang/DynamicPPL.jl/pull/936 - accs_by_thread = [map(split, getaccs(vi)) for _ in 1:(Threads.nthreads() * 2)] + accs_by_thread = [map(split, getaccs(vi)) for _ in 1:Threads.maxthreadid()] return ThreadSafeVarInfo(vi, accs_by_thread) end ThreadSafeVarInfo(vi::ThreadSafeVarInfo) = vi diff --git a/src/varinfo.jl b/src/varinfo.jl index 486d24191..14e08515c 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -367,21 +367,7 @@ vector_length(md::Metadata) = sum(length, md.ranges) function unflatten(vi::VarInfo, x::AbstractVector) md = unflatten_metadata(vi.metadata, x) - # Use of float_type_with_fallback(eltype(x)) is necessary to deal with cases where x is - # a gradient type of some AD backend. - # TODO(mhauru) How could we do this more cleanly? The problem case is map_accumulator!! - # for ThreadSafeVarInfo. In that one, if the map produces e.g a ForwardDiff.Dual, but - # the accumulators in the VarInfo are plain floats, we error since we can't change the - # element type of ThreadSafeVarInfo.accs_by_thread. However, doing this conversion here - # messes with cases like using Float32 of logprobs and Float64 for x. Also, this is just - # plain ugly and hacky. - # The below line is finicky for type stability. For instance, assigning the eltype to - # convert to into an intermediate variable makes this unstable (constant propagation) - # fails. Take care when editing. - accs = map( - acc -> convert_eltype(float_type_with_fallback(eltype(x)), acc), copy(getaccs(vi)) - ) - return VarInfo(md, accs) + return VarInfo(md, vi.accs) end # We would call this `unflatten` if not for `unflatten` having a method for NamedTuples in diff --git a/test/compiler.jl b/test/compiler.jl index b1309254e..9056f666a 100644 --- a/test/compiler.jl +++ b/test/compiler.jl @@ -606,12 +606,7 @@ module Issue537 end @model demo() = return __varinfo__ retval, svi = DynamicPPL.init!!(demo(), SimpleVarInfo()) @test svi == SimpleVarInfo() - if Threads.nthreads() > 1 - @test retval isa DynamicPPL.ThreadSafeVarInfo{<:SimpleVarInfo} - @test retval.varinfo == svi - else - @test retval == svi - end + @test retval == svi # We should not be altering return-values other than at top-level. @model function demo() @@ -793,4 +788,39 @@ module Issue537 end res = model() @test res == (a=1, b=1, c=2, d=2, t=DynamicPPL.TypeWrap{Int}()) end + + @testset "Threads.@threads detection" begin + # Check that the compiler detects when `Threads.@threads` is used inside a model + + e1 = quote + @model function f1() + Threads.@threads for i in 1:10 + x[i] ~ Normal() + end + end + end + @test_logs (:warn, r"threadsafe evaluation") eval(e1) + + e2 = quote + @model function f2() + for j in 1:10 + Threads.@threads for i in 1:10 + x[i] ~ Normal() + end + end + end + end + @test_logs (:warn, r"threadsafe evaluation") eval(e2) + + e3 = quote + @model function f3() + begin + Threads.@threads for i in 1:10 + x[i] ~ Normal() + end + end + end + end + @test_logs (:warn, r"threadsafe evaluation") eval(e3) + end end diff --git a/test/logdensityfunction.jl b/test/logdensityfunction.jl index f43ed45a4..1d609a013 100644 --- a/test/logdensityfunction.jl +++ b/test/logdensityfunction.jl @@ -51,21 +51,19 @@ using Mooncake: Mooncake end @testset "Threaded observe" begin - if Threads.nthreads() > 1 - @model function threaded(y) - x ~ Normal() - Threads.@threads for i in eachindex(y) - y[i] ~ Normal(x) - end + @model function threaded(y) + x ~ Normal() + Threads.@threads for i in eachindex(y) + y[i] ~ Normal(x) end - N = 100 - model = threaded(zeros(N)) - ldf = DynamicPPL.LogDensityFunction(model) - - xs = [1.0] - @test LogDensityProblems.logdensity(ldf, xs) ≈ - logpdf(Normal(), xs[1]) + N * logpdf(Normal(xs[1]), 0.0) end + N = 100 + model = setthreadsafe(threaded(zeros(N)), true) + ldf = DynamicPPL.LogDensityFunction(model) + + xs = [1.0] + @test LogDensityProblems.logdensity(ldf, xs) ≈ + logpdf(Normal(), xs[1]) + N * logpdf(Normal(xs[1]), 0.0) end end @@ -125,34 +123,32 @@ end end @testset "LogDensityFunction: performance" begin - if Threads.nthreads() == 1 - # Evaluating these three models should not lead to any allocations (but only when - # not using TSVI). - @model function f() - x ~ Normal() - return 1.0 ~ Normal(x) - end - @model function submodel_inner() - m ~ Normal(0, 1) - s ~ Exponential() - return (m=m, s=s) - end - # Note that for the allocation tests to work on this one, `inner` has - # to be passed as an argument to `submodel_outer`, instead of just - # being called inside the model function itself - @model function submodel_outer(inner) - params ~ to_submodel(inner) - y ~ Normal(params.m, params.s) - return 1.0 ~ Normal(y) - end - @testset for model in - (f(), submodel_inner() | (; s=0.0), submodel_outer(submodel_inner())) - vi = VarInfo(model) - ldf = DynamicPPL.LogDensityFunction(model, DynamicPPL.getlogjoint_internal, vi) - x = vi[:] - bench = median(@be LogDensityProblems.logdensity(ldf, x)) - @test iszero(bench.allocs) - end + # Evaluating these three models should not lead to any allocations (but only when + # not using TSVI). + @model function f() + x ~ Normal() + return 1.0 ~ Normal(x) + end + @model function submodel_inner() + m ~ Normal(0, 1) + s ~ Exponential() + return (m=m, s=s) + end + # Note that for the allocation tests to work on this one, `inner` has + # to be passed as an argument to `submodel_outer`, instead of just + # being called inside the model function itself + @model function submodel_outer(inner) + params ~ to_submodel(inner) + y ~ Normal(params.m, params.s) + return 1.0 ~ Normal(y) + end + @testset for model in + (f(), submodel_inner() | (; s=0.0), submodel_outer(submodel_inner())) + vi = VarInfo(model) + ldf = DynamicPPL.LogDensityFunction(model, DynamicPPL.getlogjoint_internal, vi) + x = vi[:] + bench = median(@be LogDensityProblems.logdensity($ldf, $x)) + @test iszero(bench.allocs) end end diff --git a/test/threadsafe.jl b/test/threadsafe.jl index 522730566..879e936d6 100644 --- a/test/threadsafe.jl +++ b/test/threadsafe.jl @@ -5,13 +5,23 @@ @test threadsafe_vi.varinfo === vi @test threadsafe_vi.accs_by_thread isa Vector{<:DynamicPPL.AccumulatorTuple} - @test length(threadsafe_vi.accs_by_thread) == Threads.nthreads() * 2 + @test length(threadsafe_vi.accs_by_thread) == Threads.maxthreadid() expected_accs = DynamicPPL.AccumulatorTuple( (DynamicPPL.split(acc) for acc in DynamicPPL.getaccs(vi))... ) @test all(accs == expected_accs for accs in threadsafe_vi.accs_by_thread) end + @testset "setthreadsafe" begin + @model f() = x ~ Normal() + model = f() + @test !DynamicPPL.requires_threadsafe(model) + model = setthreadsafe(model, true) + @test DynamicPPL.requires_threadsafe(model) + model = setthreadsafe(model, false) + @test !DynamicPPL.requires_threadsafe(model) + end + # TODO: Add more tests of the public API @testset "API" begin vi = VarInfo(gdemo_default) @@ -41,8 +51,6 @@ end @testset "model" begin - println("Peforming threading tests with $(Threads.nthreads()) threads") - x = rand(10_000) @model function wthreads(x) @@ -52,63 +60,24 @@ x[i] ~ Normal(x[i - 1], 1) end end - model = wthreads(x) - - vi = VarInfo() - model(vi) - lp_w_threads = getlogjoint(vi) - if Threads.nthreads() == 1 - @test vi_ isa VarInfo - else - @test vi_ isa DynamicPPL.ThreadSafeVarInfo - end - - println("With `@threads`:") - println(" default:") - @time model(vi) - - # Ensure that we use `ThreadSafeVarInfo` to handle multithreaded observe statements. - DynamicPPL.evaluate_threadsafe!!(model, vi) - @test getlogjoint(vi) ≈ lp_w_threads - # check that it's wrapped during the model evaluation - @test vi_ isa DynamicPPL.ThreadSafeVarInfo - # ensure that it's unwrapped after evaluation finishes - @test vi isa VarInfo + model = setthreadsafe(wthreads(x), true) - println(" evaluate_threadsafe!!:") - @time DynamicPPL.evaluate_threadsafe!!(model, vi) - - @model function wothreads(x) - global vi_ = __varinfo__ - x[1] ~ Normal(0, 1) + function correct_lp(x) + lp = logpdf(Normal(0, 1), x[1]) for i in 2:length(x) - x[i] ~ Normal(x[i - 1], 1) + lp += logpdf(Normal(x[i - 1], 1), x[i]) end + return lp end - model = wothreads(x) vi = VarInfo() - model(vi) - lp_wo_threads = getlogjoint(vi) - if Threads.nthreads() == 1 - @test vi_ isa VarInfo - else - @test vi_ isa DynamicPPL.ThreadSafeVarInfo - end + _, vi = DynamicPPL.evaluate!!(model, vi) - println("Without `@threads`:") - println(" default:") - @time model(vi) - - @test lp_w_threads ≈ lp_wo_threads - - # Ensure that we use `VarInfo`. - DynamicPPL.evaluate_threadunsafe!!(model, vi) - @test getlogjoint(vi) ≈ lp_w_threads - @test vi_ isa VarInfo + # check that logp is correct + @test getlogjoint(vi) ≈ correct_lp(x) + # check that varinfo was wrapped during the model evaluation + @test vi_ isa DynamicPPL.ThreadSafeVarInfo + # ensure that it's unwrapped after evaluation finishes @test vi isa VarInfo - - println(" evaluate_threadunsafe!!:") - @time DynamicPPL.evaluate_threadunsafe!!(model, vi) end end From 54ae7e30df1cfaa1f69202c8637d58afce55134d Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Mon, 1 Dec 2025 18:58:53 +0000 Subject: [PATCH 41/45] Standardise `:lp` -> `:logjoint` (#1161) * Standardise `:lp` -> `:logjoint` * changelog * fix a test --- HISTORY.md | 4 ++++ src/chains.jl | 4 ++-- test/chains.jl | 8 ++++---- test/ext/DynamicPPLMCMCChainsExt.jl | 2 +- 4 files changed, 11 insertions(+), 7 deletions(-) diff --git a/HISTORY.md b/HISTORY.md index c15b4136a..48f2efb0e 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -65,6 +65,10 @@ Leaf contexts require no changes, apart from a removal of the `NodeTrait` functi `ConditionContext` and `PrefixContext` are no longer exported. You should not need to use these directly, please use `AbstractPPL.condition` and `DynamicPPL.prefix` instead. +#### ParamsWithStats + +In the 'stats' part of `DynamicPPL.ParamsWithStats`, the log-joint is now consistently represented with the key `logjoint` instead of `lp`. + #### Miscellaneous Removed the method `returned(::Model, values, keys)`; please use `returned(::Model, ::AbstractDict{<:VarName})` instead. diff --git a/src/chains.jl b/src/chains.jl index d01606c3d..8ce4979c6 100644 --- a/src/chains.jl +++ b/src/chains.jl @@ -57,7 +57,7 @@ function ParamsWithStats( ( logprior=DynamicPPL.getlogprior(varinfo), loglikelihood=DynamicPPL.getloglikelihood(varinfo), - lp=DynamicPPL.getlogjoint(varinfo), + logjoint=DynamicPPL.getlogjoint(varinfo), ), ) end @@ -158,7 +158,7 @@ function ParamsWithStats( ( logprior=DynamicPPL.getlogprior(vi), loglikelihood=DynamicPPL.getloglikelihood(vi), - lp=DynamicPPL.getlogjoint(vi), + logjoint=DynamicPPL.getlogjoint(vi), ), ) end diff --git a/test/chains.jl b/test/chains.jl index 12a9ece71..498e2e912 100644 --- a/test/chains.jl +++ b/test/chains.jl @@ -20,9 +20,9 @@ using Test @test length(ps.params) == 2 @test haskey(ps.stats, :logprior) @test haskey(ps.stats, :loglikelihood) - @test haskey(ps.stats, :lp) + @test haskey(ps.stats, :logjoint) @test length(ps.stats) == 3 - @test ps.stats.lp ≈ ps.stats.logprior + ps.stats.loglikelihood + @test ps.stats.logjoint ≈ ps.stats.logprior + ps.stats.loglikelihood @test ps.params[@varname(y)] ≈ ps.params[@varname(x)] + 1 @test ps.stats.logprior ≈ logpdf(Normal(), ps.params[@varname(x)]) @test ps.stats.loglikelihood ≈ logpdf(Normal(ps.params[@varname(y)]), z) @@ -34,9 +34,9 @@ using Test @test length(ps.params) == 1 @test haskey(ps.stats, :logprior) @test haskey(ps.stats, :loglikelihood) - @test haskey(ps.stats, :lp) + @test haskey(ps.stats, :logjoint) @test length(ps.stats) == 3 - @test ps.stats.lp ≈ ps.stats.logprior + ps.stats.loglikelihood + @test ps.stats.logjoint ≈ ps.stats.logprior + ps.stats.loglikelihood @test ps.stats.logprior ≈ logpdf(Normal(), ps.params[@varname(x)]) @test ps.stats.loglikelihood ≈ logpdf(Normal(ps.params[@varname(x)] + 1), z) end diff --git a/test/ext/DynamicPPLMCMCChainsExt.jl b/test/ext/DynamicPPLMCMCChainsExt.jl index 6091492df..445270ef8 100644 --- a/test/ext/DynamicPPLMCMCChainsExt.jl +++ b/test/ext/DynamicPPLMCMCChainsExt.jl @@ -20,7 +20,7 @@ using DynamicPPL, Distributions, MCMCChains, Test, AbstractMCMC @test size(c, 1) == 50 @test size(c, 3) == 3 @test Set(c.name_map.parameters) == Set([:x, :y]) - @test Set(c.name_map.internals) == Set([:logprior, :loglikelihood, :lp]) + @test Set(c.name_map.internals) == Set([:logprior, :loglikelihood, :logjoint]) @test logpdf.(Normal(), c[:x]) ≈ c[:logprior] @test c.info.varname_to_symbol[@varname(x)] == :x @test c.info.varname_to_symbol[@varname(y)] == :y From 384e3ac398784af123f6e3596d83aab61a754a96 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Wed, 3 Dec 2025 17:36:55 +0000 Subject: [PATCH 42/45] Apply suggestions from code review Co-authored-by: Penelope Yong --- src/varnamedtuple.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/varnamedtuple.jl b/src/varnamedtuple.jl index d3f2ba13a..9efc1efa2 100644 --- a/src/varnamedtuple.jl +++ b/src/varnamedtuple.jl @@ -132,7 +132,7 @@ julia> pa = PartialArray{Int,2}((1,2) => 5, (3,4) => 10) PartialArray{Int64,2}((1, 2) => 5, (3, 4) => 10) ``` -The optional keywoard argument `min_size` can be used to specify the minimum initial size. +The optional keyword argument `min_size` can be used to specify the minimum initial size. This is purely a performance optimisation, to avoid resizing if the eventual size is known ahead of time. """ @@ -521,9 +521,9 @@ end """ varname_to_lens(name::VarName{S}) where {S} -Convert a `VarName` to an `Accessor` lens, wrapping the first symdol in a `PropertyLens`. +Convert a `VarName` to an `Accessor` lens, wrapping the first symbol in a `PropertyLens`. -This is used to simplify method dispatch for `_getindx`, `_setindex!!`, and `_haskey`, by +This is used to simplify method dispatch for `_getindex`, `_setindex!!`, and `_haskey`, by considering `VarName`s to just be a special case of lenses. """ function _varname_to_lens(name::VarName{S}) where {S} From 9d61a54f5dc86ae94b597e068fcf60b0910c0194 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Wed, 3 Dec 2025 17:41:46 +0000 Subject: [PATCH 43/45] Add a microoptimisation --- src/varnamedtuple.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/varnamedtuple.jl b/src/varnamedtuple.jl index 9efc1efa2..f9040e6ad 100644 --- a/src/varnamedtuple.jl +++ b/src/varnamedtuple.jl @@ -304,8 +304,8 @@ function _resize_partialarray!!(pa::PartialArray, inds) # may use a linear index that does not match between the old and the new arrays. @inbounds for i in CartesianIndices(pa.data) mask_val = pa.mask[i] - new_mask[i] = mask_val if mask_val + new_mask[i] = mask_val new_data[i] = pa.data[i] end end From 8c8e39f98519b7765c1f1c1b3de9f11e675ea62b Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Wed, 3 Dec 2025 17:49:31 +0000 Subject: [PATCH 44/45] Improve docstrings --- src/varnamedtuple.jl | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/src/varnamedtuple.jl b/src/varnamedtuple.jl index f9040e6ad..881fde767 100644 --- a/src/varnamedtuple.jl +++ b/src/varnamedtuple.jl @@ -474,14 +474,16 @@ is `merge`, which recursively merges two `VarNamedTuple`s. The there are two major limitations to indexing by VarNamedTuples: -* `VarName`s with `Colon`s, (e.g. `a[:]`) are not supported. This is because the meaning of `a[:]` is ambiguous if only some elements of `a`, say `a[1]` and `a[3]`, are defined. -* Any `VarNames` with IndexLenses` must have a consistent number of indices. That is, one cannot set `a[1]` and `a[1,2]` in the same `VarNamedTuple`. +* `VarName`s with `Colon`s, (e.g. `a[:]`) are not supported. This is because the meaning of + `a[:]` is ambiguous if only some elements of `a`, say `a[1]` and `a[3]`, are defined. +* Any `VarNames` with IndexLenses` must have a consistent number of indices. That is, one + cannot set `a[1]` and `a[1,2]` in the same `VarNamedTuple`. `setindex!!` and `getindex` on `VarNamedTuple` are type stable as long as one does not store -heterogeneous data under different indices of the same symbol. That is, if one either +heterogeneous data under different indices of the same symbol. That is, if either -* sets `a[1]` and `a[2]` to be of different types, or -* sets `a[1].b` and `a[2].c`, without setting `a[1].c`. or `a[2].b`, +* one sets `a[1]` and `a[2]` to be of different types, or +* if `a[1]` and `a[2]` both exist, one sets `a[1].b` without setting `a[2].b`, then getting values for `a[1]` or `a[2]` will not be type stable. From c818bf887ab474d15944394e62237f8be29b0f32 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Wed, 3 Dec 2025 18:01:15 +0000 Subject: [PATCH 45/45] Simplify use of QuoteNodes --- src/varnamedtuple.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/varnamedtuple.jl b/src/varnamedtuple.jl index 881fde767..1ca75d343 100644 --- a/src/varnamedtuple.jl +++ b/src/varnamedtuple.jl @@ -564,11 +564,11 @@ Base.merge(x1::VarNamedTuple, x2::VarNamedTuple) = _merge_recursive(x1, x2) push!(exs, :(data = (;))) for name in all_names val_expr = if name in names1 && name in names2 - :(_merge_recursive(vnt1.data[$(QuoteNode(name))], vnt2.data[$(QuoteNode(name))])) + :(_merge_recursive(vnt1.data.$name, vnt2.data.$name)) elseif name in names1 - :(vnt1.data[$(QuoteNode(name))]) + :(vnt1.data.$name) else - :(vnt2.data[$(QuoteNode(name))]) + :(vnt2.data.$name) end push!(exs, :(data = merge(data, NamedTuple{($(QuoteNode(name)),)}(($val_expr,))))) end