@@ -921,7 +921,7 @@ If `init_strategy` is not provided, defaults to `InitFromPrior()`.
921921
922922Returns a tuple of the model's return value, plus the updated `varinfo` object.
923923"""
924- @inline function init!! (
924+ function init!! (
925925 # Note that this `@inline` is mandatory for performance, especially for
926926 # LogDensityFunction. If it's not inlined, it leads to extra allocations (even for
927927 # trivial models) and much slower runtime.
@@ -932,30 +932,9 @@ Returns a tuple of the model's return value, plus the updated `varinfo` object.
932932)
933933 ctx = InitContext (rng, strategy)
934934 model = DynamicPPL. setleafcontext (model, ctx)
935- return if _requires_threadsafe (model)
936- # TODO (penelopeysm): The logic for setting eltype of accs is very similar to that
937- # used in `unflatten`. The reason why we need it here is because the VarInfo `vi`
938- # won't have been filled with parameters prior to `init!!` being called.
939- #
940- # Note that this eltype promotion is only needed for threadsafe evaluation. In an
941- # ideal world, this code should be handled inside `evaluate_threadsafe!!` or a
942- # similar method. In other words, it should not be here, and it should not be inside
943- # `unflatten` either. The problem is performance. Shifting this code around can have
944- # massive, inexplicable, impacts on performance. This should be investigated
945- # properly.
946- param_eltype = DynamicPPL. get_param_eltype (strategy)
947- accs = map (vi. accs) do acc
948- DynamicPPL. convert_eltype (float_type_with_fallback (param_eltype), acc)
949- end
950- vi = DynamicPPL. setaccs!! (vi, accs)
951- tsvi = ThreadSafeVarInfo (resetaccs!! (vi))
952- retval, tsvi_new = DynamicPPL. _evaluate!! (model, tsvi)
953- retval, setaccs!! (tsvi_new. varinfo, DynamicPPL. getaccs (tsvi_new))
954- else
955- DynamicPPL. _evaluate!! (model, resetaccs!! (vi))
956- end
935+ return DynamicPPL. evaluate!! (model, vi)
957936end
958- @inline function init!! (
937+ function init!! (
959938 model:: Model , vi:: AbstractVarInfo , strategy:: AbstractInitStrategy = InitFromPrior ()
960939)
961940 # This `@inline` is also mandatory for performance
@@ -975,6 +954,22 @@ Returns a tuple of the model's return value, plus the updated `varinfo`
975954"""
976955function AbstractPPL. evaluate!! (model:: Model , varinfo:: AbstractVarInfo )
977956 return if _requires_threadsafe (model)
957+ # Use of float_type_with_fallback(eltype(x)) is necessary to deal with cases where x is
958+ # a gradient type of some AD backend.
959+ # TODO (mhauru) How could we do this more cleanly? The problem case is map_accumulator!!
960+ # for ThreadSafeVarInfo. In that one, if the map produces e.g a ForwardDiff.Dual, but
961+ # the accumulators in the VarInfo are plain floats, we error since we can't change the
962+ # element type of ThreadSafeVarInfo.accs_by_thread. However, doing this conversion here
963+ # messes with cases like using Float32 of logprobs and Float64 for x. Also, this is just
964+ # plain ugly and hacky.
965+ # The below line is finicky for type stability. For instance, assigning the eltype to
966+ # convert to into an intermediate variable makes this unstable (constant propagation)
967+ # fails. Take care when editing.
968+ param_eltype = DynamicPPL. get_param_eltype (varinfo, model. context)
969+ accs = map (DynamicPPL. getaccs (varinfo)) do acc
970+ DynamicPPL. convert_eltype (float_type_with_fallback (param_eltype), acc)
971+ end
972+ varinfo = DynamicPPL. setaccs!! (varinfo, accs)
978973 wrapper = ThreadSafeVarInfo (resetaccs!! (varinfo))
979974 result, wrapper_new = _evaluate!! (model, wrapper)
980975 # TODO (penelopeysm): If seems that if you pass a TSVI to this method, it
0 commit comments