Skip to content

Commit 284dc88

Browse files
committed
Move convert_eltype code to threadsafe eval function
1 parent f5e2e5a commit 284dc88

File tree

3 files changed

+21
-48
lines changed

3 files changed

+21
-48
lines changed

src/model.jl

Lines changed: 19 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -921,7 +921,7 @@ If `init_strategy` is not provided, defaults to `InitFromPrior()`.
921921
922922
Returns a tuple of the model's return value, plus the updated `varinfo` object.
923923
"""
924-
@inline function init!!(
924+
function init!!(
925925
# Note that this `@inline` is mandatory for performance, especially for
926926
# LogDensityFunction. If it's not inlined, it leads to extra allocations (even for
927927
# trivial models) and much slower runtime.
@@ -932,30 +932,9 @@ Returns a tuple of the model's return value, plus the updated `varinfo` object.
932932
)
933933
ctx = InitContext(rng, strategy)
934934
model = DynamicPPL.setleafcontext(model, ctx)
935-
return if _requires_threadsafe(model)
936-
# TODO(penelopeysm): The logic for setting eltype of accs is very similar to that
937-
# used in `unflatten`. The reason why we need it here is because the VarInfo `vi`
938-
# won't have been filled with parameters prior to `init!!` being called.
939-
#
940-
# Note that this eltype promotion is only needed for threadsafe evaluation. In an
941-
# ideal world, this code should be handled inside `evaluate_threadsafe!!` or a
942-
# similar method. In other words, it should not be here, and it should not be inside
943-
# `unflatten` either. The problem is performance. Shifting this code around can have
944-
# massive, inexplicable, impacts on performance. This should be investigated
945-
# properly.
946-
param_eltype = DynamicPPL.get_param_eltype(strategy)
947-
accs = map(vi.accs) do acc
948-
DynamicPPL.convert_eltype(float_type_with_fallback(param_eltype), acc)
949-
end
950-
vi = DynamicPPL.setaccs!!(vi, accs)
951-
tsvi = ThreadSafeVarInfo(resetaccs!!(vi))
952-
retval, tsvi_new = DynamicPPL._evaluate!!(model, tsvi)
953-
retval, setaccs!!(tsvi_new.varinfo, DynamicPPL.getaccs(tsvi_new))
954-
else
955-
DynamicPPL._evaluate!!(model, resetaccs!!(vi))
956-
end
935+
return DynamicPPL.evaluate!!(model, vi)
957936
end
958-
@inline function init!!(
937+
function init!!(
959938
model::Model, vi::AbstractVarInfo, strategy::AbstractInitStrategy=InitFromPrior()
960939
)
961940
# This `@inline` is also mandatory for performance
@@ -975,6 +954,22 @@ Returns a tuple of the model's return value, plus the updated `varinfo`
975954
"""
976955
function AbstractPPL.evaluate!!(model::Model, varinfo::AbstractVarInfo)
977956
return if _requires_threadsafe(model)
957+
# Use of float_type_with_fallback(eltype(x)) is necessary to deal with cases where x is
958+
# a gradient type of some AD backend.
959+
# TODO(mhauru) How could we do this more cleanly? The problem case is map_accumulator!!
960+
# for ThreadSafeVarInfo. In that one, if the map produces e.g a ForwardDiff.Dual, but
961+
# the accumulators in the VarInfo are plain floats, we error since we can't change the
962+
# element type of ThreadSafeVarInfo.accs_by_thread. However, doing this conversion here
963+
# messes with cases like using Float32 of logprobs and Float64 for x. Also, this is just
964+
# plain ugly and hacky.
965+
# The below line is finicky for type stability. For instance, assigning the eltype to
966+
# convert to into an intermediate variable makes this unstable (constant propagation)
967+
# fails. Take care when editing.
968+
param_eltype = DynamicPPL.get_param_eltype(varinfo, model.context)
969+
accs = map(DynamicPPL.getaccs(varinfo)) do acc
970+
DynamicPPL.convert_eltype(float_type_with_fallback(param_eltype), acc)
971+
end
972+
varinfo = DynamicPPL.setaccs!!(varinfo, accs)
978973
wrapper = ThreadSafeVarInfo(resetaccs!!(varinfo))
979974
result, wrapper_new = _evaluate!!(model, wrapper)
980975
# TODO(penelopeysm): If seems that if you pass a TSVI to this method, it

src/simple_varinfo.jl

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -278,15 +278,7 @@ end
278278

279279
function unflatten(svi::SimpleVarInfo, x::AbstractVector)
280280
vals = unflatten(svi.values, x)
281-
# TODO(mhauru) See comment in unflatten in src/varinfo.jl for why this conversion is
282-
# required but undesireable.
283-
# The below line is finicky for type stability. For instance, assigning the eltype to
284-
# convert to into an intermediate variable makes this unstable (constant propagation)
285-
# fails. Take care when editing.
286-
accs = map(
287-
acc -> convert_eltype(float_type_with_fallback(eltype(x)), acc), getaccs(svi)
288-
)
289-
return SimpleVarInfo(vals, accs, svi.transformation)
281+
return SimpleVarInfo(vals, svi.accs, svi.transformation)
290282
end
291283

292284
function BangBang.empty!!(vi::SimpleVarInfo)

src/varinfo.jl

Lines changed: 1 addition & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -367,21 +367,7 @@ vector_length(md::Metadata) = sum(length, md.ranges)
367367

368368
function unflatten(vi::VarInfo, x::AbstractVector)
369369
md = unflatten_metadata(vi.metadata, x)
370-
# Use of float_type_with_fallback(eltype(x)) is necessary to deal with cases where x is
371-
# a gradient type of some AD backend.
372-
# TODO(mhauru) How could we do this more cleanly? The problem case is map_accumulator!!
373-
# for ThreadSafeVarInfo. In that one, if the map produces e.g a ForwardDiff.Dual, but
374-
# the accumulators in the VarInfo are plain floats, we error since we can't change the
375-
# element type of ThreadSafeVarInfo.accs_by_thread. However, doing this conversion here
376-
# messes with cases like using Float32 of logprobs and Float64 for x. Also, this is just
377-
# plain ugly and hacky.
378-
# The below line is finicky for type stability. For instance, assigning the eltype to
379-
# convert to into an intermediate variable makes this unstable (constant propagation)
380-
# fails. Take care when editing.
381-
accs = map(
382-
acc -> convert_eltype(float_type_with_fallback(eltype(x)), acc), copy(getaccs(vi))
383-
)
384-
return VarInfo(md, accs)
370+
return VarInfo(md, vi.accs)
385371
end
386372

387373
# We would call this `unflatten` if not for `unflatten` having a method for NamedTuples in

0 commit comments

Comments
 (0)