From a9609931d4deea5a9622d400b5c368f22b4be465 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Fri, 14 Nov 2025 00:41:53 +0000 Subject: [PATCH 1/4] Make FastLDF the default --- HISTORY.md | 23 +- docs/src/api.md | 7 + ext/DynamicPPLMarginalLogDensitiesExt.jl | 11 +- src/DynamicPPL.jl | 6 +- src/experimental.jl | 2 - src/fasteval.jl | 97 +++--- src/logdensityfunction.jl | 377 ----------------------- test/ad.jl | 137 -------- test/chains.jl | 8 +- test/fasteval.jl | 64 +--- test/logdensityfunction.jl | 49 --- test/runtests.jl | 11 +- 12 files changed, 113 insertions(+), 679 deletions(-) delete mode 100644 src/logdensityfunction.jl delete mode 100644 test/ad.jl delete mode 100644 test/logdensityfunction.jl diff --git a/HISTORY.md b/HISTORY.md index 0f0102ce4..91306c219 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -4,6 +4,17 @@ ### Breaking changes +#### Fast Log Density Functions + +This version provides a reimplementation of `LogDensityFunction` that provides performance improvements on the order of 2–10× for both model evaluation as well as automatic differentiation. +Exact speedups depend on the model size: larger models have less significant speedups because the bulk of the work is done in calls to `logpdf`. + +For more information about how this is accomplished, please see https://github.com/TuringLang/DynamicPPL.jl/pull/1113 as well as the `src/fasteval.jl` file, which contains extensive comments. + +As a result of this change, `LogDensityFunction` no longer stores a VarInfo inside it. +In general, if `ldf` is a `LogDensityFunction`, it is now only valid to access `ldf.model` and `ldf.adtype`. +If you were previously relying on this behaviour, you will need to store a VarInfo separately. + #### Parent and leaf contexts The `DynamicPPL.NodeTrait` function has been removed. @@ -24,18 +35,6 @@ Removed the method `returned(::Model, values, keys)`; please use `returned(::Mod The method `DynamicPPL.init` (for implementing `AbstractInitStrategy`) now has a different signature: it must return a tuple of the generated value, plus a transform function that maps it back to unlinked space. This is a generalisation of the previous behaviour, where `init` would always return an unlinked value (in effect forcing the transform to be the identity function). -### Other changes - -#### FastLDF - -Added `DynamicPPL.Experimental.FastLDF`, a version of `LogDensityFunction` that provides performance improvements on the order of 2–10× for both model evaluation as well as automatic differentiation. -Exact speedups depend on the model size: larger models have less significant speedups because the bulk of the work is done in calls to `logpdf`. - -Please note that `FastLDF` is currently considered internal and its API may change without warning. -We intend to replace `LogDensityFunction` with `FastLDF` in a release in the near future, but until then we recommend not using it. - -For more information about `FastLDF`, please see https://github.com/TuringLang/DynamicPPL.jl/pull/1113 as well as the `src/fasteval.jl` file, which contains extensive comments. - ## 0.38.9 Remove warning when using Enzyme as the AD backend. diff --git a/docs/src/api.md b/docs/src/api.md index e81f18dc7..fecb7367e 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -66,6 +66,13 @@ The [LogDensityProblems.jl](https://github.com/tpapp/LogDensityProblems.jl) inte LogDensityFunction ``` +Internally, this is accomplished using: + +```@docs +OnlyAccsVarInfo +fast_evaluate!! +``` + ## Condition and decondition A [`Model`](@ref) can be conditioned on a set of observations with [`AbstractPPL.condition`](@ref) or its alias [`|`](@ref). diff --git a/ext/DynamicPPLMarginalLogDensitiesExt.jl b/ext/DynamicPPLMarginalLogDensitiesExt.jl index 2155fa161..8b3040757 100644 --- a/ext/DynamicPPLMarginalLogDensitiesExt.jl +++ b/ext/DynamicPPLMarginalLogDensitiesExt.jl @@ -6,8 +6,13 @@ using MarginalLogDensities: MarginalLogDensities # A thin wrapper to adapt a DynamicPPL.LogDensityFunction to the interface expected by # MarginalLogDensities. It's helpful to have a struct so that we can dispatch on its type # below. -struct LogDensityFunctionWrapper{L<:DynamicPPL.LogDensityFunction} +struct LogDensityFunctionWrapper{ + L<:DynamicPPL.LogDensityFunction,V<:DynamicPPL.AbstractVarInfo +} logdensity::L + # This field is used only to reconstruct the VarInfo later on; it's not needed for the + # actual log-density evaluation. + varinfo::V end function (lw::LogDensityFunctionWrapper)(x, _) return LogDensityProblems.logdensity(lw.logdensity, x) @@ -101,7 +106,7 @@ function DynamicPPL.marginalize( # Construct the marginal log-density model. f = DynamicPPL.LogDensityFunction(model, getlogprob, varinfo) mld = MarginalLogDensities.MarginalLogDensity( - LogDensityFunctionWrapper(f), varinfo[:], varindices, (), method; kwargs... + LogDensityFunctionWrapper(f, varinfo), varinfo[:], varindices, (), method; kwargs... ) return mld end @@ -190,7 +195,7 @@ function DynamicPPL.VarInfo( unmarginalized_params::Union{AbstractVector,Nothing}=nothing, ) # Extract the original VarInfo. Its contents will in general be junk. - original_vi = mld.logdensity.logdensity.varinfo + original_vi = mld.logdensity.varinfo # Extract the stored parameters, which includes the modes for any marginalized # parameters full_params = MarginalLogDensities.cached_params(mld) diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index 6d3900e91..61bf9a485 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -92,8 +92,10 @@ export AbstractVarInfo, getargnames, extract_priors, values_as_in_model, - # LogDensityFunction + # LogDensityFunction and fasteval LogDensityFunction, + fast_evaluate!!, + OnlyAccsVarInfo, # Leaf contexts AbstractContext, contextualize, @@ -198,7 +200,7 @@ include("simple_varinfo.jl") include("onlyaccs.jl") include("compiler.jl") include("pointwise_logdensities.jl") -include("logdensityfunction.jl") +include("fasteval.jl") include("model_utils.jl") include("extract_priors.jl") include("values_as_in_model.jl") diff --git a/src/experimental.jl b/src/experimental.jl index c644c09b2..8c82dca68 100644 --- a/src/experimental.jl +++ b/src/experimental.jl @@ -2,8 +2,6 @@ module Experimental using DynamicPPL: DynamicPPL -include("fasteval.jl") - # This file only defines the names of the functions, and their docstrings. The actual implementations are in `ext/DynamicPPLJETExt.jl`, since we don't want to depend on JET.jl other than as a weak dependency. """ is_suitable_varinfo(model::Model, varinfo::AbstractVarInfo; kwargs...) diff --git a/src/fasteval.jl b/src/fasteval.jl index 722760fa1..0becebb8e 100644 --- a/src/fasteval.jl +++ b/src/fasteval.jl @@ -30,11 +30,11 @@ import DifferentiationInterface as DI using Random: Random """ - DynamicPPL.Experimental.fast_evaluate!!( + DynamicPPL.fast_evaluate!!( [rng::Random.AbstractRNG,] model::Model, strategy::AbstractInitStrategy, - accs::AccumulatorTuple, params::AbstractVector{<:Real} + accs::AccumulatorTuple, ) Evaluate a model using parameters obtained via `strategy`, and only computing the results in @@ -84,7 +84,7 @@ end end """ - FastLDF( + DynamicPPL.LogDensityFunction( model::Model, getlogdensity::Function=getlogjoint_internal, varinfo::AbstractVarInfo=VarInfo(model); @@ -115,10 +115,10 @@ There are several options for `getlogdensity` that are 'supported' out of the bo since transforms are only applied to random variables) !!! note - By default, `FastLDF` uses `getlogjoint_internal`, i.e., the result of - `LogDensityProblems.logdensity(f, x)` will depend on whether the `FastLDF` was created - with a linked or unlinked VarInfo. This is done primarily to ease interoperability with - MCMC samplers. + By default, `LogDensityFunction` uses `getlogjoint_internal`, i.e., the result of + `LogDensityProblems.logdensity(f, x)` will depend on whether the `LogDensityFunction` + was created with a linked or unlinked VarInfo. This is done primarily to ease + interoperability with MCMC samplers. If you provide one of these functions, a `VarInfo` will be automatically created for you. If you provide a different function, you have to manually create a VarInfo and pass it as the @@ -126,15 +126,16 @@ third argument. If the `adtype` keyword argument is provided, then this struct will also store the adtype along with other information for efficient calculation of the gradient of the log density. -Note that preparing a `FastLDF` with an AD type `AutoBackend()` requires the AD backend -itself to have been loaded (e.g. with `import Backend`). +Note that preparing a `LogDensityFunction` with an AD type `AutoBackend()` requires the AD +backend itself to have been loaded (e.g. with `import Backend`). ## Fields -Note that it is undefined behaviour to access any of a `FastLDF`'s fields, apart from: +Note that it is undefined behaviour to access any of a `LogDensityFunction`'s fields, apart +from: -- `fastldf.model`: The original model from which this `FastLDF` was constructed. -- `fastldf.adtype`: The AD type used for gradient calculations, or `nothing` if no AD +- `ldf.model`: The original model from which this `LogDensityFunction` was constructed. +- `ldf.adtype`: The AD type used for gradient calculations, or `nothing` if no AD type was provided. # Extended help @@ -172,8 +173,9 @@ Traditionally, this problem has been solved by `unflatten`, because that functio place values into the VarInfo's metadata alongside the information about ranges and linking. That way, when we evaluate with `DefaultContext`, we can read this information out again. However, we want to avoid using a metadata. Thus, here, we _extract this information from -the VarInfo_ a single time when constructing a `FastLDF` object. Inside the FastLDF, we -store a mapping from VarNames to ranges in that vector, along with link status. +the VarInfo_ a single time when constructing a `LogDensityFunction` object. Inside the +LogDensityFunction, we store a mapping from VarNames to ranges in that vector, along with +link status. For VarNames with identity optics, this is stored in a NamedTuple for efficiency. For all other VarNames, this is stored in a Dict. The internal data structure used to represent this @@ -185,13 +187,13 @@ ranges to create an `InitFromParams{VectorWithRanges}`, which lets us very quick parameter values from the vector. Note that this assumes that the ranges and link status are static throughout the lifetime of -the `FastLDF` object. Therefore, a `FastLDF` object cannot handle models which have variable -numbers of parameters, or models which may visit random variables in different orders depending -on stochastic control flow. **Indeed, silent errors may occur with such models.** This is a -general limitation of vectorised parameters: the original `unflatten` + `evaluate!!` -approach also fails with such models. +the `LogDensityFunction` object. Therefore, a `LogDensityFunction` object cannot handle +models which have variable numbers of parameters, or models which may visit random variables +in different orders depending on stochastic control flow. **Indeed, silent errors may occur +with such models.** This is a general limitation of vectorised parameters: the original +`unflatten` + `evaluate!!` approach also fails with such models. """ -struct FastLDF{ +struct LogDensityFunction{ M<:Model, AD<:Union{ADTypes.AbstractADType,Nothing}, F<:Function, @@ -206,7 +208,7 @@ struct FastLDF{ _adprep::ADP _dim::Int - function FastLDF( + function LogDensityFunction( model::Model, getlogdensity::Function=getlogjoint_internal, varinfo::AbstractVarInfo=VarInfo(model); @@ -224,7 +226,7 @@ struct FastLDF{ # Make backend-specific tweaks to the adtype adtype = DynamicPPL.tweak_adtype(adtype, model, varinfo) DI.prepare_gradient( - FastLogDensityAt(model, getlogdensity, all_iden_ranges, all_ranges), + LogDensityAt(model, getlogdensity, all_iden_ranges, all_ranges), adtype, x, ) @@ -261,56 +263,73 @@ end fast_ldf_accs(::typeof(getlogprior)) = AccumulatorTuple((LogPriorAccumulator(),)) fast_ldf_accs(::typeof(getloglikelihood)) = AccumulatorTuple((LogLikelihoodAccumulator(),)) -struct FastLogDensityAt{M<:Model,F<:Function,N<:NamedTuple} +struct LogDensityAt{M<:Model,F<:Function,N<:NamedTuple} model::M getlogdensity::F iden_varname_ranges::N varname_ranges::Dict{VarName,RangeAndLinked} end -function (f::FastLogDensityAt)(params::AbstractVector{<:Real}) +function (f::LogDensityAt)(params::AbstractVector{<:Real}) strategy = InitFromParams( VectorWithRanges(f.iden_varname_ranges, f.varname_ranges, params), nothing ) accs = fast_ldf_accs(f.getlogdensity) - _, vi = fast_evaluate!!(f.model, strategy, accs) + _, vi = DynamicPPL.fast_evaluate!!(f.model, strategy, accs) return f.getlogdensity(vi) end -function LogDensityProblems.logdensity(fldf::FastLDF, params::AbstractVector{<:Real}) - return FastLogDensityAt( - fldf.model, fldf._getlogdensity, fldf._iden_varname_ranges, fldf._varname_ranges +function LogDensityProblems.logdensity( + ldf::LogDensityFunction, params::AbstractVector{<:Real} +) + return LogDensityAt( + ldf.model, ldf._getlogdensity, ldf._iden_varname_ranges, ldf._varname_ranges )( params ) end function LogDensityProblems.logdensity_and_gradient( - fldf::FastLDF, params::AbstractVector{<:Real} + ldf::LogDensityFunction, params::AbstractVector{<:Real} ) return DI.value_and_gradient( - FastLogDensityAt( - fldf.model, fldf._getlogdensity, fldf._iden_varname_ranges, fldf._varname_ranges + LogDensityAt( + ldf.model, ldf._getlogdensity, ldf._iden_varname_ranges, ldf._varname_ranges ), - fldf._adprep, - fldf.adtype, + ldf._adprep, + ldf.adtype, params, ) end -function LogDensityProblems.capabilities( - ::Type{<:DynamicPPL.Experimental.FastLDF{M,Nothing}} -) where {M} +function LogDensityProblems.capabilities(::Type{<:LogDensityFunction{M,Nothing}}) where {M} return LogDensityProblems.LogDensityOrder{0}() end function LogDensityProblems.capabilities( - ::Type{<:DynamicPPL.Experimental.FastLDF{M,<:ADTypes.AbstractADType}} + ::Type{<:LogDensityFunction{M,<:ADTypes.AbstractADType}} ) where {M} return LogDensityProblems.LogDensityOrder{1}() end -function LogDensityProblems.dimension(fldf::FastLDF) - return fldf._dim +function LogDensityProblems.dimension(ldf::LogDensityFunction) + return ldf._dim end +""" + tweak_adtype( + adtype::ADTypes.AbstractADType, + model::Model, + varinfo::AbstractVarInfo, + ) + +Return an 'optimised' form of the adtype. This is useful for doing +backend-specific optimisation of the adtype (e.g., for ForwardDiff, calculating +the chunk size: see the method override in `ext/DynamicPPLForwardDiffExt.jl`). +The model is passed as a parameter in case the optimisation depends on the +model. + +By default, this just returns the input unchanged. +""" +tweak_adtype(adtype::ADTypes.AbstractADType, ::Model, ::AbstractVarInfo) = adtype + ###################################################### # Helper functions to extract ranges and link status # ###################################################### diff --git a/src/logdensityfunction.jl b/src/logdensityfunction.jl deleted file mode 100644 index 7c7438c9f..000000000 --- a/src/logdensityfunction.jl +++ /dev/null @@ -1,377 +0,0 @@ -using AbstractMCMC: AbstractModel -import DifferentiationInterface as DI - -""" - is_supported(adtype::AbstractADType) - -Check if the given AD type is formally supported by DynamicPPL. - -AD backends that are not formally supported can still be used for gradient -calculation; it is just that the DynamicPPL developers do not commit to -maintaining compatibility with them. -""" -is_supported(::ADTypes.AbstractADType) = false -is_supported(::ADTypes.AutoEnzyme) = true -is_supported(::ADTypes.AutoForwardDiff) = true -is_supported(::ADTypes.AutoMooncake) = true -is_supported(::ADTypes.AutoReverseDiff) = true - -""" - LogDensityFunction( - model::Model, - getlogdensity::Function=getlogjoint_internal, - varinfo::AbstractVarInfo=ldf_default_varinfo(model, getlogdensity); - adtype::Union{ADTypes.AbstractADType,Nothing}=nothing - ) - -A struct which contains a model, along with all the information necessary to: - - - calculate its log density at a given point; - - and if `adtype` is provided, calculate the gradient of the log density at - that point. - -This information can be extracted using the LogDensityProblems.jl interface, -specifically, using `LogDensityProblems.logdensity` and -`LogDensityProblems.logdensity_and_gradient`. If `adtype` is nothing, then only -`logdensity` is implemented. If `adtype` is a concrete AD backend type, then -`logdensity_and_gradient` is also implemented. - -There are several options for `getlogdensity` that are 'supported' out of the -box: - -- [`getlogjoint_internal`](@ref): calculate the log joint, including the - log-Jacobian term for any variables that have been linked in the provided - VarInfo. -- [`getlogprior_internal`](@ref): calculate the log prior, including the - log-Jacobian term for any variables that have been linked in the provided - VarInfo. -- [`getlogjoint`](@ref): calculate the log joint in the model space, ignoring - any effects of linking -- [`getlogprior`](@ref): calculate the log prior in the model space, ignoring - any effects of linking -- [`getloglikelihood`](@ref): calculate the log likelihood (this is unaffected - by linking, since transforms are only applied to random variables) - -!!! note - By default, `LogDensityFunction` uses `getlogjoint_internal`, i.e., the - result of `LogDensityProblems.logdensity(f, x)` will depend on whether the - `LogDensityFunction` was created with a linked or unlinked VarInfo. This - is done primarily to ease interoperability with MCMC samplers. - -If you provide one of these functions, a `VarInfo` will be automatically created -for you. If you provide a different function, you have to manually create a -VarInfo and pass it as the third argument. - -If the `adtype` keyword argument is provided, then this struct will also store -the adtype along with other information for efficient calculation of the -gradient of the log density. Note that preparing a `LogDensityFunction` with an -AD type `AutoBackend()` requires the AD backend itself to have been loaded -(e.g. with `import Backend`). - -# Fields -$(FIELDS) - -# Examples - -```jldoctest -julia> using Distributions - -julia> using DynamicPPL: LogDensityFunction, setaccs!! - -julia> @model function demo(x) - m ~ Normal() - x ~ Normal(m, 1) - end -demo (generic function with 2 methods) - -julia> model = demo(1.0); - -julia> f = LogDensityFunction(model); - -julia> # It implements the interface of LogDensityProblems.jl. - using LogDensityProblems - -julia> LogDensityProblems.logdensity(f, [0.0]) --2.3378770664093453 - -julia> LogDensityProblems.dimension(f) -1 - -julia> # By default it uses `VarInfo` under the hood, but this is not necessary. - f = LogDensityFunction(model, getlogjoint_internal, SimpleVarInfo(model)); - -julia> LogDensityProblems.logdensity(f, [0.0]) --2.3378770664093453 - -julia> # One can also specify evaluating e.g. the log prior only: - f_prior = LogDensityFunction(model, getlogprior); - -julia> LogDensityProblems.logdensity(f_prior, [0.0]) == logpdf(Normal(), 0.0) -true - -julia> # If we also need to calculate the gradient, we can specify an AD backend. - import ForwardDiff, ADTypes - -julia> f = LogDensityFunction(model, adtype=ADTypes.AutoForwardDiff()); - -julia> LogDensityProblems.logdensity_and_gradient(f, [0.0]) -(-2.3378770664093453, [1.0]) -``` -""" -struct LogDensityFunction{ - M<:Model,F<:Function,V<:AbstractVarInfo,AD<:Union{Nothing,ADTypes.AbstractADType} -} <: AbstractModel - "model used for evaluation" - model::M - "function to be called on `varinfo` to extract the log density. By default `getlogjoint_internal`." - getlogdensity::F - "varinfo used for evaluation. If not specified, generated with `ldf_default_varinfo`." - varinfo::V - "AD type used for evaluation of log density gradient. If `nothing`, no gradient can be calculated" - adtype::AD - "(internal use only) gradient preparation object for the model" - prep::Union{Nothing,DI.GradientPrep} - - function LogDensityFunction( - model::Model, - getlogdensity::Function=getlogjoint_internal, - varinfo::AbstractVarInfo=ldf_default_varinfo(model, getlogdensity); - adtype::Union{ADTypes.AbstractADType,Nothing}=nothing, - ) - if adtype === nothing - prep = nothing - else - # Make backend-specific tweaks to the adtype - adtype = tweak_adtype(adtype, model, varinfo) - # Check whether it is supported - is_supported(adtype) || - @warn "The AD backend $adtype is not officially supported by DynamicPPL. Gradient calculations may still work, but compatibility is not guaranteed." - # Get a set of dummy params to use for prep - x = [val for val in varinfo[:]] - if use_closure(adtype) - prep = DI.prepare_gradient( - LogDensityAt(model, getlogdensity, varinfo), adtype, x - ) - else - prep = DI.prepare_gradient( - logdensity_at, - adtype, - x, - DI.Constant(model), - DI.Constant(getlogdensity), - DI.Constant(varinfo), - ) - end - end - return new{typeof(model),typeof(getlogdensity),typeof(varinfo),typeof(adtype)}( - model, getlogdensity, varinfo, adtype, prep - ) - end -end - -""" - LogDensityFunction( - ldf::LogDensityFunction, - adtype::Union{Nothing,ADTypes.AbstractADType} - ) - -Create a new LogDensityFunction using the model and varinfo from the given -`ldf` argument, but with the AD type set to `adtype`. To remove the AD type, -pass `nothing` as the second argument. -""" -function LogDensityFunction( - f::LogDensityFunction, adtype::Union{Nothing,ADTypes.AbstractADType} -) - return if adtype === f.adtype - f # Avoid recomputing prep if not needed - else - LogDensityFunction(f.model, f.getlogdensity, f.varinfo; adtype=adtype) - end -end - -""" - ldf_default_varinfo(model::Model, getlogdensity::Function) - -Create the default AbstractVarInfo that should be used for evaluating the log density. - -Only the accumulators necesessary for `getlogdensity` will be used. -""" -function ldf_default_varinfo(::Model, getlogdensity::Function) - msg = """ - LogDensityFunction does not know what sort of VarInfo should be used when \ - `getlogdensity` is $getlogdensity. Please specify a VarInfo explicitly. - """ - return error(msg) -end - -ldf_default_varinfo(model::Model, ::typeof(getlogjoint_internal)) = VarInfo(model) - -function ldf_default_varinfo(model::Model, ::typeof(getlogjoint)) - return setaccs!!(VarInfo(model), (LogPriorAccumulator(), LogLikelihoodAccumulator())) -end - -function ldf_default_varinfo(model::Model, ::typeof(getlogprior_internal)) - return setaccs!!(VarInfo(model), (LogPriorAccumulator(), LogJacobianAccumulator())) -end - -function ldf_default_varinfo(model::Model, ::typeof(getlogprior)) - return setaccs!!(VarInfo(model), (LogPriorAccumulator(),)) -end - -function ldf_default_varinfo(model::Model, ::typeof(getloglikelihood)) - return setaccs!!(VarInfo(model), (LogLikelihoodAccumulator(),)) -end - -""" - logdensity_at( - x::AbstractVector, - model::Model, - getlogdensity::Function, - varinfo::AbstractVarInfo, - ) - -Evaluate the log density of the given `model` at the given parameter values -`x`, using the given `varinfo`. Note that the `varinfo` argument is provided -only for its structure, in the sense that the parameters from the vector `x` -are inserted into it, and its own parameters are discarded. `getlogdensity` is -the function that extracts the log density from the evaluated varinfo. -""" -function logdensity_at( - x::AbstractVector, model::Model, getlogdensity::Function, varinfo::AbstractVarInfo -) - varinfo_new = unflatten(varinfo, x) - varinfo_eval = last(evaluate!!(model, varinfo_new)) - return getlogdensity(varinfo_eval) -end - -""" - LogDensityAt{M<:Model,F<:Function,V<:AbstractVarInfo}( - model::M - getlogdensity::F, - varinfo::V - ) - -A callable struct that serves the same purpose as `x -> logdensity_at(x, model, -getlogdensity, varinfo)`. -""" -struct LogDensityAt{M<:Model,F<:Function,V<:AbstractVarInfo} - model::M - getlogdensity::F - varinfo::V -end -function (ld::LogDensityAt)(x::AbstractVector) - return logdensity_at(x, ld.model, ld.getlogdensity, ld.varinfo) -end - -### LogDensityProblems interface - -function LogDensityProblems.capabilities( - ::Type{<:LogDensityFunction{M,F,V,Nothing}} -) where {M,F,V} - return LogDensityProblems.LogDensityOrder{0}() -end -function LogDensityProblems.capabilities( - ::Type{<:LogDensityFunction{M,F,V,AD}} -) where {M,F,V,AD<:ADTypes.AbstractADType} - return LogDensityProblems.LogDensityOrder{1}() -end -function LogDensityProblems.logdensity(f::LogDensityFunction, x::AbstractVector) - return logdensity_at(x, f.model, f.getlogdensity, f.varinfo) -end -function LogDensityProblems.logdensity_and_gradient( - f::LogDensityFunction{M,F,V,AD}, x::AbstractVector -) where {M,F,V,AD<:ADTypes.AbstractADType} - f.prep === nothing && - error("Gradient preparation not available; this should not happen") - x = [val for val in x] # Concretise type - # Make branching statically inferrable, i.e. type-stable (even if the two - # branches happen to return different types) - return if use_closure(f.adtype) - DI.value_and_gradient( - LogDensityAt(f.model, f.getlogdensity, f.varinfo), f.prep, f.adtype, x - ) - else - DI.value_and_gradient( - logdensity_at, - f.prep, - f.adtype, - x, - DI.Constant(f.model), - DI.Constant(f.getlogdensity), - DI.Constant(f.varinfo), - ) - end -end - -# TODO: should we instead implement and call on `length(f.varinfo)` (at least in the cases where no sampler is involved)? -LogDensityProblems.dimension(f::LogDensityFunction) = length(getparams(f)) - -### Utils - -""" - tweak_adtype( - adtype::ADTypes.AbstractADType, - model::Model, - varinfo::AbstractVarInfo, - ) - -Return an 'optimised' form of the adtype. This is useful for doing -backend-specific optimisation of the adtype (e.g., for ForwardDiff, calculating -the chunk size: see the method override in `ext/DynamicPPLForwardDiffExt.jl`). -The model is passed as a parameter in case the optimisation depends on the -model. - -By default, this just returns the input unchanged. -""" -tweak_adtype(adtype::ADTypes.AbstractADType, ::Model, ::AbstractVarInfo) = adtype - -""" - use_closure(adtype::ADTypes.AbstractADType) - -In LogDensityProblems, we want to calculate the derivative of logdensity(f, x) -with respect to x, where f is the model (in our case LogDensityFunction) and is -a constant. However, DifferentiationInterface generally expects a -single-argument function g(x) to differentiate. - -There are two ways of dealing with this: - -1. Construct a closure over the model, i.e. let g = Base.Fix1(logdensity, f) - -2. Use a constant DI.Context. This lets us pass a two-argument function to DI, - as long as we also give it the 'inactive argument' (i.e. the model) wrapped - in `DI.Constant`. - -The relative performance of the two approaches, however, depends on the AD -backend used. Some benchmarks are provided here: -https://github.com/TuringLang/DynamicPPL.jl/issues/946#issuecomment-2931604829 - -This function is used to determine whether a given AD backend should use a -closure or a constant. If `use_closure(adtype)` returns `true`, then the -closure approach will be used. By default, this function returns `false`, i.e. -the constant approach will be used. -""" -use_closure(::ADTypes.AbstractADType) = true -use_closure(::ADTypes.AutoEnzyme) = false - -""" - getmodel(f) - -Return the `DynamicPPL.Model` wrapped in the given log-density function `f`. -""" -getmodel(f::DynamicPPL.LogDensityFunction) = f.model - -""" - setmodel(f, model[, adtype]) - -Set the `DynamicPPL.Model` in the given log-density function `f` to `model`. -""" -function setmodel(f::DynamicPPL.LogDensityFunction, model::DynamicPPL.Model) - return LogDensityFunction(model, f.getlogdensity, f.varinfo; adtype=f.adtype) -end - -""" - getparams(f::LogDensityFunction) - -Return the parameters of the wrapped varinfo as a vector. -""" -getparams(f::LogDensityFunction) = f.varinfo[:] diff --git a/test/ad.jl b/test/ad.jl deleted file mode 100644 index 0236c232f..000000000 --- a/test/ad.jl +++ /dev/null @@ -1,137 +0,0 @@ -using DynamicPPL: LogDensityFunction -using DynamicPPL.TestUtils.AD: run_ad, WithExpectedResult, NoTest - -@testset "Automatic differentiation" begin - # Used as the ground truth that others are compared against. - ref_adtype = AutoForwardDiff() - - test_adtypes = [ - AutoReverseDiff(; compile=false), - AutoReverseDiff(; compile=true), - AutoMooncake(; config=nothing), - ] - - @testset "Unsupported backends" begin - @model demo() = x ~ Normal() - @test_logs (:warn, r"not officially supported") LogDensityFunction( - demo(); adtype=AutoZygote() - ) - end - - @testset "Correctness" begin - @testset "$(m.f)" for m in DynamicPPL.TestUtils.DEMO_MODELS - rand_param_values = DynamicPPL.TestUtils.rand_prior_true(m) - vns = DynamicPPL.TestUtils.varnames(m) - varinfos = DynamicPPL.TestUtils.setup_varinfos(m, rand_param_values, vns) - - @testset "$(short_varinfo_name(varinfo))" for varinfo in varinfos - linked_varinfo = DynamicPPL.link(varinfo, m) - f = LogDensityFunction(m, getlogjoint_internal, linked_varinfo) - x = DynamicPPL.getparams(f) - - # Calculate reference logp + gradient of logp using ForwardDiff - ref_ad_result = run_ad(m, ref_adtype; varinfo=linked_varinfo, test=NoTest()) - ref_logp, ref_grad = ref_ad_result.value_actual, ref_ad_result.grad_actual - - @testset "$adtype" for adtype in test_adtypes - @info "Testing AD on: $(m.f) - $(short_varinfo_name(linked_varinfo)) - $adtype" - - # Put predicates here to avoid long lines - is_mooncake = adtype isa AutoMooncake - is_1_10 = v"1.10" <= VERSION < v"1.11" - is_1_11_or_1_12 = v"1.11" <= VERSION < v"1.13" - is_svi_vnv = - linked_varinfo isa SimpleVarInfo{<:DynamicPPL.VarNamedVector} - is_svi_od = linked_varinfo isa SimpleVarInfo{<:OrderedDict} - - # Mooncake doesn't work with several combinations of SimpleVarInfo. - if is_mooncake && is_1_11_or_1_12 && is_svi_vnv - # https://github.com/compintell/Mooncake.jl/issues/470 - @test_throws ArgumentError DynamicPPL.LogDensityFunction( - m, getlogjoint_internal, linked_varinfo; adtype=adtype - ) - elseif is_mooncake && is_1_10 && is_svi_vnv - # TODO: report upstream - @test_throws UndefRefError DynamicPPL.LogDensityFunction( - m, getlogjoint_internal, linked_varinfo; adtype=adtype - ) - elseif is_mooncake && is_1_10 && is_svi_od - # TODO: report upstream - @test_throws Mooncake.MooncakeRuleCompilationError DynamicPPL.LogDensityFunction( - m, getlogjoint_internal, linked_varinfo; adtype=adtype - ) - else - @test run_ad( - m, - adtype; - varinfo=linked_varinfo, - test=WithExpectedResult(ref_logp, ref_grad), - ) isa Any - end - end - end - end - end - - # Test that various different ways of specifying array types as arguments work with all - # ADTypes. - @testset "Array argument types" begin - test_m = randn(2, 3) - - function eval_logp_and_grad(model, m, adtype) - ldf = LogDensityFunction(model(); adtype=adtype) - return LogDensityProblems.logdensity_and_gradient(ldf, m[:]) - end - - @model function scalar_matrix_model(::Type{T}=Float64) where {T<:Real} - m = Matrix{T}(undef, 2, 3) - return m ~ filldist(MvNormal(zeros(2), I), 3) - end - - scalar_matrix_model_reference = eval_logp_and_grad( - scalar_matrix_model, test_m, ref_adtype - ) - - @model function matrix_model(::Type{T}=Matrix{Float64}) where {T} - m = T(undef, 2, 3) - return m ~ filldist(MvNormal(zeros(2), I), 3) - end - - matrix_model_reference = eval_logp_and_grad(matrix_model, test_m, ref_adtype) - - @model function scalar_array_model(::Type{T}=Float64) where {T<:Real} - m = Array{T}(undef, 2, 3) - return m ~ filldist(MvNormal(zeros(2), I), 3) - end - - scalar_array_model_reference = eval_logp_and_grad( - scalar_array_model, test_m, ref_adtype - ) - - @model function array_model(::Type{T}=Array{Float64}) where {T} - m = T(undef, 2, 3) - return m ~ filldist(MvNormal(zeros(2), I), 3) - end - - array_model_reference = eval_logp_and_grad(array_model, test_m, ref_adtype) - - @testset "$adtype" for adtype in test_adtypes - scalar_matrix_model_logp_and_grad = eval_logp_and_grad( - scalar_matrix_model, test_m, adtype - ) - @test scalar_matrix_model_logp_and_grad[1] ≈ scalar_matrix_model_reference[1] - @test scalar_matrix_model_logp_and_grad[2] ≈ scalar_matrix_model_reference[2] - matrix_model_logp_and_grad = eval_logp_and_grad(matrix_model, test_m, adtype) - @test matrix_model_logp_and_grad[1] ≈ matrix_model_reference[1] - @test matrix_model_logp_and_grad[2] ≈ matrix_model_reference[2] - scalar_array_model_logp_and_grad = eval_logp_and_grad( - scalar_array_model, test_m, adtype - ) - @test scalar_array_model_logp_and_grad[1] ≈ scalar_array_model_reference[1] - @test scalar_array_model_logp_and_grad[2] ≈ scalar_array_model_reference[2] - array_model_logp_and_grad = eval_logp_and_grad(array_model, test_m, adtype) - @test array_model_logp_and_grad[1] ≈ array_model_reference[1] - @test array_model_logp_and_grad[2] ≈ array_model_reference[2] - end - end -end diff --git a/test/chains.jl b/test/chains.jl index 43b877d62..12a9ece71 100644 --- a/test/chains.jl +++ b/test/chains.jl @@ -66,7 +66,7 @@ using Test end end -@testset "ParamsWithStats from FastLDF" begin +@testset "ParamsWithStats from LogDensityFunction" begin @testset "$(m.f)" for m in DynamicPPL.TestUtils.DEMO_MODELS unlinked_vi = VarInfo(m) @testset "$islinked" for islinked in (false, true) @@ -77,9 +77,9 @@ end end params = [x for x in vi[:]] - # Get the ParamsWithStats using FastLDF - fldf = DynamicPPL.Experimental.FastLDF(m, getlogjoint, vi) - ps = ParamsWithStats(params, fldf) + # Get the ParamsWithStats using LogDensityFunction + ldf = DynamicPPL.LogDensityFunction(m, getlogjoint, vi) + ps = ParamsWithStats(params, ldf) # Check that length of parameters is as expected @test length(ps.params) == length(keys(vi)) diff --git a/test/fasteval.jl b/test/fasteval.jl index db2333711..d740d4b6c 100644 --- a/test/fasteval.jl +++ b/test/fasteval.jl @@ -1,4 +1,4 @@ -module DynamicPPLFastLDFTests +module DynamicPPLFastEvalTests using AbstractPPL: AbstractPPL using Chairmarks @@ -6,7 +6,6 @@ using DynamicPPL using Distributions using DistributionsAD: filldist using ADTypes -using DynamicPPL.Experimental: FastLDF using DynamicPPL.TestUtils.AD: run_ad, WithExpectedResult, NoTest using LinearAlgebra: I using Test @@ -14,14 +13,9 @@ using LogDensityProblems: LogDensityProblems using ForwardDiff: ForwardDiff using ReverseDiff: ReverseDiff -# Need to include this block here in case we run this test file standalone -@static if VERSION < v"1.12" - using Pkg - Pkg.add("Mooncake") - using Mooncake: Mooncake -end +using Mooncake: Mooncake -@testset "FastLDF: Correctness" begin +@testset "LogDensityFunction: Correctness" begin @testset "$(m.f)" for m in DynamicPPL.TestUtils.DEMO_MODELS @testset "$varinfo_func" for varinfo_func in [ DynamicPPL.untyped_varinfo, @@ -36,7 +30,7 @@ end else unlinked_vi end - nt_ranges, dict_ranges = DynamicPPL.Experimental.get_ranges_and_linked(vi) + nt_ranges, dict_ranges = DynamicPPL.get_ranges_and_linked(vi) params = [x for x in vi[:]] # Iterate over all variables for vn in keys(vi) @@ -52,26 +46,6 @@ end # Check that the link status is correct @test range_with_linked.is_linked == islinked end - - # Compare results of FastLDF vs ordinary LogDensityFunction. These tests - # can eventually go once we replace LogDensityFunction with FastLDF, but - # for now it helps to have this check! (Eventually we should just check - # against manually computed log-densities). - # - # TODO(penelopeysm): I think we need to add tests for some really - # pathological models here. - @testset "$getlogdensity" for getlogdensity in ( - DynamicPPL.getlogjoint_internal, - DynamicPPL.getlogjoint, - DynamicPPL.getloglikelihood, - DynamicPPL.getlogprior_internal, - DynamicPPL.getlogprior, - ) - ldf = DynamicPPL.LogDensityFunction(m, getlogdensity, vi) - fldf = FastLDF(m, getlogdensity, vi) - @test LogDensityProblems.logdensity(ldf, params) ≈ - LogDensityProblems.logdensity(fldf, params) - end end end end @@ -86,7 +60,7 @@ end end N = 100 model = threaded(zeros(N)) - ldf = DynamicPPL.Experimental.FastLDF(model) + ldf = DynamicPPL.LogDensityFunction(model) xs = [1.0] @test LogDensityProblems.logdensity(ldf, xs) ≈ @@ -95,7 +69,7 @@ end end end -@testset "FastLDF: performance" begin +@testset "LogDensityFunction: performance" begin if Threads.nthreads() == 1 # Evaluating these three models should not lead to any allocations (but only when # not using TSVI). @@ -119,35 +93,29 @@ end @testset for model in (f(), submodel_inner() | (; s=0.0), submodel_outer(submodel_inner())) vi = VarInfo(model) - fldf = DynamicPPL.Experimental.FastLDF( - model, DynamicPPL.getlogjoint_internal, vi - ) + ldf = DynamicPPL.LogDensityFunction(model, DynamicPPL.getlogjoint_internal, vi) x = vi[:] - bench = median(@be LogDensityProblems.logdensity(fldf, x)) + bench = median(@be LogDensityProblems.logdensity(ldf, x)) @test iszero(bench.allocs) end end end -@testset "AD with FastLDF" begin +@testset "AD with LogDensityFunction" begin # Used as the ground truth that others are compared against. ref_adtype = AutoForwardDiff() - test_adtypes = @static if VERSION < v"1.12" - [ - AutoReverseDiff(; compile=false), - AutoReverseDiff(; compile=true), - AutoMooncake(; config=nothing), - ] - else - [AutoReverseDiff(; compile=false), AutoReverseDiff(; compile=true)] - end + test_adtypes = [ + AutoReverseDiff(; compile=false), + AutoReverseDiff(; compile=true), + AutoMooncake(; config=nothing), + ] @testset "Correctness" begin @testset "$(m.f)" for m in DynamicPPL.TestUtils.DEMO_MODELS varinfo = VarInfo(m) linked_varinfo = DynamicPPL.link(varinfo, m) - f = FastLDF(m, getlogjoint_internal, linked_varinfo) + f = LogDensityFunction(m, getlogjoint_internal, linked_varinfo) x = [p for p in linked_varinfo[:]] # Calculate reference logp + gradient of logp using ForwardDiff @@ -173,7 +141,7 @@ end test_m = randn(2, 3) function eval_logp_and_grad(model, m, adtype) - ldf = FastLDF(model(); adtype=adtype) + ldf = LogDensityFunction(model(); adtype=adtype) return LogDensityProblems.logdensity_and_gradient(ldf, m[:]) end diff --git a/test/logdensityfunction.jl b/test/logdensityfunction.jl deleted file mode 100644 index fbd868f71..000000000 --- a/test/logdensityfunction.jl +++ /dev/null @@ -1,49 +0,0 @@ -using Test, DynamicPPL, ADTypes, LogDensityProblems, ForwardDiff - -@testset "`getmodel` and `setmodel`" begin - @testset "$(nameof(model))" for model in DynamicPPL.TestUtils.DEMO_MODELS - model = DynamicPPL.TestUtils.DEMO_MODELS[1] - ℓ = DynamicPPL.LogDensityFunction(model) - @test DynamicPPL.getmodel(ℓ) == model - @test DynamicPPL.setmodel(ℓ, model).model == model - end -end - -@testset "LogDensityFunction" begin - @testset "$(nameof(model))" for model in DynamicPPL.TestUtils.DEMO_MODELS - example_values = DynamicPPL.TestUtils.rand_prior_true(model) - vns = DynamicPPL.TestUtils.varnames(model) - varinfos = DynamicPPL.TestUtils.setup_varinfos(model, example_values, vns) - - vi = first(varinfos) - theta = vi[:] - ldf_joint = DynamicPPL.LogDensityFunction(model) - @test LogDensityProblems.logdensity(ldf_joint, theta) ≈ logjoint(model, vi) - ldf_prior = DynamicPPL.LogDensityFunction(model, getlogprior) - @test LogDensityProblems.logdensity(ldf_prior, theta) ≈ logprior(model, vi) - ldf_likelihood = DynamicPPL.LogDensityFunction(model, getloglikelihood) - @test LogDensityProblems.logdensity(ldf_likelihood, theta) ≈ - loglikelihood(model, vi) - - @testset "$(varinfo)" for varinfo in varinfos - # Note use of `getlogjoint` rather than `getlogjoint_internal` here ... - logdensity = DynamicPPL.LogDensityFunction(model, getlogjoint, varinfo) - θ = varinfo[:] - # ... because it has to match with `logjoint(model, vi)`, which always returns - # the unlinked value - @test LogDensityProblems.logdensity(logdensity, θ) ≈ logjoint(model, varinfo) - @test LogDensityProblems.dimension(logdensity) == length(θ) - end - end - - @testset "capabilities" begin - model = DynamicPPL.TestUtils.DEMO_MODELS[1] - ldf = DynamicPPL.LogDensityFunction(model) - @test LogDensityProblems.capabilities(typeof(ldf)) == - LogDensityProblems.LogDensityOrder{0}() - - ldf_with_ad = DynamicPPL.LogDensityFunction(model; adtype=AutoForwardDiff()) - @test LogDensityProblems.capabilities(typeof(ldf_with_ad)) == - LogDensityProblems.LogDensityOrder{1}() - end -end diff --git a/test/runtests.jl b/test/runtests.jl index 1474b426a..84aa00982 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -13,6 +13,7 @@ using ForwardDiff using LogDensityProblems using MacroTools using MCMCChains +using Mooncake using StableRNGs using ReverseDiff using Mooncake @@ -57,7 +58,9 @@ include("test_util.jl") include("simple_varinfo.jl") include("model.jl") include("distribution_wrappers.jl") - include("logdensityfunction.jl") + end + + if GROUP == "All" || GROUP == "Group2" include("linking.jl") include("serialization.jl") include("pointwise_logdensities.jl") @@ -69,9 +72,7 @@ include("test_util.jl") include("submodels.jl") include("chains.jl") include("bijector.jl") - end - - if GROUP == "All" || GROUP == "Group2" + include("fasteval.jl") @testset "extensions" begin include("ext/DynamicPPLMCMCChainsExt.jl") include("ext/DynamicPPLJETExt.jl") @@ -80,8 +81,6 @@ include("test_util.jl") @testset "ad" begin include("ext/DynamicPPLForwardDiffExt.jl") include("ext/DynamicPPLMooncakeExt.jl") - include("ad.jl") - include("fasteval.jl") end @testset "prob and logprob macro" begin @test_throws ErrorException prob"..." From e1652492c4e5833f236b8d35184cabb7acb6ad3a Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Tue, 25 Nov 2025 00:46:04 +0000 Subject: [PATCH 2/4] Add miscellaneous LogDensityProblems tests --- test/fasteval.jl | 39 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 39 insertions(+) diff --git a/test/fasteval.jl b/test/fasteval.jl index d740d4b6c..7a143a140 100644 --- a/test/fasteval.jl +++ b/test/fasteval.jl @@ -69,6 +69,45 @@ using Mooncake: Mooncake end end +@testset "LogDensityFunction: interface" begin + # miscellaneous parts of the LogDensityProblems interface + @testset "dimensions" begin + @model function m1() + x ~ Normal() + y ~ Normal() + return nothing + end + model = m1() + ldf = DynamicPPL.LogDensityFunction(model) + @test LogDensityProblems.dimension(ldf) == 2 + + @model function m2() + x ~ Dirichlet(ones(4)) + y ~ Categorical(x) + return nothing + end + model = m2() + ldf = DynamicPPL.LogDensityFunction(model) + @test LogDensityProblems.dimension(ldf) == 5 + linked_vi = DynamicPPL.link!!(VarInfo(model), model) + ldf = DynamicPPL.LogDensityFunction(model, getlogjoint_internal, linked_vi) + @test LogDensityProblems.dimension(ldf) == 4 + end + + @testset "capabilities" begin + @model f() = x ~ Normal() + model = f() + # No adtype + ldf = DynamicPPL.LogDensityFunction(model) + @test LogDensityProblems.capabilities(typeof(ldf)) == + LogDensityProblems.LogDensityOrder{0}() + # With adtype + ldf = DynamicPPL.LogDensityFunction(model; adtype=AutoForwardDiff()) + @test LogDensityProblems.capabilities(typeof(ldf)) == + LogDensityProblems.LogDensityOrder{1}() + end +end + @testset "LogDensityFunction: performance" begin if Threads.nthreads() == 1 # Evaluating these three models should not lead to any allocations (but only when From d1c002f4784648868369678c81ce2b5caa1aad4a Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Tue, 25 Nov 2025 02:22:31 +0000 Subject: [PATCH 3/4] Use `init!!` instead of `fast_evaluate!!` --- docs/src/api.md | 5 ++--- src/DynamicPPL.jl | 4 +++- src/chains.jl | 8 +++---- src/fasteval.jl | 56 +---------------------------------------------- src/model.jl | 54 +++++++++++++++++++++++++++++++++------------ 5 files changed, 49 insertions(+), 78 deletions(-) diff --git a/docs/src/api.md b/docs/src/api.md index fecb7367e..adb476db5 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -66,11 +66,10 @@ The [LogDensityProblems.jl](https://github.com/tpapp/LogDensityProblems.jl) inte LogDensityFunction ``` -Internally, this is accomplished using: +Internally, this is accomplished using [`init!!`](@ref) on: ```@docs OnlyAccsVarInfo -fast_evaluate!! ``` ## Condition and decondition @@ -517,7 +516,7 @@ The function `init!!` is used to initialise, or overwrite, values in a VarInfo. It is really a thin wrapper around using `evaluate!!` with an `InitContext`. ```@docs -DynamicPPL.init!! +init!! ``` To accomplish this, an initialisation _strategy_ is required, which defines how new values are to be obtained. diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index 61bf9a485..375cf731e 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -92,9 +92,11 @@ export AbstractVarInfo, getargnames, extract_priors, values_as_in_model, + # evaluation + evaluate!!, + init!!, # LogDensityFunction and fasteval LogDensityFunction, - fast_evaluate!!, OnlyAccsVarInfo, # Leaf contexts AbstractContext, diff --git a/src/chains.jl b/src/chains.jl index 892423822..f176b8e68 100644 --- a/src/chains.jl +++ b/src/chains.jl @@ -137,7 +137,7 @@ end """ ParamsWithStats( param_vector::AbstractVector, - ldf::DynamicPPL.Experimental.FastLDF, + ldf::DynamicPPL.LogDensityFunction, stats::NamedTuple=NamedTuple(); include_colon_eq::Bool=true, include_log_probs::Bool=true, @@ -156,7 +156,7 @@ via `unflatten` plus re-evaluation. It is faster for two reasons: """ function ParamsWithStats( param_vector::AbstractVector, - ldf::DynamicPPL.Experimental.FastLDF, + ldf::DynamicPPL.LogDensityFunction, stats::NamedTuple=NamedTuple(); include_colon_eq::Bool=true, include_log_probs::Bool=true, @@ -174,9 +174,7 @@ function ParamsWithStats( else (DynamicPPL.ValuesAsInModelAccumulator(include_colon_eq),) end - _, vi = DynamicPPL.Experimental.fast_evaluate!!( - ldf.model, strategy, AccumulatorTuple(accs) - ) + _, vi = DynamicPPL.init!!(ldf.model, OnlyAccsVarInfo(AccumulatorTuple(accs)), strategy) params = DynamicPPL.getacc(vi, Val(:ValuesAsInModel)).values if include_log_probs stats = merge( diff --git a/src/fasteval.jl b/src/fasteval.jl index 0becebb8e..65eab448e 100644 --- a/src/fasteval.jl +++ b/src/fasteval.jl @@ -29,60 +29,6 @@ using LogDensityProblems: LogDensityProblems import DifferentiationInterface as DI using Random: Random -""" - DynamicPPL.fast_evaluate!!( - [rng::Random.AbstractRNG,] - model::Model, - strategy::AbstractInitStrategy, - accs::AccumulatorTuple, - ) - -Evaluate a model using parameters obtained via `strategy`, and only computing the results in -the provided accumulators. - -It is assumed that the accumulators passed in have been initialised to appropriate values, -as this function will not reset them. The default constructors for each accumulator will do -this for you correctly. - -Returns a tuple of the model's return value, plus an `OnlyAccsVarInfo`. Note that the `accs` -argument may be mutated (depending on how the accumulators are implemented); hence the `!!` -in the function name. -""" -@inline function fast_evaluate!!( - # Note that this `@inline` is mandatory for performance. If it's not inlined, it leads - # to extra allocations (even for trivial models) and much slower runtime. - rng::Random.AbstractRNG, - model::Model, - strategy::AbstractInitStrategy, - accs::AccumulatorTuple, -) - ctx = InitContext(rng, strategy) - model = DynamicPPL.setleafcontext(model, ctx) - # Calling `evaluate!!` would be fine, but would lead to an extra call to resetaccs!!, - # which is unnecessary. So we shortcircuit this by simply calling `_evaluate!!` - # directly. To preserve thread-safety we need to reproduce the ThreadSafeVarInfo logic - # here. - # TODO(penelopeysm): This should _not_ check Threads.nthreads(). I still don't know what - # it _should_ do, but this is wrong regardless. - # https://github.com/TuringLang/DynamicPPL.jl/issues/1086 - vi = if Threads.nthreads() > 1 - param_eltype = DynamicPPL.get_param_eltype(strategy) - accs = map(accs) do acc - DynamicPPL.convert_eltype(float_type_with_fallback(param_eltype), acc) - end - ThreadSafeVarInfo(OnlyAccsVarInfo(accs)) - else - OnlyAccsVarInfo(accs) - end - return DynamicPPL._evaluate!!(model, vi) -end -@inline function fast_evaluate!!( - model::Model, strategy::AbstractInitStrategy, accs::AccumulatorTuple -) - # This `@inline` is also mandatory for performance - return fast_evaluate!!(Random.default_rng(), model, strategy, accs) -end - """ DynamicPPL.LogDensityFunction( model::Model, @@ -274,7 +220,7 @@ function (f::LogDensityAt)(params::AbstractVector{<:Real}) VectorWithRanges(f.iden_varname_ranges, f.varname_ranges, params), nothing ) accs = fast_ldf_accs(f.getlogdensity) - _, vi = DynamicPPL.fast_evaluate!!(f.model, strategy, accs) + _, vi = DynamicPPL.init!!(f.model, OnlyAccsVarInfo(accs), strategy) return f.getlogdensity(vi) end diff --git a/src/model.jl b/src/model.jl index 2bcfe8f98..9029318b1 100644 --- a/src/model.jl +++ b/src/model.jl @@ -881,30 +881,56 @@ end [init_strategy::AbstractInitStrategy=InitFromPrior()] ) -Evaluate the `model` and replace the values of the model's random variables -in the given `varinfo` with new values, using a specified initialisation strategy. -If the values in `varinfo` are not set, they will be added -using a specified initialisation strategy. +Evaluate the `model` and replace the values of the model's random variables in the given +`varinfo` with new values, using a specified initialisation strategy. If the values in +`varinfo` are not set, they will be added using a specified initialisation strategy. If `init_strategy` is not provided, defaults to `InitFromPrior()`. Returns a tuple of the model's return value, plus the updated `varinfo` object. """ -function init!!( +@inline function init!!( + # Note that this `@inline` is mandatory for performance, especially for + # LogDensityFunction. If it's not inlined, it leads to extra allocations (even for + # trivial models) and much slower runtime. rng::Random.AbstractRNG, model::Model, - varinfo::AbstractVarInfo, - init_strategy::AbstractInitStrategy=InitFromPrior(), + vi::AbstractVarInfo, + strategy::AbstractInitStrategy=InitFromPrior(), ) - new_model = setleafcontext(model, InitContext(rng, init_strategy)) - return evaluate!!(new_model, varinfo) + ctx = InitContext(rng, strategy) + model = DynamicPPL.setleafcontext(model, ctx) + # TODO(penelopeysm): This should _not_ check Threads.nthreads(). I still don't know what + # it _should_ do, but this is wrong regardless. + # https://github.com/TuringLang/DynamicPPL.jl/issues/1086 + return if Threads.nthreads() > 1 + # TODO(penelopeysm): The logic for setting eltype of accs is very similar to that + # used in `unflatten`. The reason why we need it here is because the VarInfo `vi` + # won't have been filled with parameters prior to `init!!` being called. + # + # Note that this eltype promotion is only needed for threadsafe evaluation. In an + # ideal world, this code should be handled inside `evaluate_threadsafe!!` or a + # similar method. In other words, it should not be here, and it should not be inside + # `unflatten` either. The problem is performance. Shifting this code around can have + # massive, inexplicable, impacts on performance. This should be investigated + # properly. + param_eltype = DynamicPPL.get_param_eltype(strategy) + accs = map(vi.accs) do acc + DynamicPPL.convert_eltype(float_type_with_fallback(param_eltype), acc) + end + vi = DynamicPPL.setaccs!!(vi, accs) + tsvi = ThreadSafeVarInfo(resetaccs!!(vi)) + retval, tsvi_new = DynamicPPL._evaluate!!(model, tsvi) + return retval, setaccs!!(tsvi_new.varinfo, DynamicPPL.getaccs(tsvi_new)) + else + return DynamicPPL._evaluate!!(model, resetaccs!!(vi)) + end end -function init!!( - model::Model, - varinfo::AbstractVarInfo, - init_strategy::AbstractInitStrategy=InitFromPrior(), +@inline function init!!( + model::Model, vi::AbstractVarInfo, strategy::AbstractInitStrategy=InitFromPrior() ) - return init!!(Random.default_rng(), model, varinfo, init_strategy) + # This `@inline` is also mandatory for performance + return init!!(Random.default_rng(), model, vi, strategy) end """ From 0c10eaa8f09d75c5f4b85e0bcd545ff9a1bf49a3 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Tue, 25 Nov 2025 11:11:35 +0000 Subject: [PATCH 4/4] Rename files, rebalance tests --- src/DynamicPPL.jl | 4 ++-- src/{fasteval.jl => logdensityfunction.jl} | 0 test/{fasteval.jl => logdensityfunction.jl} | 2 +- test/runtests.jl | 8 ++++---- 4 files changed, 7 insertions(+), 7 deletions(-) rename src/{fasteval.jl => logdensityfunction.jl} (100%) rename test/{fasteval.jl => logdensityfunction.jl} (99%) diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index 375cf731e..a885f6a96 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -95,7 +95,7 @@ export AbstractVarInfo, # evaluation evaluate!!, init!!, - # LogDensityFunction and fasteval + # LogDensityFunction LogDensityFunction, OnlyAccsVarInfo, # Leaf contexts @@ -202,7 +202,7 @@ include("simple_varinfo.jl") include("onlyaccs.jl") include("compiler.jl") include("pointwise_logdensities.jl") -include("fasteval.jl") +include("logdensityfunction.jl") include("model_utils.jl") include("extract_priors.jl") include("values_as_in_model.jl") diff --git a/src/fasteval.jl b/src/logdensityfunction.jl similarity index 100% rename from src/fasteval.jl rename to src/logdensityfunction.jl diff --git a/test/fasteval.jl b/test/logdensityfunction.jl similarity index 99% rename from test/fasteval.jl rename to test/logdensityfunction.jl index 7a143a140..06492d6e1 100644 --- a/test/fasteval.jl +++ b/test/logdensityfunction.jl @@ -1,4 +1,4 @@ -module DynamicPPLFastEvalTests +module DynamicPPLLDFTests using AbstractPPL: AbstractPPL using Chairmarks diff --git a/test/runtests.jl b/test/runtests.jl index 84aa00982..9649aebbb 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -58,9 +58,6 @@ include("test_util.jl") include("simple_varinfo.jl") include("model.jl") include("distribution_wrappers.jl") - end - - if GROUP == "All" || GROUP == "Group2" include("linking.jl") include("serialization.jl") include("pointwise_logdensities.jl") @@ -71,8 +68,11 @@ include("test_util.jl") include("debug_utils.jl") include("submodels.jl") include("chains.jl") + end + + if GROUP == "All" || GROUP == "Group2" include("bijector.jl") - include("fasteval.jl") + include("logdensityfunction.jl") @testset "extensions" begin include("ext/DynamicPPLMCMCChainsExt.jl") include("ext/DynamicPPLJETExt.jl")