Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions src/chains.jl
Original file line number Diff line number Diff line change
Expand Up @@ -156,13 +156,15 @@ via `unflatten` plus re-evaluation. It is faster for two reasons:
"""
function ParamsWithStats(
param_vector::AbstractVector,
ldf::DynamicPPL.LogDensityFunction,
ldf::DynamicPPL.LogDensityFunction{Tlink},
stats::NamedTuple=NamedTuple();
include_colon_eq::Bool=true,
include_log_probs::Bool=true,
)
) where {Tlink}
strategy = InitFromParams(
VectorWithRanges(ldf._iden_varname_ranges, ldf._varname_ranges, param_vector),
VectorWithRanges{Tlink}(
ldf._iden_varname_ranges, ldf._varname_ranges, param_vector
),
nothing,
)
accs = if include_log_probs
Expand Down
20 changes: 15 additions & 5 deletions src/contexts/init.jl
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,7 @@ struct RangeAndLinked
end

"""
VectorWithRanges(
VectorWithRanges{Tlink}(
iden_varname_ranges::NamedTuple,
varname_ranges::Dict{VarName,RangeAndLinked},
vect::AbstractVector{<:Real},
Expand All @@ -231,13 +231,19 @@ non-identity-optic VarNames are stored in the `varname_ranges` Dict.
It would be nice to improve the NamedTuple and Dict approach. See, e.g.
https://github.com/TuringLang/DynamicPPL.jl/issues/1116.
"""
struct VectorWithRanges{N<:NamedTuple,T<:AbstractVector{<:Real}}
struct VectorWithRanges{Tlink,N<:NamedTuple,T<:AbstractVector{<:Real}}
# This NamedTuple stores the ranges for identity VarNames
iden_varname_ranges::N
# This Dict stores the ranges for all other VarNames
varname_ranges::Dict{VarName,RangeAndLinked}
# The full parameter vector which we index into to get variable values
vect::T

function VectorWithRanges{Tlink}(
iden_varname_ranges::N, varname_ranges::Dict{VarName,RangeAndLinked}, vect::T
) where {Tlink,N,T}
return new{Tlink,N,T}(iden_varname_ranges, varname_ranges, vect)
end
end

function _get_range_and_linked(
Expand All @@ -252,11 +258,15 @@ function init(
::Random.AbstractRNG,
vn::VarName,
dist::Distribution,
p::InitFromParams{<:VectorWithRanges},
)
p::InitFromParams{<:VectorWithRanges{T}},
) where {T}
vr = p.params
range_and_linked = _get_range_and_linked(vr, vn)
transform = if range_and_linked.is_linked
# T can either be `nothing` (i.e., link status is mixed, in which
# case we use the stored link status), or `true` / `false`, which
# indicates that all variables are linked / unlinked.
linked = isnothing(T) ? range_and_linked.is_linked : T
transform = if linked
from_linked_vec_transform(dist)
else
from_vec_transform(dist)
Expand Down
56 changes: 43 additions & 13 deletions src/logdensityfunction.jl
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,9 @@ with such models.** This is a general limitation of vectorised parameters: the o
`unflatten` + `evaluate!!` approach also fails with such models.
"""
struct LogDensityFunction{
# true if all variables are linked; false if all variables are unlinked; nothing if
# mixed
Tlink,
M<:Model,
AD<:Union{ADTypes.AbstractADType,Nothing},
F<:Function,
Expand All @@ -163,6 +166,21 @@ struct LogDensityFunction{
# Figure out which variable corresponds to which index, and
# which variables are linked.
all_iden_ranges, all_ranges = get_ranges_and_linked(varinfo)
# Figure out if all variables are linked, unlinked, or mixed
link_statuses = Bool[]
for ral in all_iden_ranges
push!(link_statuses, ral.is_linked)
end
for (_, ral) in all_ranges
push!(link_statuses, ral.is_linked)
end
Tlink = if all(link_statuses)
true
elseif all(!s for s in link_statuses)
false
else
nothing
end
x = [val for val in varinfo[:]]
dim = length(x)
# Do AD prep if needed
Expand All @@ -172,12 +190,13 @@ struct LogDensityFunction{
# Make backend-specific tweaks to the adtype
adtype = DynamicPPL.tweak_adtype(adtype, model, varinfo)
DI.prepare_gradient(
LogDensityAt(model, getlogdensity, all_iden_ranges, all_ranges),
LogDensityAt{Tlink}(model, getlogdensity, all_iden_ranges, all_ranges),
adtype,
x,
)
end
return new{
Tlink,
typeof(model),
typeof(adtype),
typeof(getlogdensity),
Expand Down Expand Up @@ -209,36 +228,45 @@ end
fast_ldf_accs(::typeof(getlogprior)) = AccumulatorTuple((LogPriorAccumulator(),))
fast_ldf_accs(::typeof(getloglikelihood)) = AccumulatorTuple((LogLikelihoodAccumulator(),))

struct LogDensityAt{M<:Model,F<:Function,N<:NamedTuple}
struct LogDensityAt{Tlink,M<:Model,F<:Function,N<:NamedTuple}
model::M
getlogdensity::F
iden_varname_ranges::N
varname_ranges::Dict{VarName,RangeAndLinked}

function LogDensityAt{Tlink}(
model::M,
getlogdensity::F,
iden_varname_ranges::N,
varname_ranges::Dict{VarName,RangeAndLinked},
) where {Tlink,M,F,N}
return new{Tlink,M,F,N}(model, getlogdensity, iden_varname_ranges, varname_ranges)
end
end
function (f::LogDensityAt)(params::AbstractVector{<:Real})
function (f::LogDensityAt{Tlink})(params::AbstractVector{<:Real}) where {Tlink}
strategy = InitFromParams(
VectorWithRanges(f.iden_varname_ranges, f.varname_ranges, params), nothing
VectorWithRanges{Tlink}(f.iden_varname_ranges, f.varname_ranges, params), nothing
)
accs = fast_ldf_accs(f.getlogdensity)
_, vi = DynamicPPL.init!!(f.model, OnlyAccsVarInfo(accs), strategy)
return f.getlogdensity(vi)
end

function LogDensityProblems.logdensity(
ldf::LogDensityFunction, params::AbstractVector{<:Real}
)
return LogDensityAt(
ldf::LogDensityFunction{Tlink}, params::AbstractVector{<:Real}
) where {Tlink}
return LogDensityAt{Tlink}(
ldf.model, ldf._getlogdensity, ldf._iden_varname_ranges, ldf._varname_ranges
)(
params
)
end

function LogDensityProblems.logdensity_and_gradient(
ldf::LogDensityFunction, params::AbstractVector{<:Real}
)
ldf::LogDensityFunction{Tlink}, params::AbstractVector{<:Real}
) where {Tlink}
return DI.value_and_gradient(
LogDensityAt(
LogDensityAt{Tlink}(
ldf.model, ldf._getlogdensity, ldf._iden_varname_ranges, ldf._varname_ranges
),
ldf._adprep,
Expand All @@ -247,12 +275,14 @@ function LogDensityProblems.logdensity_and_gradient(
)
end

function LogDensityProblems.capabilities(::Type{<:LogDensityFunction{M,Nothing}}) where {M}
function LogDensityProblems.capabilities(
::Type{<:LogDensityFunction{T,M,Nothing}}
) where {T,M}
return LogDensityProblems.LogDensityOrder{0}()
end
function LogDensityProblems.capabilities(
::Type{<:LogDensityFunction{M,<:ADTypes.AbstractADType}}
) where {M}
::Type{<:LogDensityFunction{T,M,<:ADTypes.AbstractADType}}
) where {T,M}
return LogDensityProblems.LogDensityOrder{1}()
end
function LogDensityProblems.dimension(ldf::LogDensityFunction)
Expand Down
10 changes: 7 additions & 3 deletions test/integration/enzyme/main.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,15 @@ using Test: @test, @testset
import Enzyme: set_runtime_activity, Forward, Reverse, Const
using ForwardDiff: ForwardDiff # run_ad uses FD for correctness test

ADTYPES = Dict(
"EnzymeForward" =>
ADTYPES = (
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Entirely ambivalent about which constructor to use, but curious if you had a reason for changing.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, CI crashes with a Julia GC error when using a Dict.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(don't ask 🙃)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

...

(
"EnzymeForward",
AutoEnzyme(; mode=set_runtime_activity(Forward), function_annotation=Const),
"EnzymeReverse" =>
),
(
"EnzymeReverse",
AutoEnzyme(; mode=set_runtime_activity(Reverse), function_annotation=Const),
),
)

@testset "$ad_key" for (ad_key, ad_type) in ADTYPES
Expand Down
16 changes: 16 additions & 0 deletions test/logdensityfunction.jl
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,22 @@ end
end
end

@testset "LogDensityFunction: Type stability" begin
@testset "$(m.f)" for m in DynamicPPL.TestUtils.DEMO_MODELS
unlinked_vi = DynamicPPL.VarInfo(m)
@testset "$islinked" for islinked in (false, true)
vi = if islinked
DynamicPPL.link!!(unlinked_vi, m)
else
unlinked_vi
end
ldf = DynamicPPL.LogDensityFunction(m, DynamicPPL.getlogjoint_internal, vi)
x = vi[:]
@inferred LogDensityProblems.logdensity(ldf, x)
end
end
end

@testset "LogDensityFunction: performance" begin
if Threads.nthreads() == 1
# Evaluating these three models should not lead to any allocations (but only when
Expand Down