Skip to content

Commit f69398e

Browse files
committed
OnlyAccsVarInfo -> AccumulatorTuple; fast_evaluate!! -> init!!
1 parent a960993 commit f69398e

File tree

8 files changed

+50
-128
lines changed

8 files changed

+50
-128
lines changed

docs/src/api.md

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -66,13 +66,6 @@ The [LogDensityProblems.jl](https://github.com/tpapp/LogDensityProblems.jl) inte
6666
LogDensityFunction
6767
```
6868

69-
Internally, this is accomplished using:
70-
71-
```@docs
72-
OnlyAccsVarInfo
73-
fast_evaluate!!
74-
```
75-
7669
## Condition and decondition
7770

7871
A [`Model`](@ref) can be conditioned on a set of observations with [`AbstractPPL.condition`](@ref) or its alias [`|`](@ref).
@@ -371,6 +364,7 @@ The subtypes of [`AbstractVarInfo`](@ref) store the cumulative log prior and log
371364

372365
```@docs
373366
AbstractAccumulator
367+
AccumulatorTuple
374368
```
375369

376370
DynamicPPL provides the following default accumulators.

src/DynamicPPL.jl

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ export AbstractVarInfo,
5151
LogLikelihoodAccumulator,
5252
LogPriorAccumulator,
5353
LogJacobianAccumulator,
54+
AccumulatorTuple,
5455
push!!,
5556
empty!!,
5657
subset,
@@ -92,10 +93,8 @@ export AbstractVarInfo,
9293
getargnames,
9394
extract_priors,
9495
values_as_in_model,
95-
# LogDensityFunction and fasteval
96+
# LogDensityFunction
9697
LogDensityFunction,
97-
fast_evaluate!!,
98-
OnlyAccsVarInfo,
9998
# Leaf contexts
10099
AbstractContext,
101100
contextualize,
@@ -197,7 +196,7 @@ include("abstract_varinfo.jl")
197196
include("threadsafe.jl")
198197
include("varinfo.jl")
199198
include("simple_varinfo.jl")
200-
include("onlyaccs.jl")
199+
include("accs_as_varinfo.jl")
201200
include("compiler.jl")
202201
include("pointwise_logdensities.jl")
203202
include("fasteval.jl")

src/accs_as_varinfo.jl

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
DynamicPPL.maybe_invlink_before_eval!!(at::AccumulatorTuple, ::Model) = at
2+
DynamicPPL.getaccs(at::AccumulatorTuple) = at
3+
DynamicPPL.setaccs!!(::AccumulatorTuple, new_at::AccumulatorTuple) = new_at
4+
5+
function tilde_assume!!(
6+
ctx::InitContext,
7+
dist::Distribution,
8+
vn::VarName,
9+
vi::Union{AccumulatorTuple,ThreadSafeVarInfo{<:AccumulatorTuple}},
10+
)
11+
# For AccumulatorTuple, since we don't need to write into the metadata part of the
12+
# VarInfo, we can cut out a lot of the code.
13+
val, transform = DynamicPPL.init(ctx.rng, vn, dist, ctx.strategy)
14+
x, inv_logjac = Bijectors.with_logabsdet_jacobian(transform, val)
15+
vi = DynamicPPL.accumulate_assume!!(vi, x, -inv_logjac, vn, dist)
16+
return x, vi
17+
end

src/accumulators.jl

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,7 @@ used by various AD backends, should implement a method for this function.
132132
convert_eltype(::Type, acc::AbstractAccumulator) = acc
133133

134134
"""
135-
AccumulatorTuple{N,T<:NamedTuple}
135+
AccumulatorTuple{N,T<:NamedTuple} <: AbstractVarInfo
136136
137137
A collection of accumulators, stored as a `NamedTuple` of length `N`
138138
@@ -144,8 +144,18 @@ constraint that the name in the tuple for each accumulator `acc` must be
144144
The constructor can be called with a tuple or a `VarArgs` of `AbstractAccumulators`. The
145145
names will be generated automatically. One can also call the constructor with a `NamedTuple`
146146
but the names in the argument will be discarded in favour of the generated ones.
147+
148+
`AccumulatorTuple` is a subtype of `AbstractVarInfo`, but in reality it only implements the
149+
part of the interface that deals with accumulators. That is, it implements `getaccs`,
150+
`setaccs!!`, and `maybe_invlink_before_eval!!`. These definitions allow several other
151+
'derived' methods to work automatically, such as `getlogjoint`, `accumulate_assume!!`, and
152+
`accumulate_observe!!`. Note that `maybe_invlink_before_eval!!` should not be needed if/when
153+
SimpleVarInfo is removed.
154+
155+
Unfortunately, because the necessary forward definitions are not present at this stage,
156+
these methods have to be defined in a separate file (`src/accs_as_varinfo.jl`).
147157
"""
148-
struct AccumulatorTuple{N,T<:NamedTuple}
158+
struct AccumulatorTuple{N,T<:NamedTuple} <: AbstractVarInfo
149159
nt::T
150160

151161
function AccumulatorTuple(t::T) where {N,T<:NTuple{N,AbstractAccumulator}}

src/chains.jl

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,7 @@ end
137137
"""
138138
ParamsWithStats(
139139
param_vector::AbstractVector,
140-
ldf::DynamicPPL.Experimental.FastLDF,
140+
ldf::DynamicPPL.LogDensityFunction,
141141
stats::NamedTuple=NamedTuple();
142142
include_colon_eq::Bool=true,
143143
include_log_probs::Bool=true,
@@ -152,11 +152,11 @@ via `unflatten` plus re-evaluation. It is faster for two reasons:
152152
1. It does not rely on `deepcopy`-ing the VarInfo object (this used to be mandatory as
153153
otherwise re-evaluation would mutate the VarInfo, rendering it unusable for subsequent
154154
MCMC iterations).
155-
2. The re-evaluation is faster as it uses `OnlyAccsVarInfo`.
155+
2. The re-evaluation is faster as it uses `AccumulatorTuple` rather than a full VarInfo.
156156
"""
157157
function ParamsWithStats(
158158
param_vector::AbstractVector,
159-
ldf::DynamicPPL.Experimental.FastLDF,
159+
ldf::DynamicPPL.LogDensityFunction,
160160
stats::NamedTuple=NamedTuple();
161161
include_colon_eq::Bool=true,
162162
include_log_probs::Bool=true,
@@ -174,9 +174,7 @@ function ParamsWithStats(
174174
else
175175
(DynamicPPL.ValuesAsInModelAccumulator(include_colon_eq),)
176176
end
177-
_, vi = DynamicPPL.Experimental.fast_evaluate!!(
178-
ldf.model, strategy, AccumulatorTuple(accs)
179-
)
177+
_, vi = DynamicPPL.init!!(ldf.model, AccumulatorTuple(accs), strategy)
180178
params = DynamicPPL.getacc(vi, Val(:ValuesAsInModel)).values
181179
if include_log_probs
182180
stats = merge(

src/contexts/init.jl

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,7 @@ This tells you how to obtain values for the random variable `vn` from `p.params`
142142
the last argument is `InitFromParams(params)`, not just `params` itself. Please see the
143143
docstring of [`DynamicPPL.init`](@ref) for more information on the expected behaviour.
144144
145-
If you only use `InitFromParams` with `DynamicPPL.OnlyAccsVarInfo`, as is usually the case,
145+
If you only use `InitFromParams` with `DynamicPPL.AccumulatorTuple`, as is usually the case,
146146
then you will not need to implement anything else. So far, this is the same as you would do
147147
for creating any new `AbstractInitStrategy` subtype.
148148
@@ -311,12 +311,13 @@ function tilde_assume!!(
311311
# calculation wastes a lot of time going from linked vectorised -> unlinked ->
312312
# linked, and `inv_logjac` will also just be the negative of `fwd_logjac`.
313313
#
314-
# However, `VectorWithRanges` is only really used with `OnlyAccsVarInfo`, in which
315-
# case this branch is never hit (since `in_varinfo` will always be false). It does
316-
# mean that the combination of InitFromParams{<:VectorWithRanges} with a full,
317-
# linked, VarInfo will be very slow. That should never really be used, though. So
318-
# (at least for now) we can leave this branch in for full generality with other
319-
# combinations of init strategies / VarInfo.
314+
# However, `VectorWithRanges` is only really used with `AccumulatorTuple`, in which
315+
# case this method is never hit (since there's a special method for it, in
316+
# `src/accs_as_varinfo.jl`). It does mean that the combination of
317+
# InitFromParams{<:VectorWithRanges} with a full, linked, VarInfo will be very slow.
318+
# That should never really be used, though. So (at least for now) we can leave this
319+
# branch in for full generality with other combinations of init strategies /
320+
# VarInfo.
320321
#
321322
# TODO(penelopeysm): Figure out one day how to refactor this. The crux of the issue
322323
# is that the transform used by `VectorWithRanges` is `from_linked_VEC_transform`,

src/fasteval.jl

Lines changed: 5 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@ using DynamicPPL:
1010
Model,
1111
ThreadSafeVarInfo,
1212
VarInfo,
13-
OnlyAccsVarInfo,
1413
RangeAndLinked,
1514
VectorWithRanges,
1615
Metadata,
@@ -29,60 +28,6 @@ using LogDensityProblems: LogDensityProblems
2928
import DifferentiationInterface as DI
3029
using Random: Random
3130

32-
"""
33-
DynamicPPL.fast_evaluate!!(
34-
[rng::Random.AbstractRNG,]
35-
model::Model,
36-
strategy::AbstractInitStrategy,
37-
accs::AccumulatorTuple,
38-
)
39-
40-
Evaluate a model using parameters obtained via `strategy`, and only computing the results in
41-
the provided accumulators.
42-
43-
It is assumed that the accumulators passed in have been initialised to appropriate values,
44-
as this function will not reset them. The default constructors for each accumulator will do
45-
this for you correctly.
46-
47-
Returns a tuple of the model's return value, plus an `OnlyAccsVarInfo`. Note that the `accs`
48-
argument may be mutated (depending on how the accumulators are implemented); hence the `!!`
49-
in the function name.
50-
"""
51-
@inline function fast_evaluate!!(
52-
# Note that this `@inline` is mandatory for performance. If it's not inlined, it leads
53-
# to extra allocations (even for trivial models) and much slower runtime.
54-
rng::Random.AbstractRNG,
55-
model::Model,
56-
strategy::AbstractInitStrategy,
57-
accs::AccumulatorTuple,
58-
)
59-
ctx = InitContext(rng, strategy)
60-
model = DynamicPPL.setleafcontext(model, ctx)
61-
# Calling `evaluate!!` would be fine, but would lead to an extra call to resetaccs!!,
62-
# which is unnecessary. So we shortcircuit this by simply calling `_evaluate!!`
63-
# directly. To preserve thread-safety we need to reproduce the ThreadSafeVarInfo logic
64-
# here.
65-
# TODO(penelopeysm): This should _not_ check Threads.nthreads(). I still don't know what
66-
# it _should_ do, but this is wrong regardless.
67-
# https://github.com/TuringLang/DynamicPPL.jl/issues/1086
68-
vi = if Threads.nthreads() > 1
69-
param_eltype = DynamicPPL.get_param_eltype(strategy)
70-
accs = map(accs) do acc
71-
DynamicPPL.convert_eltype(float_type_with_fallback(param_eltype), acc)
72-
end
73-
ThreadSafeVarInfo(OnlyAccsVarInfo(accs))
74-
else
75-
OnlyAccsVarInfo(accs)
76-
end
77-
return DynamicPPL._evaluate!!(model, vi)
78-
end
79-
@inline function fast_evaluate!!(
80-
model::Model, strategy::AbstractInitStrategy, accs::AccumulatorTuple
81-
)
82-
# This `@inline` is also mandatory for performance
83-
return fast_evaluate!!(Random.default_rng(), model, strategy, accs)
84-
end
85-
8631
"""
8732
DynamicPPL.LogDensityFunction(
8833
model::Model,
@@ -154,11 +99,11 @@ metadata can often be quite wasteful. In particular, it is very common that the
15499
we care about from model evaluation are those which are stored in accumulators, such as log
155100
probability densities, or `ValuesAsInModel`.
156101
157-
To avoid this issue, we use `OnlyAccsVarInfo`, which is a VarInfo that only contains
158-
accumulators. It implements enough of the `AbstractVarInfo` interface to not error during
159-
model evaluation.
102+
To avoid this issue, instead of evaluating a model with a full `VarInfo`, we use just an
103+
`AccumulatorTuple`. It implements enough of the `AbstractVarInfo` interface to not error
104+
during model evaluation.
160105
161-
Because `OnlyAccsVarInfo` does not store any parameter values, when evaluating a model with
106+
Because `AccumulatorTuple` does not store any parameter values, when evaluating a model with
162107
it, it is mandatory that parameters are provided from outside the VarInfo, namely via
163108
`InitContext`.
164109
@@ -274,7 +219,7 @@ function (f::LogDensityAt)(params::AbstractVector{<:Real})
274219
VectorWithRanges(f.iden_varname_ranges, f.varname_ranges, params), nothing
275220
)
276221
accs = fast_ldf_accs(f.getlogdensity)
277-
_, vi = DynamicPPL.fast_evaluate!!(f.model, strategy, accs)
222+
_, vi = DynamicPPL.init!!(f.model, accs, strategy)
278223
return f.getlogdensity(vi)
279224
end
280225

src/onlyaccs.jl

Lines changed: 0 additions & 42 deletions
This file was deleted.

0 commit comments

Comments
 (0)