diff --git a/docs/src/api.md b/docs/src/api.md index 193a6ce4c..acbfaa9a3 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -372,6 +372,12 @@ Base.empty! SimpleVarInfo ``` +#### `VarNamedTuple` + +```@docs +DynamicPPL.VarNamedTuples.VarNamedTuple +``` + ### Accumulators The subtypes of [`AbstractVarInfo`](@ref) store the cumulative log prior and log likelihood, and sometimes other variables that change during executing, in what are called accumulators. diff --git a/docs/src/internals/varnamedtuple.md b/docs/src/internals/varnamedtuple.md new file mode 100644 index 000000000..47ff9c65e --- /dev/null +++ b/docs/src/internals/varnamedtuple.md @@ -0,0 +1,156 @@ +# `VarNamedTuple` + +In DynamicPPL there is often a need to store data keyed by `VarName`s. +This comes up when getting conditioned variable values from the user, when tracking values of random variables in the model outputs or inputs, etc. +Historically we've had several different approaches to this: Dictionaries, `NamedTuple`s, vectors with subranges corresponding to different `VarName`s, and various combinations thereof. + +To unify the treatment of these use cases, and handle them all in a robust and performant way, is the purpose of `VarNamedTuple`, aka VNT. +It's a data structure that can store arbitrary data, indexed by (nearly) arbitrary `VarName`s, in a type stable and performant manner. + +`VarNamedTuple` consists of nested `NamedTuple`s and `PartialArray`s. +Let's first talk about the `NamedTuple` part. +This is what is needed for handling `PropertyLens`es in `VarName`s, that is, `VarName`s consisting of nested symbols, like in `@varname(a.b.c)`. +In a `VarNamedTuple` each level of such nesting of `PropertyLens`es corresponds to a level of nested `NamedTuple`s, with the `Symbol`s of the lenses as keys. +For instance, the `VarNamedTuple` mapping `@varname(x) => 1, @varname(y.z) => 2` would be stored as + +``` +VarNamedTuple(; x=1, y=VarNamedTuple(; z=2)) +``` + +where `VarNamedTuple(; x=a, y=b)` is just a thin wrapper around the `NamedTuple` `(; x=a, y=b)`. + +It's often handy to think of this as a tree, with each node being a `VarNamedTuple`, like so: + +``` + VNT +x / \ y + 1 VNT + \ z + 2 +``` + +If all `VarName`s consisted of only `PropertyLens`es we would be done designing the data structure. +However, recall that `VarName`s allow three different kinds of lenses: `PropertyLens`es, `IndexLens`es, and `identity` (the trivial lens). +The `identity` lens presents no complications, and in fact in the above example there was an implicit identity lens in e.g. `@varname(x) => 1`. +It is the `IndexLenses` that require more structure. + +An `IndexLens` is the square bracket indexing part in `VarName`s like `@varname(x[1])`, `@varname(x[1].a.b[2:3])` and `@varname(x[:].b[1,2,3].c[1:5,:])`. +`VarNamedTuple` cannot deal with `IndexLens`es in their full generality, for reasons we'll discuss below. +Instead we restrict ourselves to `IndexLens`es where the indices are integers, explicit ranges with end points, like `1:5`, or tuples thereof. + +When storing data in a `VarNamedTuple`, we recursively go through the nested lenses in the `VarName`, inserting a new `VarNamedTuple` for every `PropertyLens`. +When we meet an `IndexLens`, we instead instert into the tree something called a `PartialArray`. + +A `PartialArray` is like a regular `Base.Array`, but with some elements possibly unset. +It is a data structure we define ourselves for use within `VarNamedTuple`s. +A `PartialArray` has an element type and a number of dimensions, and they are known at compile time, but it does not have a size, and thus is not an `AbstractArray`. +This is because if we set the elements `x[1,2]` and `x[14,10]` in a `PartialArray` called `x`, this does not mean that 14 and 10 are the ends of their respective dimensions. +The typical use of this structure in DynamicPPL is that the user may define values for elements in an array-like structure one by one, and we do not always know how large these arrays are. + +This is also the reason why `PartialArray`, and by extension `VarNamedTuple`, do not support indexing by `Colon()`, i.e. `:`, as in `x[:]`. +A `Colon()` says that we should get or set all the values along that dimension, but a `PartialArray` does not know how many values there may be. +If `x[1]` and `x[4]` have been set, asking for `x[:]` is not a well-posed question. + +`PartialArray`s have other restrictions, compared to the full indexing syntax of Julia, as well: +They do not support linearly indexing into multidimemensional arrays (as in `rand(3,3)[8]`), nor indexing with arrays of indices (as in `rand(4)[[1,3]]`), nor indexing with boolean mask arrays (as in `rand(4)[[true, false, true, false]]`). +This is mostly because we haven't seen a need to support them, and implementing them would complicate the codebase for little gain. +We may add support for them later if needed. + +`PartialArray`s can hold any values, just like `Base.Array`s, and in particular they can hold `VarNamedTuple`s. +Thus we nest them with `VarNamedTuple`s to support storing `VarName`s with arbitrary combinations of `PropertyLens`es and `IndexLens`es. +A code example illustrates this the best: + +```julia +julia> vnt = VarNamedTuple(); + +julia> vnt = setindex!!(vnt, 1.0, @varname(a)); + +julia> vnt = setindex!!(vnt, [2.0, 3.0], @varname(b.c)); + +julia> vnt = setindex!!(vnt, [:hip, :hop], @varname(d.e[2].f[3:4])); + +julia> print(vnt) +VarNamedTuple(; a=1.0, b=VarNamedTuple(; c=[2.0, 3.0]), d=VarNamedTuple(; e=PartialArray{VarNamedTuple{(:f,), Tuple{DynamicPPL.VarNamedTuples.PartialArray{Symbol, 1}}},1}((2,) => VarNamedTuple(; f=PartialArray{Symbol,1}((3,) => hip, (4,) => hop))))) +``` + +The output there may be a bit hard to parse, so to illustrate: + +```julia +julia> vnt[@varname(b)] +VarNamedTuple(; c=[2.0, 3.0]) + +julia> vnt[@varname(b.c[1])] +2.0 + +julia> vnt[@varname(d.e)] +PartialArray{VarNamedTuple{(:f,), Tuple{DynamicPPL.VarNamedTuples.PartialArray{Symbol, 1}}},1}((2,) => VarNamedTuple(; f=PartialArray{Symbol,1}((3,) => hip, (4,) => hop))) + +julia> vnt[@varname(d.e[2].f)] +PartialArray{Symbol,1}((3,) => hip, (4,) => hop) +``` + +Or as a tree drawing, where `PA` marks a `PartialArray`: + +``` + /----VNT------\ +a / | b \ d + 1 VNT VNT + | c | e + [2.0, 3.0] PA(2 => VNT) + | f + PA(3 => :hip, 4 => :hop) +``` + +The above code also highlights how setting indices in a `VarNamedTuple` is done using `BangBang.setindex!!`. +We do not define a method for `Base.setindex!` at all, `setindex!!` is the only way. +This is because `VarNamedTuple` mixes mutable and immutable data structures. +It is also for user convenience: +One does not ever have to think about whether the value that one is inserting into a `VarNamedTuple` is of the right type to fit in. +Rather the containers will flex to fit it, keeping element types concrete when possible, but making them abstract if needed. +`VarNamedTuple`, or more precisely `PartialArray`, even explicitly concretises element types whenever possible. +For instance, one can make an abstractly typed `VarNamedTuple` like so: + +```julia +julia> vnt = VarNamedTuple(); + +julia> vnt = setindex!!(vnt, 1.0, @varname(a[1])); + +julia> vnt = setindex!!(vnt, "hello", @varname(a[2])); + +julia> print(vnt) +VarNamedTuple(; a=PartialArray{Any,1}((1,) => 1.0, (2,) => hello)) +``` + +Note the element type of `PartialArray{Any}`. +But if one changes the values to make them homogeneous, the element type is automatically made concrete again: + +```julia +julia> vnt = setindex!!(vnt, "me here", @varname(a[1])); + +julia> print(vnt) +VarNamedTuple(; a=PartialArray{String,1}((1,) => me here, (2,) => hello)) +``` + +This approach is at the core of why `VarNamedTuple` is performant: +As long as one does not store inhomogeneous types within a single `PartialArray`, by assigning different types to `VarName`s like `@varname(a[1])` and `@varname(a[2])`, different variables in a `VarNamedTuple` can have different types, and all `getindex` and `setindex!!` operations remain type stable. +Note that assigning a value to `@varname(a[1].b)` but not to `@varname(a[2].b)` has the same effect as assigning values of different types to `@varname(a[1])` and `@varname(a[2])`, and also causes a loss of type stability for for `getindex` and `setindex!!`. +Although, this only affects `getindex` and `setindex!!` on sub-`VarName`s of `@varname(a)`; +You can still use the same `VarNamedTuple` to store information about an unrelated `@varname(c)` with stability. + +Note that if you `setindex!!` a new value into a `VarNamedTuple` with an `IndexLens`, this causes a `PartialArray` to be created. +However, if there already is a regular `Base.Array` stored in a `VarNamedTuple`, you can index into it with `IndexLens`es without involving `PartialArray`s. +That is, if you do `vnt = setindex!!(vnt, @varname(a), [1.0, 2.0])`, you can then either get the values with e.g. `vnt[@varname(a[1])`, which returns 1.0. +You can also set the elements with `vnt = setindex!!(vnt, @varname(a[1]), 3.0)`, and this will modify the existing `Base.Array`. +At this point you can not set any new values in that array that would be outside of its range, with something like `vnt = setindex!!(vnt, @varname(a[5]), 5.0)`. +The philosophy here is that once a `Base.Array` has been attached to a `VarName`, that takes precedence, and a `PartialArray` is only used as a fallback when we are told to store a value for `@varname(a[i])` without having any previous knowledge about what `@varname(a)` is. + +## Limitations + +This design has a several of benefits, for performance and generality, but it also has limitations: + + 1. The lack of support for `Colon`s in `VarName`s. + 2. The lack of support for some other indexing syntaxes supported by Julia, such as linear indexing and boolean indexing. + 3. `VarNamedTuple` cannot store indices with different numbers of dimensions in the same value, so for instance `@varname(a[1])` and `@varname(a[1,1])` cannot be stored in the same `VarNamedTuple`. + 4. There is an asymmetry between storing arrays with `setindex!!(vnt, array, @varname(a))` and elements of arrays with `setindex!!(vnt, element, @varname(a[i]))`. + The former stores the whole array, which can then be indexed with both `@varname(a)` and `@varname(a[i])`. + The latter stores only individual elements, and even if all elements have been set, one still can't get the value associated with `@varname(a)` as a regular `Base.Array`. diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index fda428eaa..25ca59018 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -184,6 +184,8 @@ abstract type AbstractVarInfo <: AbstractModelTrace end # Necessary forward declarations include("utils.jl") +include("varnamedtuple.jl") +using .VarNamedTuples: VarNamedTuple include("contexts.jl") include("contexts/default.jl") include("contexts/init.jl") diff --git a/src/chains.jl b/src/chains.jl index 8ce4979c6..319579a9c 100644 --- a/src/chains.jl +++ b/src/chains.jl @@ -136,10 +136,7 @@ function ParamsWithStats( include_log_probs::Bool=true, ) where {Tlink} strategy = InitFromParams( - VectorWithRanges{Tlink}( - ldf._iden_varname_ranges, ldf._varname_ranges, param_vector - ), - nothing, + VectorWithRanges{Tlink}(ldf._varname_ranges, param_vector), nothing ) accs = if include_log_probs ( diff --git a/src/contexts/init.jl b/src/contexts/init.jl index 80a494c23..90394a24c 100644 --- a/src/contexts/init.jl +++ b/src/contexts/init.jl @@ -215,8 +215,7 @@ end """ VectorWithRanges{Tlink}( - iden_varname_ranges::NamedTuple, - varname_ranges::Dict{VarName,RangeAndLinked}, + varname_ranges::VarNamedTuple, vect::AbstractVector{<:Real}, ) @@ -228,26 +227,14 @@ this `VectorWithRanges` are linked/not linked, or `nothing` if either the linkin not known or is mixed, i.e. some are linked while others are not. Using `nothing` does not affect functionality or correctness, but causes more work to be done at runtime, with possible impacts on type stability and performance. - -In the simplest case, this could be accomplished only with a single dictionary mapping -VarNames to ranges and link status. However, for performance reasons, we separate out -VarNames with identity optics into a NamedTuple (`iden_varname_ranges`). All -non-identity-optic VarNames are stored in the `varname_ranges` Dict. - -It would be nice to improve the NamedTuple and Dict approach. See, e.g. -https://github.com/TuringLang/DynamicPPL.jl/issues/1116. """ -struct VectorWithRanges{Tlink,N<:NamedTuple,T<:AbstractVector{<:Real}} - # This NamedTuple stores the ranges for identity VarNames - iden_varname_ranges::N - # This Dict stores the ranges for all other VarNames - varname_ranges::Dict{VarName,RangeAndLinked} +struct VectorWithRanges{Tlink,VNT<:VarNamedTuple,T<:AbstractVector{<:Real}} + # Ranges for all VarNames + varname_ranges::VNT # The full parameter vector which we index into to get variable values vect::T - function VectorWithRanges{Tlink}( - iden_varname_ranges::N, varname_ranges::Dict{VarName,RangeAndLinked}, vect::T - ) where {Tlink,N,T} + function VectorWithRanges{Tlink}(varname_ranges::VNT, vect::T) where {Tlink,VNT,T} if !(Tlink isa Union{Bool,Nothing}) throw( ArgumentError( @@ -255,15 +242,10 @@ struct VectorWithRanges{Tlink,N<:NamedTuple,T<:AbstractVector{<:Real}} ), ) end - return new{Tlink,N,T}(iden_varname_ranges, varname_ranges, vect) + return new{Tlink,VNT,T}(varname_ranges, vect) end end -function _get_range_and_linked( - vr::VectorWithRanges, ::VarName{sym,typeof(identity)} -) where {sym} - return vr.iden_varname_ranges[sym] -end function _get_range_and_linked(vr::VectorWithRanges, vn::VarName) return vr.varname_ranges[vn] end diff --git a/src/logdensityfunction.jl b/src/logdensityfunction.jl index 3008a329b..47b49a277 100644 --- a/src/logdensityfunction.jl +++ b/src/logdensityfunction.jl @@ -90,10 +90,11 @@ Up until DynamicPPL v0.38, there have been two ways of evaluating a DynamicPPL m given set of parameters: 1. With `unflatten` + `evaluate!!` with `DefaultContext`: this stores a vector of parameters - inside a VarInfo's metadata, then reads parameter values from the VarInfo during evaluation. + inside a VarInfo's metadata, then reads parameter values from the VarInfo during + evaluation. -2. With `InitFromParams`: this reads parameter values from a NamedTuple or a Dict, and stores - them inside a VarInfo's metadata. +2. With `InitFromParams`: this reads parameter values from a NamedTuple or a Dict, and + stores them inside a VarInfo's metadata. In general, both of these approaches work fine, but the fact that they modify the VarInfo's metadata can often be quite wasteful. In particular, it is very common that the only outputs @@ -123,14 +124,9 @@ the VarInfo_ a single time when constructing a `LogDensityFunction` object. Insi LogDensityFunction, we store a mapping from VarNames to ranges in that vector, along with link status. -For VarNames with identity optics, this is stored in a NamedTuple for efficiency. For all -other VarNames, this is stored in a Dict. The internal data structure used to represent this -could almost certainly be optimised further. See e.g. the discussion in -https://github.com/TuringLang/DynamicPPL.jl/issues/1116. - -When evaluating the model, this allows us to combine the parameter vector together with those -ranges to create an `InitFromParams{VectorWithRanges}`, which lets us very quickly read -parameter values from the vector. +When evaluating the model, this allows us to combine the parameter vector together with +those ranges to create an `InitFromParams{VectorWithRanges}`, which lets us very quickly +read parameter values from the vector. Note that this assumes that the ranges and link status are static throughout the lifetime of the `LogDensityFunction` object. Therefore, a `LogDensityFunction` object cannot handle @@ -146,7 +142,7 @@ struct LogDensityFunction{ M<:Model, AD<:Union{ADTypes.AbstractADType,Nothing}, F<:Function, - N<:NamedTuple, + VNT<:VarNamedTuple, ADP<:Union{Nothing,DI.GradientPrep}, # type of the vector passed to logdensity functions X<:AbstractVector, @@ -154,8 +150,7 @@ struct LogDensityFunction{ model::M adtype::AD _getlogdensity::F - _iden_varname_ranges::N - _varname_ranges::Dict{VarName,RangeAndLinked} + _varname_ranges::VNT _adprep::ADP _dim::Int @@ -167,14 +162,11 @@ struct LogDensityFunction{ ) # Figure out which variable corresponds to which index, and # which variables are linked. - all_iden_ranges, all_ranges = get_ranges_and_linked(varinfo) + all_ranges = get_ranges_and_linked(varinfo) # Figure out if all variables are linked, unlinked, or mixed link_statuses = Bool[] - for ral in all_iden_ranges - push!(link_statuses, ral.is_linked) - end - for (_, ral) in all_ranges - push!(link_statuses, ral.is_linked) + for vn in keys(all_ranges) + push!(link_statuses, all_ranges[vn].is_linked) end Tlink = if all(link_statuses) true @@ -192,9 +184,7 @@ struct LogDensityFunction{ # Make backend-specific tweaks to the adtype adtype = DynamicPPL.tweak_adtype(adtype, model, varinfo) DI.prepare_gradient( - LogDensityAt{Tlink}(model, getlogdensity, all_iden_ranges, all_ranges), - adtype, - x, + LogDensityAt{Tlink}(model, getlogdensity, all_ranges), adtype, x ) end return new{ @@ -202,11 +192,11 @@ struct LogDensityFunction{ typeof(model), typeof(adtype), typeof(getlogdensity), - typeof(all_iden_ranges), + typeof(all_ranges), typeof(prep), typeof(x), }( - model, adtype, getlogdensity, all_iden_ranges, all_ranges, prep, dim + model, adtype, getlogdensity, all_ranges, prep, dim ) end end @@ -235,25 +225,19 @@ end ldf_accs(::typeof(getlogprior)) = AccumulatorTuple((LogPriorAccumulator(),)) ldf_accs(::typeof(getloglikelihood)) = AccumulatorTuple((LogLikelihoodAccumulator(),)) -struct LogDensityAt{Tlink,M<:Model,F<:Function,N<:NamedTuple} +struct LogDensityAt{Tlink,M<:Model,F<:Function,VNT<:VarNamedTuple} model::M getlogdensity::F - iden_varname_ranges::N - varname_ranges::Dict{VarName,RangeAndLinked} + varname_ranges::VNT function LogDensityAt{Tlink}( - model::M, - getlogdensity::F, - iden_varname_ranges::N, - varname_ranges::Dict{VarName,RangeAndLinked}, + model::M, getlogdensity::F, varname_ranges::N ) where {Tlink,M,F,N} - return new{Tlink,M,F,N}(model, getlogdensity, iden_varname_ranges, varname_ranges) + return new{Tlink,M,F,N}(model, getlogdensity, varname_ranges) end end function (f::LogDensityAt{Tlink})(params::AbstractVector{<:Real}) where {Tlink} - strategy = InitFromParams( - VectorWithRanges{Tlink}(f.iden_varname_ranges, f.varname_ranges, params), nothing - ) + strategy = InitFromParams(VectorWithRanges{Tlink}(f.varname_ranges, params), nothing) accs = ldf_accs(f.getlogdensity) _, vi = DynamicPPL.init!!(f.model, OnlyAccsVarInfo(accs), strategy) return f.getlogdensity(vi) @@ -262,11 +246,7 @@ end function LogDensityProblems.logdensity( ldf::LogDensityFunction{Tlink}, params::AbstractVector{<:Real} ) where {Tlink} - return LogDensityAt{Tlink}( - ldf.model, ldf._getlogdensity, ldf._iden_varname_ranges, ldf._varname_ranges - )( - params - ) + return LogDensityAt{Tlink}(ldf.model, ldf._getlogdensity, ldf._varname_ranges)(params) end function LogDensityProblems.logdensity_and_gradient( @@ -274,9 +254,7 @@ function LogDensityProblems.logdensity_and_gradient( ) where {Tlink} params = convert(_get_input_vector_type(ldf), params) return DI.value_and_gradient( - LogDensityAt{Tlink}( - ldf.model, ldf._getlogdensity, ldf._iden_varname_ranges, ldf._varname_ranges - ), + LogDensityAt{Tlink}(ldf.model, ldf._getlogdensity, ldf._varname_ranges), ldf._adprep, ldf.adtype, params, @@ -329,62 +307,42 @@ tweak_adtype(adtype::ADTypes.AbstractADType, ::Model, ::AbstractVarInfo) = adtyp Given a `VarInfo`, extract the ranges of each variable in the vectorised parameter representation, along with whether each variable is linked or unlinked. -This function should return a tuple containing: - -- A NamedTuple mapping VarNames with identity optics to their corresponding `RangeAndLinked` -- A Dict mapping all other VarNames to their corresponding `RangeAndLinked`. +This function returns a VarNamedTuple mapping all VarNames to their corresponding +`RangeAndLinked`. """ function get_ranges_and_linked(varinfo::VarInfo{<:NamedTuple{syms}}) where {syms} - all_iden_ranges = NamedTuple() - all_ranges = Dict{VarName,RangeAndLinked}() + all_ranges = VarNamedTuple() offset = 1 for sym in syms md = varinfo.metadata[sym] - this_md_iden, this_md_others, offset = get_ranges_and_linked_metadata(md, offset) - all_iden_ranges = merge(all_iden_ranges, this_md_iden) + this_md_others, offset = get_ranges_and_linked_metadata(md, offset) all_ranges = merge(all_ranges, this_md_others) end - return all_iden_ranges, all_ranges + return all_ranges end function get_ranges_and_linked(varinfo::VarInfo{<:Union{Metadata,VarNamedVector}}) - all_iden, all_others, _ = get_ranges_and_linked_metadata(varinfo.metadata, 1) - return all_iden, all_others + all_ranges, _ = get_ranges_and_linked_metadata(varinfo.metadata, 1) + return all_ranges end function get_ranges_and_linked_metadata(md::Metadata, start_offset::Int) - all_iden_ranges = NamedTuple() - all_ranges = Dict{VarName,RangeAndLinked}() + all_ranges = VarNamedTuple() offset = start_offset for (vn, idx) in md.idcs is_linked = md.is_transformed[idx] range = md.ranges[idx] .+ (start_offset - 1) - if AbstractPPL.getoptic(vn) === identity - all_iden_ranges = merge( - all_iden_ranges, - NamedTuple((AbstractPPL.getsym(vn) => RangeAndLinked(range, is_linked),)), - ) - else - all_ranges[vn] = RangeAndLinked(range, is_linked) - end + all_ranges = BangBang.setindex!!(all_ranges, RangeAndLinked(range, is_linked), vn) offset += length(range) end - return all_iden_ranges, all_ranges, offset + return all_ranges, offset end function get_ranges_and_linked_metadata(vnv::VarNamedVector, start_offset::Int) - all_iden_ranges = NamedTuple() - all_ranges = Dict{VarName,RangeAndLinked}() + all_ranges = VarNamedTuple() offset = start_offset for (vn, idx) in vnv.varname_to_index is_linked = vnv.is_unconstrained[idx] range = vnv.ranges[idx] .+ (start_offset - 1) - if AbstractPPL.getoptic(vn) === identity - all_iden_ranges = merge( - all_iden_ranges, - NamedTuple((AbstractPPL.getsym(vn) => RangeAndLinked(range, is_linked),)), - ) - else - all_ranges[vn] = RangeAndLinked(range, is_linked) - end + all_ranges = BangBang.setindex!!(all_ranges, RangeAndLinked(range, is_linked), vn) offset += length(range) end - return all_iden_ranges, all_ranges, offset + return all_ranges, offset end diff --git a/src/utils.jl b/src/utils.jl index 75fb805dc..fe2879182 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -949,3 +949,19 @@ end Return `typeof(x)` stripped of its type parameters. """ basetypeof(x::T) where {T} = Base.typename(T).wrapper + +# TODO(mhauru) Might add another specialisation to _compose_no_identity, where if +# ReshapeTransforms are composed with each other or with a an UnwrapSingeltonTransform, only +# the latter one would be kept. +""" + _compose_no_identity(f, g) + +Like `f ∘ g`, but if `f` or `g` is `identity` it is omitted. + +This helps avoid trivial cases of `ComposedFunction` that would cause unnecessary type +conflicts. +""" +_compose_no_identity(f, g) = f ∘ g +_compose_no_identity(::typeof(identity), g) = g +_compose_no_identity(f, ::typeof(identity)) = f +_compose_no_identity(::typeof(identity), ::typeof(identity)) = identity diff --git a/src/varnamedtuple.jl b/src/varnamedtuple.jl new file mode 100644 index 000000000..1ca75d343 --- /dev/null +++ b/src/varnamedtuple.jl @@ -0,0 +1,694 @@ +# TODO(mhauru) This module should probably be moved to AbstractPPL. +module VarNamedTuples + +using AbstractPPL +using BangBang +using Accessors +using ..DynamicPPL: _compose_no_identity + +export VarNamedTuple + +# We define our own getindex, setindex!!, and haskey functions, which we use to +# get/set/check values in VarNamedTuple and PartialArray. We do this because we want to be +# able to override their behaviour for some types exported from elsewhere without type +# piracy. This is needed because +# 1. We would want to index into things with lenses (from Accessors.jl) using getindex and +# setindex!!, but Accessors does not define these methods. +# 2. We would want `haskey` to fall back onto `checkbounds` when called on Base.Arrays. +function _getindex end +function _haskey end +function _setindex!! end + +_getindex(arr::AbstractArray, optic::IndexLens) = getindex(arr, optic.indices...) +_haskey(arr::AbstractArray, optic::IndexLens) = _haskey(arr, optic.indices) +_haskey(arr::AbstractArray, inds) = checkbounds(Bool, arr, inds...) +function _setindex!!(arr::AbstractArray, value, optic::IndexLens) + return setindex!!(arr, value, optic.indices...) +end + +# Some utilities for checking what sort of indices we are dealing with. +_has_colon(::T) where {T<:Tuple} = any(x <: Colon for x in T.parameters) +function _is_multiindex(::T) where {T<:Tuple} + return any(x <: UnitRange || x <: Colon for x in T.parameters) +end + +""" + _merge_recursive(x1, x2) + +Recursively merge two values `x1` and `x2`. + +Unlike `Base.merge`, this function is defined for all types, and by default returns the +second argument. It is overridden for `PartialArray` and `VarNamedTuple`, since they are +nested containers, and calls itself recursively on all elements that are found in both +`x1` and `x2`. + +In other words, if both `x` and `y` are collections with the key `a`, `Base.merge(x, y)[a]` +is `y[a]`, whereas `_merge_recursive(x, y)[a]` will be `_merge_recursive(x[a], y[a])`, +unless no specific method is defined for the type of `x` and `y`, in which case +`_merge_recursive(x, y) === y` +""" +_merge_recursive(_, x2) = x2 + +"""The factor by which we increase the dimensions of PartialArrays when resizing them.""" +const PARTIAL_ARRAY_DIM_GROWTH_FACTOR = 4 + +"""A convenience for defining method argument type bounds.""" +const INDEX_TYPES = Union{Integer,UnitRange,Colon} + +""" + PartialArray{ElType,numdims} + +An array-like like structure that may only have some of its elements defined. + +A `PartialArray` is like a `Base.Array,` except not all of its elements are necessarily +defined. That is to say, one can create an empty `PartialArray` `arr` and e.g. set +`arr[3,2] = 5`, but asking for `arr[1,1]` may throw a `BoundsError` if `[1, 1]` has not been +explicitly set yet. + +`PartialArray`s can be indexed with integer indices and ranges. Indexing is always 1-based. +Other types of indexing allowed by `Base.Array` are not supported. Some of these are simply +because we haven't seen a need and haven't bothered to implement them, namely boolean +indexing, linear indexing into multidimensional arrays, and indexing with arrays. However, +notably, indexing with colons (i.e. `:`) is not supported for more fundamental reasons. + +To understand this, note that a `PartialArray` has no well-defined size. For example, if one +creates an empty array and sets `arr[3,2]`, it is unclear if that should be taken to mean +that the array has size `(3,2)`: It could be larger, and saying that the size is `(3,2)` +would also misleadingly suggest that all elements within `1:3,1:2` are set. This is also why +colon indexing is ill-defined: If one would e.g. set `arr[2,:] = [1,2,3]`, we would have no +way of saying whether the right hand side is of an acceptable size or not. + +The fact that its size is ill-defined also means that `PartialArray` is not a subtype of +`AbstractArray`. + +All indexing into `PartialArray`s is done with `getindex` and `setindex!!`. `setindex!`, +`push!`, etc. are not defined. The element type of a `PartialArray` will change as needed +under `setindex!!` to accomoddate the new values. + +Like `Base.Array`s, `PartialArray`s have a well-defined, compile-time-known element type +`ElType` and number of dimensions `numdims`. Indices into a `PartialArray` must have exactly +`numdims` elements. + +If the element type of a `PartialArray` is not concrete, any call to `setindex!!` will check +if, after the new value has been set, the element type can be made more concrete. If so, +a new `PartialArray` with a more concrete element type is returned. Thus the element type +of any `PartialArray` should always be as concrete as is allowed by the elements in it. + +The internal implementation of an `PartialArray` consists of two arrays: one holding the +data and the other one being a boolean mask indicating which elements are defined. These +internal arrays may need resizing when new elements are set that have index ranges larger +than the current internal arrays. To avoid resizing too often, the internal arrays are +resized in exponentially increasing steps. This means that most `setindex!!` calls are very +fast, but some may incur substantial overhead due to resizing and copying data. It also +means that the largest index set so far determines the memory usage of the `PartialArray`. +`PartialArray`s are thus well-suited when most values in it will eventually be set. If only +a few scattered values are set, a structure like `SparseArray` may be more appropriate. +""" +struct PartialArray{ElType,num_dims} + data::Array{ElType,num_dims} + mask::Array{Bool,num_dims} + + function PartialArray( + data::Array{ElType,num_dims}, mask::Array{Bool,num_dims} + ) where {ElType,num_dims} + if size(data) != size(mask) + throw(ArgumentError("Data and mask arrays must have the same size")) + end + return new{ElType,num_dims}(data, mask) + end +end + +""" + PartialArray{ElType,num_dims}(args::Vararg{Pair}; min_size=nothing) + +Create a new `PartialArray`. + +The element type and number of dimensions have to be specified explicitly as type +parameters. The positional arguments can be `Pair`s of indices and values. For example, +```jldoctest +julia> using DynamicPPL.VarNamedTuples: PartialArray + +julia> pa = PartialArray{Int,2}((1,2) => 5, (3,4) => 10) +PartialArray{Int64,2}((1, 2) => 5, (3, 4) => 10) +``` + +The optional keyword argument `min_size` can be used to specify the minimum initial size. +This is purely a performance optimisation, to avoid resizing if the eventual size is known +ahead of time. +""" +function PartialArray{ElType,num_dims}( + args::Vararg{Pair}; min_size::Union{Tuple,Nothing}=nothing +) where {ElType,num_dims} + dims = if min_size === nothing + ntuple(_ -> PARTIAL_ARRAY_DIM_GROWTH_FACTOR, num_dims) + else + map(_partial_array_dim_size, min_size) + end + data = Array{ElType,num_dims}(undef, dims) + mask = fill(false, dims) + pa = PartialArray(data, mask) + + for (inds, value) in args + pa = _setindex!!(pa, convert(ElType, value), inds...) + end + return pa +end + +Base.ndims(::PartialArray{ElType,num_dims}) where {ElType,num_dims} = num_dims +Base.eltype(::PartialArray{ElType}) where {ElType} = ElType + +function Base.show(io::IO, pa::PartialArray) + print(io, "PartialArray{", eltype(pa), ",", ndims(pa), "}(") + is_first = true + for inds in CartesianIndices(pa.mask) + if @inbounds(!pa.mask[inds]) + continue + end + if !is_first + print(io, ", ") + else + is_first = false + end + val = @inbounds(pa.data[inds]) + # Note the distinction: The raw strings that form part of the structure of the print + # out are `print`ed, whereas the keys and values are `show`n. The latter ensures + # that strings are quoted, Symbols are prefixed with :, etc. + show(io, Tuple(inds)) + print(io, " => ") + show(io, val) + end + print(io, ")") + return nothing +end + +# We deliberately don't define Base.size for PartialArray, because it is ill-defined. +# The size of the .data field is an implementation detail. +_internal_size(pa::PartialArray, args...) = size(pa.data, args...) + +function Base.copy(pa::PartialArray) + # Make a shallow copy of pa, except for any VarNamedTuple elements, which we recursively + # copy. + pa_copy = PartialArray(copy(pa.data), copy(pa.mask)) + if VarNamedTuple <: eltype(pa) || eltype(pa) <: VarNamedTuple + @inbounds for i in eachindex(pa.mask) + if pa.mask[i] && pa_copy.data[i] isa VarNamedTuple + pa_copy.data[i] = copy(pa.data[i]) + end + end + end + return pa_copy +end + +function Base.:(==)(pa1::PartialArray, pa2::PartialArray) + if ndims(pa1) != ndims(pa2) + return false + end + size1 = _internal_size(pa1) + size2 = _internal_size(pa2) + # TODO(mhauru) This could be optimised by not calling checkbounds on all elements + # outside the size of an array, but not sure it's worth it. + merge_size = ntuple(i -> max(size1[i], size2[i]), ndims(pa1)) + for i in CartesianIndices(merge_size) + m1 = checkbounds(Bool, pa1.mask, Tuple(i)...) ? pa1.mask[i] : false + m2 = checkbounds(Bool, pa2.mask, Tuple(i)...) ? pa2.mask[i] : false + if m1 != m2 + return false + end + if m1 && (pa1.data[i] != pa2.data[i]) + return false + end + end + return true +end + +function Base.hash(pa::PartialArray, h::UInt) + h = hash(ndims(pa), h) + for i in eachindex(pa.mask) + @inbounds if pa.mask[i] + h = hash(i, h) + h = hash(pa.data[i], h) + end + end + return h +end + +""" + _concretise_eltype!!(pa::PartialArray) + +Concretise the element type of a `PartialArray`. + +Returns a new `PartialArray` with the same data and mask as `pa`, but with its element type +concretised to the most specific type that can hold all currently defined elements. + +Note that this function is fundamentally type unstable if the current element type of `pa` +is not already concrete. + +The name has a `!!` not because it mutates its argument, but because the return value +aliases memory with the argument, and is thus not independent of it. +""" +function _concretise_eltype!!(pa::PartialArray) + if isconcretetype(eltype(pa)) + return pa + end + new_et = promote_type((typeof(pa.data[i]) for i in eachindex(pa.mask) if pa.mask[i])...) + # TODO(mhauru) Should we check as below, or rather isconcretetype(new_et)? + # In other words, does it help to be more concrete, even if we aren't fully concrete? + if new_et === eltype(pa) + # The types of the elements do not allow for concretisation. + return pa + end + new_data = Array{new_et,ndims(pa)}(undef, _internal_size(pa)) + @inbounds for i in eachindex(pa.mask) + if pa.mask[i] + new_data[i] = pa.data[i] + end + end + return PartialArray(new_data, pa.mask) +end + +"""Return the length needed in a dimension given an index.""" +_length_needed(i::Integer) = i +_length_needed(r::UnitRange) = last(r) + +"""Take the minimum size that a dimension of a PartialArray needs to be, and return the size +we choose it to be. This size will be the smallest possible power of +PARTIAL_ARRAY_DIM_GROWTH_FACTOR. Growing PartialArrays in big jumps like this helps reduce +data copying, as resizes aren't needed as often. +""" +function _partial_array_dim_size(min_dim) + factor = PARTIAL_ARRAY_DIM_GROWTH_FACTOR + return factor^(Int(ceil(log(factor, min_dim)))) +end + +"""Return the minimum internal size needed for a `PartialArray` to be able set the value +at inds. +""" +function _min_size(pa::PartialArray, inds) + return ntuple(i -> max(_internal_size(pa, i), _length_needed(inds[i])), length(inds)) +end + +"""Resize a PartialArray to be able to accommodate the index inds. This operates in place +for vectors, but makes a copy for higher-dimensional arrays, unless no resizing is +necessary, in which case this is a no-op.""" +function _resize_partialarray!!(pa::PartialArray, inds) + min_size = _min_size(pa, inds) + new_size = map(_partial_array_dim_size, min_size) + if new_size == _internal_size(pa) + return pa + end + # Generic multidimensional Arrays can not be resized, so we need to make a new one. + # See https://github.com/JuliaLang/julia/issues/37900 + new_data = Array{eltype(pa),ndims(pa)}(undef, new_size) + new_mask = fill(false, new_size) + # Note that we have to use CartesianIndices instead of eachindex, because the latter + # may use a linear index that does not match between the old and the new arrays. + @inbounds for i in CartesianIndices(pa.data) + mask_val = pa.mask[i] + if mask_val + new_mask[i] = mask_val + new_data[i] = pa.data[i] + end + end + return PartialArray(new_data, new_mask) +end + +# The below implements the same functionality as above, but more performantly for 1D arrays. +function _resize_partialarray!!(pa::PartialArray{Eltype,1}, (ind,)) where {Eltype} + # Resize arrays to accommodate new indices. + old_size = _internal_size(pa, 1) + min_size = max(old_size, _length_needed(ind)) + new_size = _partial_array_dim_size(min_size) + if new_size == old_size + return pa + end + resize!(pa.data, new_size) + resize!(pa.mask, new_size) + @inbounds pa.mask[(old_size + 1):new_size] .= false + return pa +end + +_getindex(pa::PartialArray, optic::IndexLens) = _getindex(pa, optic.indices...) +_haskey(pa::PartialArray, optic::IndexLens) = _haskey(pa, optic.indices) +function _setindex!!(pa::PartialArray, value, optic::IndexLens) + return _setindex!!(pa, value, optic.indices...) +end + +"""Throw an appropriate error if the given indices are invalid for `pa`.""" +function _check_index_validity(pa::PartialArray, inds::NTuple{N,INDEX_TYPES}) where {N} + if length(inds) != ndims(pa) + throw(BoundsError(pa, inds)) + end + if _has_colon(inds) + throw(ArgumentError("Indexing PartialArrays with Colon is not supported")) + end + return nothing +end + +function _getindex(pa::PartialArray, inds::Vararg{INDEX_TYPES}) + _check_index_validity(pa, inds) + if !_haskey(pa, inds) + throw(BoundsError(pa, inds)) + end + return getindex(pa.data, inds...) +end + +function _haskey(pa::PartialArray, inds::NTuple{N,INDEX_TYPES}) where {N} + _check_index_validity(pa, inds) + return checkbounds(Bool, pa.mask, inds...) && all(@inbounds(getindex(pa.mask, inds...))) +end + +function _setindex!!(pa::PartialArray, value, inds::Vararg{INDEX_TYPES}) + _check_index_validity(pa, inds) + pa = if checkbounds(Bool, pa.mask, inds...) + pa + else + _resize_partialarray!!(pa, inds) + end + new_data = setindex!!(pa.data, value, inds...) + if _is_multiindex(inds) + pa.mask[inds...] .= true + else + pa.mask[inds...] = true + end + return _concretise_eltype!!(PartialArray(new_data, pa.mask)) +end + +Base.merge(x1::PartialArray, x2::PartialArray) = _merge_recursive(x1, x2) + +function _merge_element_recursive(x1::PartialArray, x2::PartialArray, ind::CartesianIndex) + m1 = x1.mask[ind] + m2 = x2.mask[ind] + return if m1 && m2 + _merge_recursive(x1.data[ind], x2.data[ind]) + elseif m2 + x2.data[ind] + else + x1.data[ind] + end +end + +# TODO(mhauru) Would this benefit from a specialised method for 1D PartialArrays? +function _merge_recursive(pa1::PartialArray, pa2::PartialArray) + if ndims(pa1) != ndims(pa2) + throw( + ArgumentError("Cannot merge PartialArrays with different numbers of dimensions") + ) + end + num_dims = ndims(pa1) + merge_size = ntuple(i -> max(_internal_size(pa1, i), _internal_size(pa2, i)), num_dims) + return if merge_size == _internal_size(pa2) + # Either pa2 is strictly bigger than pa1 or they are equal in size. + result = copy(pa2) + for i in CartesianIndices(pa1.data) + @inbounds if pa1.mask[i] + result = setindex!!( + result, _merge_element_recursive(pa1, result, i), Tuple(i)... + ) + end + end + result + else + if merge_size == _internal_size(pa1) + # pa1 is bigger than pa2 + result = copy(pa1) + for i in CartesianIndices(pa2.data) + @inbounds if pa2.mask[i] + result = setindex!!( + result, _merge_element_recursive(result, pa2, i), Tuple(i)... + ) + end + end + result + else + # Neither is strictly bigger than the other. + et = promote_type(eltype(pa1), eltype(pa2)) + new_data = Array{et,num_dims}(undef, merge_size) + new_mask = fill(false, merge_size) + result = PartialArray(new_data, new_mask) + for i in CartesianIndices(pa2.data) + @inbounds if pa2.mask[i] + result.mask[i] = true + result.data[i] = pa2.data[i] + end + end + for i in CartesianIndices(pa1.data) + @inbounds if pa1.mask[i] + result = setindex!!( + result, _merge_element_recursive(pa1, result, i), Tuple(i)... + ) + end + end + result + end + end +end + +function Base.keys(pa::PartialArray) + inds = findall(pa.mask) + lenses = map(x -> IndexLens(Tuple(x)), inds) + ks = Any[] + for lens in lenses + val = getindex(pa.data, lens.indices...) + if val isa VarNamedTuple + subkeys = keys(val) + for vn in subkeys + sublens = _varname_to_lens(vn) + push!(ks, _compose_no_identity(sublens, lens)) + end + else + push!(ks, lens) + end + end + return ks +end + +""" + VarNamedTuple{names,Values} + +A `NamedTuple`-like structure with `VarName` keys. + +`VarNamedTuple` is a data structure for storing arbitrary data, keyed by `VarName`s, in an +efficient and type stable manner. It is mainly used through `getindex`, `setindex!!`, and +`haskey`, all of which accept `VarName`s and only `VarName`s as keys. Anther notable methods +is `merge`, which recursively merges two `VarNamedTuple`s. + +The there are two major limitations to indexing by VarNamedTuples: + +* `VarName`s with `Colon`s, (e.g. `a[:]`) are not supported. This is because the meaning of + `a[:]` is ambiguous if only some elements of `a`, say `a[1]` and `a[3]`, are defined. +* Any `VarNames` with IndexLenses` must have a consistent number of indices. That is, one + cannot set `a[1]` and `a[1,2]` in the same `VarNamedTuple`. + +`setindex!!` and `getindex` on `VarNamedTuple` are type stable as long as one does not store +heterogeneous data under different indices of the same symbol. That is, if either + +* one sets `a[1]` and `a[2]` to be of different types, or +* if `a[1]` and `a[2]` both exist, one sets `a[1].b` without setting `a[2].b`, + +then getting values for `a[1]` or `a[2]` will not be type stable. + +`VarNamedTuple` is intrinsically linked to `PartialArray`, which it'll use to store data +related to `VarName`s with `IndexLens` components. +""" +struct VarNamedTuple{Names,Values} + data::NamedTuple{Names,Values} +end + +VarNamedTuple(; kwargs...) = VarNamedTuple((; kwargs...)) + +Base.:(==)(vnt1::VarNamedTuple, vnt2::VarNamedTuple) = vnt1.data == vnt2.data +Base.hash(vnt::VarNamedTuple, h::UInt) = hash(vnt.data, h) + +function Base.show(io::IO, vnt::VarNamedTuple) + if isempty(vnt.data) + return print(io, "VarNamedTuple()") + end + print(io, "VarNamedTuple") + show(io, vnt.data) + return nothing +end + +function Base.copy(vnt::VarNamedTuple{names}) where {names} + # Make a shallow copy of vnt, except for any VarNamedTuple or PartialArray elements, + # which we recursively copy. + return VarNamedTuple( + NamedTuple{names}( + map( + x -> x isa Union{VarNamedTuple,PartialArray} ? copy(x) : x, values(vnt.data) + ), + ), + ) +end + +""" + varname_to_lens(name::VarName{S}) where {S} + +Convert a `VarName` to an `Accessor` lens, wrapping the first symbol in a `PropertyLens`. + +This is used to simplify method dispatch for `_getindex`, `_setindex!!`, and `_haskey`, by +considering `VarName`s to just be a special case of lenses. +""" +function _varname_to_lens(name::VarName{S}) where {S} + return _compose_no_identity(getoptic(name), PropertyLens{S}()) +end + +_getindex(vnt::VarNamedTuple, name::VarName) = _getindex(vnt, _varname_to_lens(name)) +_getindex(vnt::VarNamedTuple, ::PropertyLens{S}) where {S} = getindex(vnt.data, S) +_getindex(vnt::VarNamedTuple, name::Symbol) = vnt.data[name] + +_haskey(vnt::VarNamedTuple, name::VarName) = _haskey(vnt, _varname_to_lens(name)) +_haskey(vnt::VarNamedTuple, ::PropertyLens{S}) where {S} = haskey(vnt.data, S) +_haskey(vnt::VarNamedTuple, ::typeof(identity)) = true +_haskey(::VarNamedTuple, ::IndexLens) = false + +function _setindex!!(vnt::VarNamedTuple, value, name::VarName) + return _setindex!!(vnt, value, _varname_to_lens(name)) +end + +function _setindex!!(vnt::VarNamedTuple, value, ::PropertyLens{S}) where {S} + # I would like for this to just read + # return VarNamedTuple(_setindex!!(vnt.data, value, S)) + # but that seems to be type unstable. Why? Shouldn't it obviously be the same as the + # below? + return VarNamedTuple(merge(vnt.data, NamedTuple{(S,)}((value,)))) +end + +Base.merge(x1::VarNamedTuple, x2::VarNamedTuple) = _merge_recursive(x1, x2) + +# This needs to be a generated function for type stability. +@generated function _merge_recursive( + vnt1::VarNamedTuple{names1}, vnt2::VarNamedTuple{names2} +) where {names1,names2} + all_names = union(names1, names2) + exs = Expr[] + push!(exs, :(data = (;))) + for name in all_names + val_expr = if name in names1 && name in names2 + :(_merge_recursive(vnt1.data.$name, vnt2.data.$name)) + elseif name in names1 + :(vnt1.data.$name) + else + :(vnt2.data.$name) + end + push!(exs, :(data = merge(data, NamedTuple{($(QuoteNode(name)),)}(($val_expr,))))) + end + push!(exs, :(return VarNamedTuple(data))) + return Expr(:block, exs...) +end + +# TODO(mhauru) The below remains unfinished an undertested. I think it's incorrect for more +# complex VarNames. It is unexported though. +""" + apply!!(func, vnt::VarNamedTuple, name::VarName) + +Apply `func` to the subdata at `name` in `vnt`, and set the result back at `name`. + +```jldoctest +julia> using DynamicPPL: VarNamedTuple, setindex!! + +julia> using DynamicPPL.VarNamedTuples: apply!! + +julia> vnt = VarNamedTuple() +VarNamedTuple() + +julia> vnt = setindex!!(vnt, [1, 2, 3], @varname(a)) +VarNamedTuple(a = [1, 2, 3],) + +julia> apply!!(x -> x .+ 1, vnt, @varname(a)) +VarNamedTuple(a = [2, 3, 4],) +``` +""" +function apply!!(func, vnt::VarNamedTuple, name::VarName) + if !haskey(vnt, name) + throw(KeyError(repr(name))) + end + subdata = _getindex(vnt, name) + new_subdata = func(subdata) + return _setindex!!(vnt, new_subdata, name) +end + +# TODO(mhauru) Should this return tuples, like it does now? That makes sense for +# VarNamedTuple itself, but if there is a nested PartialArray the tuple might get very big. +# Also, this is not very type stable, it fails even in basic cases. A generated function +# would help, but I failed to make one. Might be something to do with a recursive +# generated function. +function Base.keys(vnt::VarNamedTuple) + result = () + for sym in keys(vnt.data) + subdata = vnt.data[sym] + if subdata isa VarNamedTuple + subkeys = keys(subdata) + result = ( + result..., (AbstractPPL.prefix(sk, VarName{sym}()) for sk in subkeys)... + ) + elseif subdata isa PartialArray + subkeys = keys(subdata) + result = (result..., (VarName{sym}(lens) for lens in subkeys)...) + else + result = (result..., VarName{sym}()) + end + end + return result +end + +# The following methods, indexing with ComposedFunction, are exactly the same for +# VarNamedTuple and PartialArray, since they just fall back on indexing with the outer and +# inner lenses. +const VNT_OR_PA = Union{VarNamedTuple,PartialArray} + +function _getindex(x::VNT_OR_PA, optic::ComposedFunction) + subdata = _getindex(x, optic.inner) + return _getindex(subdata, optic.outer) +end + +function _setindex!!(vnt::VNT_OR_PA, value, optic::ComposedFunction) + sub = if _haskey(vnt, optic.inner) + _setindex!!(_getindex(vnt, optic.inner), value, optic.outer) + else + make_leaf(value, optic.outer) + end + return _setindex!!(vnt, sub, optic.inner) +end + +function _haskey(vnt::VNT_OR_PA, optic::ComposedFunction) + return _haskey(vnt, optic.inner) && _haskey(_getindex(vnt, optic.inner), optic.outer) +end + +# The entry points for getting, setting, and checking, using the familiar functions. +Base.haskey(vnt::VarNamedTuple, vn::VarName) = _haskey(vnt, vn) +Base.getindex(vnt::VarNamedTuple, vn::VarName) = _getindex(vnt, vn) +BangBang.setindex!!(vnt::VarNamedTuple, value, vn::VarName) = _setindex!!(vnt, value, vn) + +Base.haskey(vnt::PartialArray, key) = _haskey(vnt, key) +Base.getindex(vnt::PartialArray, inds...) = _getindex(vnt, inds...) +BangBang.setindex!!(vnt::PartialArray, value, inds...) = _setindex!!(vnt, value, inds...) + +""" + make_leaf(value, optic) + +Make a new leaf node for a VarNamedTuple. + +This is the function that sets any `optic` that is a `PropertyLens` to be stored as a +`VarNamedTuple`, any `IndexLens` to be stored as a `PartialArray`, and other `identity` +optics to be stored as raw values. It is the link that joins `VarNamedTuple` and +`PartialArray` together. +""" +make_leaf(value, ::typeof(identity)) = value +make_leaf(value, ::PropertyLens{S}) where {S} = VarNamedTuple(NamedTuple{(S,)}((value,))) + +function make_leaf(value, optic::ComposedFunction) + sub = make_leaf(value, optic.outer) + return make_leaf(sub, optic.inner) +end + +function make_leaf(value, optic::IndexLens) + inds = optic.indices + num_inds = length(inds) + # Check if any of the indices are ranges or colons. If yes, value needs to be an + # AbstractArray. Otherwise it needs to be an individual value. + et = _is_multiindex(inds) ? eltype(value) : typeof(value) + pa = PartialArray{et,num_inds}() + return _setindex!!(pa, value, optic) +end + +end diff --git a/src/varnamedvector.jl b/src/varnamedvector.jl index 17b851d1d..e5d2f2c2e 100644 --- a/src/varnamedvector.jl +++ b/src/varnamedvector.jl @@ -1355,22 +1355,6 @@ function nextrange(vnv::VarNamedVector, x) return (offset + 1):(offset + length(x)) end -# TODO(mhauru) Might add another specialisation to _compose_no_identity, where if -# ReshapeTransforms are composed with each other or with a an UnwrapSingeltonTransform, only -# the latter one would be kept. -""" - _compose_no_identity(f, g) - -Like `f ∘ g`, but if `f` or `g` is `identity` it is omitted. - -This helps avoid trivial cases of `ComposedFunction` that would cause unnecessary type -conflicts. -""" -_compose_no_identity(f, g) = f ∘ g -_compose_no_identity(::typeof(identity), g) = g -_compose_no_identity(f, ::typeof(identity)) = f -_compose_no_identity(::typeof(identity), ::typeof(identity)) = identity - """ shift_right!(x::AbstractVector{<:Real}, start::Int, n::Int) diff --git a/test/logdensityfunction.jl b/test/logdensityfunction.jl index 383d7593d..8a0fb3954 100644 --- a/test/logdensityfunction.jl +++ b/test/logdensityfunction.jl @@ -30,17 +30,13 @@ using Mooncake: Mooncake else unlinked_vi end - nt_ranges, dict_ranges = DynamicPPL.get_ranges_and_linked(vi) + ranges = DynamicPPL.get_ranges_and_linked(vi) params = [x for x in vi[:]] # Iterate over all variables for vn in keys(vi) # Check that `getindex_internal` returns the same thing as using the ranges # directly - range_with_linked = if AbstractPPL.getoptic(vn) === identity - nt_ranges[AbstractPPL.getsym(vn)] - else - dict_ranges[vn] - end + range_with_linked = ranges[vn] @test params[range_with_linked.range] == DynamicPPL.getindex_internal(vi, vn) # Check that the link status is correct diff --git a/test/varnamedtuple.jl b/test/varnamedtuple.jl new file mode 100644 index 000000000..67f3d5c2b --- /dev/null +++ b/test/varnamedtuple.jl @@ -0,0 +1,460 @@ +module VarNamedTupleTests + +using Test: @inferred, @test, @test_throws, @testset +using DynamicPPL: DynamicPPL, @varname, VarNamedTuple +using DynamicPPL.VarNamedTuples: PartialArray +using AbstractPPL: VarName, prefix +using BangBang: setindex!! + +""" + test_invariants(vnt::VarNamedTuple) + +Test properties that should hold for all VarNamedTuples. + +Uses @test for all the tests. Intended to be called inside a @testset. +""" +function test_invariants(vnt::VarNamedTuple) + # Check that for all keys in vnt, haskey is true, and resetting the value is a no-op. + for k in keys(vnt) + @test haskey(vnt, k) + v = getindex(vnt, k) + vnt2 = setindex!!(copy(vnt), v, k) + @test vnt == vnt2 + @test hash(vnt) == hash(vnt2) + end + # Check that the printed representation can be parsed back to an equal VarNamedTuple. + vnt3 = eval(Meta.parse(repr(vnt))) + @test vnt == vnt3 + @test hash(vnt) == hash(vnt3) + # Check that merge with an empty VarNamedTuple is a no-op. + @test merge(vnt, VarNamedTuple()) == vnt + @test merge(VarNamedTuple(), vnt) == vnt +end + +@testset "VarNamedTuple" begin + @testset "Construction" begin + vnt1 = VarNamedTuple() + test_invariants(vnt1) + vnt1 = setindex!!(vnt1, 1.0, @varname(a)) + vnt1 = setindex!!(vnt1, [1, 2, 3], @varname(b)) + vnt1 = setindex!!(vnt1, "a", @varname(c.d.e)) + test_invariants(vnt1) + vnt2 = VarNamedTuple(; + a=1.0, b=[1, 2, 3], c=VarNamedTuple(; d=VarNamedTuple(; e="a")) + ) + test_invariants(vnt2) + @test vnt1 == vnt2 + + pa1 = PartialArray{Float64,1}() + pa1 = setindex!!(pa1, 1.0, 16) + pa2 = PartialArray{Float64,1}(; min_size=(16,)) + pa2 = setindex!!(pa2, 1.0, 16) + pa3 = PartialArray{Float64,1}(16 => 1.0) + pa4 = PartialArray{Float64,1}((16,) => 1.0) + @test pa1 == pa2 + @test pa1 == pa3 + @test pa1 == pa4 + + pa1 = PartialArray{String,3}() + pa1 = setindex!!(pa1, "a", 2, 3, 4) + pa1 = setindex!!(pa1, "b", 1, 2, 4) + pa2 = PartialArray{String,3}(; min_size=(16, 16, 16)) + pa2 = setindex!!(pa2, "a", 2, 3, 4) + pa2 = setindex!!(pa2, "b", 1, 2, 4) + pa3 = PartialArray{String,3}((2, 3, 4) => "a", (1, 2, 4) => "b") + @test pa1 == pa2 + @test pa1 == pa3 + + @test_throws BoundsError PartialArray{Int,1}((0,) => 1) + @test_throws BoundsError PartialArray{Int,1}((1, 2) => 1) + @test_throws MethodError PartialArray{Int,1}((1,) => "a") + @test_throws MethodError PartialArray{Int,1}((1,) => 1; min_size=(2, 2)) + end + + @testset "Basic sets and gets" begin + vnt = VarNamedTuple() + vnt = @inferred(setindex!!(vnt, 32.0, @varname(a))) + @test @inferred(getindex(vnt, @varname(a))) == 32.0 + @test haskey(vnt, @varname(a)) + @test !haskey(vnt, @varname(b)) + test_invariants(vnt) + + vnt = @inferred(setindex!!(vnt, [1, 2, 3], @varname(b))) + @test @inferred(getindex(vnt, @varname(b))) == [1, 2, 3] + @test @inferred(getindex(vnt, @varname(b[2]))) == 2 + @test haskey(vnt, @varname(b)) + @test haskey(vnt, @varname(b[1])) + @test haskey(vnt, @varname(b[1:3])) + @test !haskey(vnt, @varname(b[4])) + test_invariants(vnt) + + vnt = @inferred(setindex!!(vnt, 64.0, @varname(a))) + @test @inferred(getindex(vnt, @varname(a))) == 64.0 + @test @inferred(getindex(vnt, @varname(b))) == [1, 2, 3] + test_invariants(vnt) + + vnt = @inferred(setindex!!(vnt, 15, @varname(b[2]))) + @test @inferred(getindex(vnt, @varname(b))) == [1, 15, 3] + @test @inferred(getindex(vnt, @varname(b[2]))) == 15 + test_invariants(vnt) + + vnt = @inferred(setindex!!(vnt, [10], @varname(c.x.y))) + @test @inferred(getindex(vnt, @varname(c.x.y))) == [10] + test_invariants(vnt) + + vnt = @inferred(setindex!!(vnt, 11, @varname(c.x.y[1]))) + @test @inferred(getindex(vnt, @varname(c.x.y))) == [11] + @test @inferred(getindex(vnt, @varname(c.x.y[1]))) == 11 + test_invariants(vnt) + + vnt = @inferred(setindex!!(vnt, -1.0, @varname(d[4]))) + @test @inferred(getindex(vnt, @varname(d[4]))) == -1.0 + test_invariants(vnt) + + vnt = @inferred(setindex!!(vnt, -2.0, @varname(d[4]))) + @test @inferred(getindex(vnt, @varname(d[4]))) == -2.0 + test_invariants(vnt) + + # These can't be @inferred because `d` now has an abstract element type. Note that this + # does not ruin type stability for other varnames that don't involve `d`. + vnt = setindex!!(vnt, "a", @varname(d[5])) + @test getindex(vnt, @varname(d[5])) == "a" + test_invariants(vnt) + + vnt = @inferred(setindex!!(vnt, 1.0, @varname(e.f[3].g.h[2].i))) + @test @inferred(getindex(vnt, @varname(e.f[3].g.h[2].i))) == 1.0 + @test haskey(vnt, @varname(e.f[3].g.h[2].i)) + @test !haskey(vnt, @varname(e.f[2].g.h[2].i)) + test_invariants(vnt) + + vnt = @inferred(setindex!!(vnt, 2.0, @varname(e.f[3].g.h[2].i))) + @test @inferred(getindex(vnt, @varname(e.f[3].g.h[2].i))) == 2.0 + test_invariants(vnt) + + vec = fill(1.0, 4) + vnt = @inferred(setindex!!(vnt, vec, @varname(j[1:4]))) + @test @inferred(getindex(vnt, @varname(j[1:4]))) == vec + @test @inferred(getindex(vnt, @varname(j[2]))) == vec[2] + @test haskey(vnt, @varname(j[4])) + @test !haskey(vnt, @varname(j[5])) + @test_throws BoundsError getindex(vnt, @varname(j[5])) + test_invariants(vnt) + + vec = fill(2.0, 4) + vnt = @inferred(setindex!!(vnt, vec, @varname(j[2:5]))) + @test @inferred(getindex(vnt, @varname(j[1]))) == 1.0 + @test @inferred(getindex(vnt, @varname(j[2:5]))) == vec + @test haskey(vnt, @varname(j[5])) + test_invariants(vnt) + + arr = fill(2.0, (4, 2)) + vn = @varname(k.l[2:5, 3, 1:2, 2]) + vnt = @inferred(setindex!!(vnt, arr, vn)) + @test @inferred(getindex(vnt, vn)) == arr + # A subset of the elements set just now. + @test @inferred(getindex(vnt, @varname(k.l[2, 3, 1:2, 2]))) == fill(2.0, 2) + test_invariants(vnt) + + # Not enough, or too many, indices. + @test_throws BoundsError setindex!!(vnt, 0.0, @varname(k.l[1, 2, 3])) + @test_throws BoundsError setindex!!(vnt, 0.0, @varname(k.l[1, 2, 3, 4, 5])) + + arr = fill(3.0, (3, 3)) + vn = @varname(k.l[1, 1:3, 1:3, 1]) + vnt = @inferred(setindex!!(vnt, arr, vn)) + @test @inferred(getindex(vnt, vn)) == arr + # A subset of the elements set just now. + @test @inferred(getindex(vnt, @varname(k.l[1, 1:2, 1:2, 1]))) == fill(3.0, 2, 2) + # A subset of the elements set previously. + @test @inferred(getindex(vnt, @varname(k.l[2, 3, 1:2, 2]))) == fill(2.0, 2) + @test !haskey(vnt, @varname(k.l[2, 3, 3, 2])) + test_invariants(vnt) + + vnt = @inferred(setindex!!(vnt, 1.0, @varname(m[2]))) + vnt = @inferred(setindex!!(vnt, 1.0, @varname(m[3]))) + @test @inferred(getindex(vnt, @varname(m[2:3]))) == [1.0, 1.0] + @test !haskey(vnt, @varname(m[1])) + test_invariants(vnt) + + # The below tests are mostly significant for the type stability aspect. For the last + # test to pass, PartialArray needs to actively tighten its eltype when possible. + vnt = @inferred(setindex!!(vnt, 1.0, @varname(n[1].a))) + @test @inferred(getindex(vnt, @varname(n[1].a))) == 1.0 + vnt = @inferred(setindex!!(vnt, 1.0, @varname(n[2].a))) + @test @inferred(getindex(vnt, @varname(n[2].a))) == 1.0 + # This can't be type stable, because n[1] has inhomogeneous types. + vnt = setindex!!(vnt, 1.0, @varname(n[1].b)) + @test getindex(vnt, @varname(n[1].b)) == 1.0 + # The setindex!! call can't be type stable either, but it should return a + # VarNamedTuple with a concrete element type, and hence getindex can be inferred. + vnt = setindex!!(vnt, 1.0, @varname(n[2].b)) + @test @inferred(getindex(vnt, @varname(n[2].b))) == 1.0 + test_invariants(vnt) + + # Some funky Symbols in VarNames + # TODO(mhauru) This still isn't as robust as it should be, for instance Symbol(":") + # fails the eval(Meta.parse(print(vnt))) == vnt test because NamedTuple show doesn't + # respect the eval-property. + vn1 = VarName{Symbol("a b c")}() + vnt = @inferred(setindex!!(vnt, 2, vn1)) + @test @inferred(getindex(vnt, vn1)) == 2 + test_invariants(vnt) + vn2 = VarName{Symbol("1")}() + vnt = @inferred(setindex!!(vnt, 3, vn2)) + @test @inferred(getindex(vnt, vn2)) == 3 + test_invariants(vnt) + vn3 = VarName{Symbol("?!")}() + vnt = @inferred(setindex!!(vnt, 4, vn3)) + @test @inferred(getindex(vnt, vn3)) == 4 + test_invariants(vnt) + vnt = VarNamedTuple() + vn4 = prefix(prefix(vn1, vn2), vn3) + vnt = @inferred(setindex!!(vnt, 5, vn4)) + @test @inferred(getindex(vnt, vn4)) == 5 + test_invariants(vnt) + vn5 = prefix(prefix(vn3, vn2), vn1) + vnt = @inferred(setindex!!(vnt, 6, vn5)) + @test @inferred(getindex(vnt, vn5)) == 6 + test_invariants(vnt) + end + + @testset "equality" begin + vnt1 = VarNamedTuple() + vnt2 = VarNamedTuple() + @test vnt1 == vnt2 + + vnt1 = setindex!!(vnt1, 1.0, @varname(a)) + @test vnt1 != vnt2 + + vnt2 = setindex!!(vnt2, 1.0, @varname(a)) + @test vnt1 == vnt2 + + vnt1 = setindex!!(vnt1, [1, 2], @varname(b)) + vnt2 = setindex!!(vnt2, [1, 2], @varname(b)) + @test vnt1 == vnt2 + + vnt2 = setindex!!(vnt2, [1, 3], @varname(b)) + @test vnt1 != vnt2 + vnt2 = setindex!!(vnt2, [1, 2], @varname(b)) + + # Try with index lenses too + vnt1 = setindex!!(vnt1, 2, @varname(c[2])) + vnt2 = setindex!!(vnt2, 2, @varname(c[2])) + @test vnt1 == vnt2 + + vnt2 = setindex!!(vnt2, 3, @varname(c[2])) + @test vnt1 != vnt2 + vnt2 = setindex!!(vnt2, 2, @varname(c[2])) + + vnt1 = setindex!!(vnt1, ["a", "b"], @varname(d.e[1:2])) + vnt2 = setindex!!(vnt2, ["a", "b"], @varname(d.e[1:2])) + @test vnt1 == vnt2 + + vnt2 = setindex!!(vnt2, :b, @varname(d.e[2])) + @test vnt1 != vnt2 + end + + @testset "merge" begin + vnt1 = VarNamedTuple() + vnt2 = VarNamedTuple() + expected_merge = VarNamedTuple() + @test @inferred(merge(vnt1, vnt2)) == expected_merge + + vnt1 = setindex!!(vnt1, 1.0, @varname(a)) + vnt2 = setindex!!(vnt2, 2.0, @varname(b)) + vnt1 = setindex!!(vnt1, 1, @varname(c)) + vnt2 = setindex!!(vnt2, 2, @varname(c)) + expected_merge = setindex!!(expected_merge, 1.0, @varname(a)) + expected_merge = setindex!!(expected_merge, 2, @varname(c)) + expected_merge = setindex!!(expected_merge, 2.0, @varname(b)) + @test @inferred(merge(vnt1, vnt2)) == expected_merge + + vnt1 = VarNamedTuple() + vnt2 = VarNamedTuple() + expected_merge = VarNamedTuple() + vnt1 = setindex!!(vnt1, [1], @varname(d.a)) + vnt2 = setindex!!(vnt2, [2, 2], @varname(d.b)) + vnt1 = setindex!!(vnt1, [1], @varname(d.c)) + vnt2 = setindex!!(vnt2, [2, 2], @varname(d.c)) + expected_merge = setindex!!(expected_merge, [1], @varname(d.a)) + expected_merge = setindex!!(expected_merge, [2, 2], @varname(d.c)) + expected_merge = setindex!!(expected_merge, [2, 2], @varname(d.b)) + @test @inferred(merge(vnt1, vnt2)) == expected_merge + + vnt1 = setindex!!(vnt1, 1, @varname(e.a[1])) + vnt2 = setindex!!(vnt2, 2, @varname(e.a[2])) + expected_merge = setindex!!(expected_merge, 1, @varname(e.a[1])) + expected_merge = setindex!!(expected_merge, 2, @varname(e.a[2])) + vnt1 = setindex!!(vnt1, 1, @varname(e.a[3])) + vnt2 = setindex!!(vnt2, 2, @varname(e.a[3])) + expected_merge = setindex!!(expected_merge, 2, @varname(e.a[3])) + @test @inferred(merge(vnt1, vnt2)) == expected_merge + + vnt1 = setindex!!(vnt1, fill(1, 4), @varname(e.a[7:10])) + vnt2 = setindex!!(vnt2, fill(2, 4), @varname(e.a[8:11])) + expected_merge = setindex!!(expected_merge, 1, @varname(e.a[7])) + expected_merge = setindex!!(expected_merge, fill(2, 4), @varname(e.a[8:11])) + @test @inferred(merge(vnt1, vnt2)) == expected_merge + + vnt1 = setindex!!(vnt1, ["1", "1"], @varname(f.a[1].b.c[2, 2].d[1, 3:4])) + vnt2 = setindex!!(vnt2, ["2", "2"], @varname(f.a[1].b.c[2, 2].d[1, 3:4])) + expected_merge = setindex!!( + expected_merge, ["2", "2"], @varname(f.a[1].b.c[2, 2].d[1, 3:4]) + ) + vnt1 = setindex!!(vnt1, :1, @varname(f.a[1].b.c[3, 2].d[1, 1])) + vnt2 = setindex!!(vnt2, :2, @varname(f.a[1].b.c[4, 2].d[1, 1])) + expected_merge = setindex!!(expected_merge, :1, @varname(f.a[1].b.c[3, 2].d[1, 1])) + expected_merge = setindex!!(expected_merge, :2, @varname(f.a[1].b.c[4, 2].d[1, 1])) + @test merge(vnt1, vnt2) == expected_merge + + # PartialArrays with different sizes. + vnt1 = VarNamedTuple() + vnt2 = VarNamedTuple() + vnt1 = setindex!!(vnt1, 1, @varname(a[1])) + vnt1 = setindex!!(vnt1, 1, @varname(a[257])) + vnt2 = setindex!!(vnt2, 2, @varname(a[1])) + vnt2 = setindex!!(vnt2, 2, @varname(a[2])) + expected_merge_12 = VarNamedTuple() + expected_merge_12 = setindex!!(expected_merge_12, 1, @varname(a[257])) + expected_merge_12 = setindex!!(expected_merge_12, 2, @varname(a[1])) + expected_merge_12 = setindex!!(expected_merge_12, 2, @varname(a[2])) + @test @inferred(merge(vnt1, vnt2)) == expected_merge_12 + expected_merge_21 = setindex!!(expected_merge_12, 1, @varname(a[1])) + @test @inferred(merge(vnt2, vnt1)) == expected_merge_21 + + vnt1 = VarNamedTuple() + vnt2 = VarNamedTuple() + vnt1 = setindex!!(vnt1, 1, @varname(a[1, 1])) + vnt1 = setindex!!(vnt1, 1, @varname(a[257, 1])) + vnt2 = setindex!!(vnt2, :2, @varname(a[1, 1])) + vnt2 = setindex!!(vnt2, :2, @varname(a[1, 257])) + expected_merge_12 = VarNamedTuple() + expected_merge_12 = setindex!!(expected_merge_12, :2, @varname(a[1, 1])) + expected_merge_12 = setindex!!(expected_merge_12, 1, @varname(a[257, 1])) + expected_merge_12 = setindex!!(expected_merge_12, :2, @varname(a[1, 257])) + @test merge(vnt1, vnt2) == expected_merge_12 + expected_merge_21 = setindex!!(expected_merge_12, 1, @varname(a[1, 1])) + @test merge(vnt2, vnt1) == expected_merge_21 + end + + @testset "keys" begin + vnt = VarNamedTuple() + @test @inferred(keys(vnt)) == () + + vnt = setindex!!(vnt, 1.0, @varname(a)) + # TODO(mhauru) that the below passes @inferred, but any of the later ones don't. + # We should improve type stability of keys(). + @test @inferred(keys(vnt)) == (@varname(a),) + + vnt = setindex!!(vnt, [1, 2, 3], @varname(b)) + @test keys(vnt) == (@varname(a), @varname(b)) + + vnt = setindex!!(vnt, 15, @varname(b[2])) + @test keys(vnt) == (@varname(a), @varname(b)) + + vnt = setindex!!(vnt, [10], @varname(c.x.y)) + @test keys(vnt) == (@varname(a), @varname(b), @varname(c.x.y)) + + vnt = setindex!!(vnt, -1.0, @varname(d[4])) + @test keys(vnt) == (@varname(a), @varname(b), @varname(c.x.y), @varname(d[4])) + + vnt = setindex!!(vnt, 2.0, @varname(e.f[3, 3].g.h[2, 4, 1].i)) + @test keys(vnt) == ( + @varname(a), + @varname(b), + @varname(c.x.y), + @varname(d[4]), + @varname(e.f[3, 3].g.h[2, 4, 1].i), + ) + + vnt = setindex!!(vnt, fill(1.0, 4), @varname(j[1:4])) + @test keys(vnt) == ( + @varname(a), + @varname(b), + @varname(c.x.y), + @varname(d[4]), + @varname(e.f[3, 3].g.h[2, 4, 1].i), + @varname(j[1]), + @varname(j[2]), + @varname(j[3]), + @varname(j[4]), + ) + + vnt = setindex!!(vnt, 1.0, @varname(j[6])) + @test keys(vnt) == ( + @varname(a), + @varname(b), + @varname(c.x.y), + @varname(d[4]), + @varname(e.f[3, 3].g.h[2, 4, 1].i), + @varname(j[1]), + @varname(j[2]), + @varname(j[3]), + @varname(j[4]), + @varname(j[6]), + ) + + vnt = setindex!!(vnt, 1.0, @varname(n[2].a)) + @test keys(vnt) == ( + @varname(a), + @varname(b), + @varname(c.x.y), + @varname(d[4]), + @varname(e.f[3, 3].g.h[2, 4, 1].i), + @varname(j[1]), + @varname(j[2]), + @varname(j[3]), + @varname(j[4]), + @varname(j[6]), + @varname(n[2].a), + ) + end + + @testset "printing" begin + vnt = VarNamedTuple() + io = IOBuffer() + show(io, vnt) + output = String(take!(io)) + @test output == "VarNamedTuple()" + + vnt = setindex!!(vnt, "s", @varname(a)) + io = IOBuffer() + show(io, vnt) + output = String(take!(io)) + @test output == """VarNamedTuple(a = "s",)""" + + vnt = setindex!!(vnt, [1, 2, 3], @varname(b)) + io = IOBuffer() + show(io, vnt) + output = String(take!(io)) + @test output == """VarNamedTuple(a = "s", b = [1, 2, 3])""" + + vnt = setindex!!(vnt, :dada, @varname(c[2])) + io = IOBuffer() + show(io, vnt) + output = String(take!(io)) + @test output == """ + VarNamedTuple(a = "s", b = [1, 2, 3], \ + c = PartialArray{Symbol,1}((2,) => :dada))""" + + vnt = setindex!!(vnt, [16.0, 17.0], @varname(d.e[3].f.g[1:2])) + io = IOBuffer() + show(io, vnt) + output = String(take!(io)) + # Depending on what's in scope, and maybe sometimes even the Julia version, + # sometimes types in the output are fully qualified, sometimes not. To avoid + # brittle tests, we normalise the output: + output = replace(output, "DynamicPPL." => "", "VarNamedTuples." => "") + @test output == """ + VarNamedTuple(a = "s", b = [1, 2, 3], \ + c = PartialArray{Symbol,1}((2,) => :dada), \ + d = VarNamedTuple(\ + e = PartialArray{VarNamedTuple{(:f,), \ + Tuple{VarNamedTuple{(:g,), \ + Tuple{PartialArray{Float64, 1}}}}},1}((3,) => \ + VarNamedTuple(f = VarNamedTuple(g = PartialArray{Float64,1}((1,) => 16.0, \ + (2,) => 17.0),),)),))""" + end +end + +end