Skip to content

Commit a0331cd

Browse files
committed
Add miscellaneous LDP tests
1 parent 47ed385 commit a0331cd

File tree

1 file changed

+39
-0
lines changed

1 file changed

+39
-0
lines changed

test/fasteval.jl

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,45 @@ using Mooncake: Mooncake
6969
end
7070
end
7171

72+
@testset "LogDensityFunction: interface" begin
73+
# miscellaneous parts of the LogDensityProblems interface
74+
@testset "dimensions" begin
75+
@model function m1()
76+
x ~ Normal()
77+
y ~ Normal()
78+
return nothing
79+
end
80+
model = m1()
81+
ldf = DynamicPPL.LogDensityFunction(model)
82+
@test LogDensityProblems.dimension(ldf) == 2
83+
84+
@model function m2()
85+
x ~ Dirichlet(ones(4))
86+
y ~ Categorical(x)
87+
return nothing
88+
end
89+
model = m2()
90+
ldf = DynamicPPL.LogDensityFunction(model)
91+
@test LogDensityProblems.dimension(ldf) == 5
92+
linked_vi = DynamicPPL.link!!(VarInfo(model), model)
93+
ldf = DynamicPPL.LogDensityFunction(model, getlogjoint_internal, linked_vi)
94+
@test LogDensityProblems.dimension(ldf) == 4
95+
end
96+
97+
@testset "capabilities" begin
98+
@model f() = x ~ Normal()
99+
model = f()
100+
# No adtype
101+
ldf = DynamicPPL.LogDensityFunction(model)
102+
@test LogDensityProblems.capabilities(typeof(ldf)) ==
103+
LogDensityProblems.LogDensityOrder{0}()
104+
# With adtype
105+
ldf = DynamicPPL.LogDensityFunction(model; adtype=AutoForwardDiff())
106+
@test LogDensityProblems.capabilities(typeof(ldf)) ==
107+
LogDensityProblems.LogDensityOrder{1}()
108+
end
109+
end
110+
72111
@testset "LogDensityFunction: performance" begin
73112
if Threads.nthreads() == 1
74113
# Evaluating these three models should not lead to any allocations (but only when

0 commit comments

Comments
 (0)