Skip to content

Commit 1f1bb01

Browse files
committed
Fix performance issues
1 parent 0688f11 commit 1f1bb01

File tree

3 files changed

+15
-20
lines changed

3 files changed

+15
-20
lines changed

src/compiler.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -724,7 +724,7 @@ function build_output(modeldef, linenumbernode)
724724
# to the call site
725725
modeldef[:body] = MacroTools.@q begin
726726
$(linenumbernode)
727-
return $(DynamicPPL.Model)($name, $args_nt; $(kwargs_inclusion...))
727+
return $(DynamicPPL.Model){false}($name, $args_nt; $(kwargs_inclusion...))
728728
end
729729

730730
return MacroTools.@q begin

src/model.jl

Lines changed: 13 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -46,12 +46,12 @@ struct Model{
4646
context::Ctx
4747

4848
@doc """
49-
Model{missings}(f, args::NamedTuple, defaults::NamedTuple)
49+
Model{Threaded,missings}(f, args::NamedTuple, defaults::NamedTuple)
5050
5151
Create a model with evaluation function `f` and missing arguments overwritten by
5252
`missings`.
5353
"""
54-
function Model{missings,Threaded}(
54+
function Model{Threaded,missings}(
5555
f::F,
5656
args::NamedTuple{argnames,Targs},
5757
defaults::NamedTuple{defaultnames,Tdefaults},
@@ -71,32 +71,27 @@ Create a model with evaluation function `f` and missing arguments deduced from `
7171
Default arguments `defaults` are used internally when constructing instances of the same
7272
model with different arguments.
7373
"""
74-
@generated function Model(
74+
@generated function Model{Threaded}(
7575
f::F,
7676
args::NamedTuple{argnames,Targs},
7777
defaults::NamedTuple{kwargnames,Tkwargs},
7878
context::AbstractContext=DefaultContext(),
79-
threadsafe::Bool=false,
80-
) where {F,argnames,Targs,kwargnames,Tkwargs}
79+
) where {Threaded,F,argnames,Targs,kwargnames,Tkwargs}
8180
missing_args = Tuple(
8281
name for (name, typ) in zip(argnames, Targs.types) if typ <: Missing
8382
)
8483
missing_kwargs = Tuple(
8584
name for (name, typ) in zip(kwargnames, Tkwargs.types) if typ <: Missing
8685
)
87-
return :(Model{$(missing_args..., missing_kwargs...),threadsafe}(
86+
return :(Model{Threaded,$(missing_args..., missing_kwargs...)}(
8887
f, args, defaults, context
8988
))
9089
end
9190

92-
function Model(
93-
f,
94-
args::NamedTuple,
95-
context::AbstractContext=DefaultContext(),
96-
threadsafe=false;
97-
kwargs...,
98-
)
99-
return Model(f, args, NamedTuple(kwargs), context, threadsafe)
91+
function Model{Threaded}(
92+
f, args::NamedTuple, context::AbstractContext=DefaultContext(); kwargs...
93+
) where {Threaded}
94+
return Model{Threaded}(f, args, NamedTuple(kwargs), context)
10095
end
10196

10297
function _requires_threadsafe(
@@ -112,7 +107,7 @@ Return a new `Model` with the same evaluation function and other arguments, but
112107
with its underlying context set to `context`.
113108
"""
114109
function contextualize(model::Model, context::AbstractContext)
115-
return Model(model.f, model.args, model.defaults, context, _requires_threadsafe(model))
110+
return Model{_requires_threadsafe(model)}(model.f, model.args, model.defaults, context)
116111
end
117112

118113
"""
@@ -148,7 +143,7 @@ function setthreadsafe(model::Model{F,A,D,M}, threadsafe::Bool) where {F,A,D,M}
148143
return if _requires_threadsafe(model) == threadsafe
149144
model
150145
else
151-
Model{M,threadsafe}(model.f, model.args, model.defaults, model.context)
146+
Model{threadsafe,M}(model.f, model.args, model.defaults, model.context)
152147
end
153148
end
154149

@@ -955,9 +950,9 @@ Returns a tuple of the model's return value, plus the updated `varinfo` object.
955950
vi = DynamicPPL.setaccs!!(vi, accs)
956951
tsvi = ThreadSafeVarInfo(resetaccs!!(vi))
957952
retval, tsvi_new = DynamicPPL._evaluate!!(model, tsvi)
958-
return retval, setaccs!!(tsvi_new.varinfo, DynamicPPL.getaccs(tsvi_new))
953+
retval, setaccs!!(tsvi_new.varinfo, DynamicPPL.getaccs(tsvi_new))
959954
else
960-
return DynamicPPL._evaluate!!(model, resetaccs!!(vi))
955+
DynamicPPL._evaluate!!(model, resetaccs!!(vi))
961956
end
962957
end
963958
@inline function init!!(

test/logdensityfunction.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@ end
147147
vi = VarInfo(model)
148148
ldf = DynamicPPL.LogDensityFunction(model, DynamicPPL.getlogjoint_internal, vi)
149149
x = vi[:]
150-
bench = median(@be LogDensityProblems.logdensity(ldf, x))
150+
bench = median(@be LogDensityProblems.logdensity($ldf, $x))
151151
@test iszero(bench.allocs)
152152
end
153153
end

0 commit comments

Comments
 (0)