Skip to content

Commit c49c958

Browse files
committed
add some functionality for ProductDistribution
1 parent 4fd93a8 commit c49c958

File tree

5 files changed

+171
-1
lines changed

5 files changed

+171
-1
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "SequentialSamplingModels"
22
uuid = "0e71a2a6-2b30-4447-8742-d083a85e82d1"
33
authors = ["itsdfish"]
4-
version = "0.12.6"
4+
version = "0.12.7"
55

66
[deps]
77
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"

src/SequentialSamplingModels.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ using PrettyTables
1212
using Random
1313
using SpecialFunctions
1414

15+
using Distributions: ProductDistribution
1516
using HCubature: hcubature
1617
using StatsBase: Weights
1718

@@ -28,6 +29,7 @@ import Distributions: mean
2829
import Distributions: minimum
2930
import Distributions: pdf
3031
import Distributions: rand
32+
import Distributions: rand!
3133
import Distributions: std
3234
import StatsAPI: params
3335
import StatsBase: cor2cov
@@ -109,4 +111,5 @@ include("MDFT.jl")
109111
include("ClassicMDFT.jl")
110112
include("MLBA.jl")
111113
include("ShiftedLogNormal.jl")
114+
include("product_distribution.jl")
112115
end

src/product_distribution.jl

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
function rand(
2+
rng::AbstractRNG,
3+
s::Sampleable{T, R}
4+
) where {T <: Matrixvariate, R <: SequentialSamplingModels.Mixed}
5+
n = size(s, 2)
6+
data = (; choice = fill(0, n), rt = fill(0.0, n))
7+
return rand!(rng, s, data)
8+
end
9+
10+
function rand(
11+
rng::AbstractRNG,
12+
s::Sampleable{T, R},
13+
dims::Dims
14+
) where {T <: Matrixvariate, R <: SequentialSamplingModels.Mixed}
15+
n = size(s, 2)
16+
ax = map(Base.OneTo, dims)
17+
data = [(; choice = fill(0, n), rt = fill(0.0, n)) for _ in Iterators.product(ax...)]
18+
return [rand!(rng, s, d) for d data]
19+
end
20+
21+
function rand!(
22+
rng::AbstractRNG,
23+
s::Sampleable{T, R},
24+
data::NamedTuple
25+
) where {T <: Matrixvariate, R <: SequentialSamplingModels.Mixed}
26+
for i 1:size(s, 2)
27+
data.choice[i], data.rt[i] = rand(rng, s.dists[i])
28+
end
29+
return data
30+
end
31+
32+
function logpdf(d::ProductDistribution, data_array::Array{<:NamedTuple, N}) where {N}
33+
return [logpdf(d, data) for data data_array]
34+
end
35+
36+
function logpdf(d::ProductDistribution, data::NamedTuple)
37+
LL = 0.0
38+
for i 1:length(d.dists)
39+
LL += logpdf(d.dists[i], data.choice[i], data.rt[i])
40+
end
41+
return LL
42+
end

src/type_system.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -393,3 +393,8 @@ Increments the evidence states `x` on each time step.
393393
"""
394394
increment!(model::SSM1D, x, μΔ; Δt = 0.001) =
395395
increment!(Random.default_rng(), model, x, μΔ; Δt)
396+
397+
Base.eltype(::Type{<:Sampleable{F, Mixed}}) where {F} =
398+
@NamedTuple{choice::Vector{Int64}, rt::Vector{Float64}}
399+
400+
Base.length(s::ContinuousMultivariateSSM) = 2

test/product_distribution_tests.jl

Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
@safetestset "ProductDistribution Tests" begin
2+
@safetestset "rand SSM1D 1" begin
3+
using Distributions
4+
using SequentialSamplingModels
5+
using Test
6+
7+
walds = [Wald(; ν = 2.5, α = 0.1, τ = 0.2), Wald(; ν = 1.5, α = 1, τ = 10)]
8+
dist = product_distribution(walds)
9+
data = rand(dist)
10+
@test length(data) == 2
11+
@test data[1] < 10
12+
@test data[2] > 10
13+
end
14+
15+
@safetestset "rand SSM1D 2" begin
16+
using Distributions
17+
using SequentialSamplingModels
18+
using Test
19+
20+
walds = [Wald(; ν = 2.5, α = 0.1, τ = 0.2), Wald(; ν = 1.5, α = 1, τ = 10)]
21+
dist = product_distribution(walds)
22+
data = rand(dist, 3)
23+
@test size(data) == (2, 3)
24+
@test all(data[1, :] .< 10)
25+
@test all(data[2, :] .> 10)
26+
end
27+
28+
@safetestset "rand logpdf 1" begin
29+
using Distributions
30+
using SequentialSamplingModels
31+
using Test
32+
33+
walds = [Wald(; ν = 2.5, α = 0.1, τ = 0.2), Wald(; ν = 1.5, α = 1, τ = 10)]
34+
dist = product_distribution(walds)
35+
data = rand(dist)
36+
LL1 = logpdf(dist, data)
37+
LL2 = sum(i -> logpdf(walds[i], data[i]), 1:2)
38+
@test LL1 LL2
39+
end
40+
41+
@safetestset "logpdf SSM1D 2" begin
42+
using Distributions
43+
using SequentialSamplingModels
44+
using Test
45+
46+
walds = [Wald(; ν = 2.5, α = 0.1, τ = 0.2), Wald(; ν = 1.5, α = 1, τ = 10)]
47+
dist = product_distribution(walds)
48+
data = rand(dist, 3)
49+
LL1 = logpdf(dist, data)
50+
LL2 = sum(i -> logpdf(walds[i], data[i, :]), 1:2)
51+
@test LL1 LL2
52+
end
53+
54+
@safetestset "rand SSM2D 1" begin
55+
using Distributions
56+
using SequentialSamplingModels
57+
using Test
58+
59+
lbas = [
60+
LBA(; ν = [3, 2], A = 0.8, k = 0.2, τ = 0.1),
61+
LBA= [1, 2], A = 0.5, k = 0.3, τ = 10)
62+
]
63+
dist = product_distribution(lbas)
64+
data = rand(dist)
65+
@test length(data.rt) == 2
66+
@test data.rt[1] < 10
67+
@test data.rt[2] > 10
68+
end
69+
70+
@safetestset "rand SSM2D 2" begin
71+
using Distributions
72+
using SequentialSamplingModels
73+
using Test
74+
75+
lbas = [
76+
LBA(; ν = [3, 2], A = 0.8, k = 0.2, τ = 0.1),
77+
LBA= [1, 2], A = 0.5, k = 0.3, τ = 10)
78+
]
79+
dist = product_distribution(lbas)
80+
data = rand(dist, 3)
81+
@test length(data) == 3
82+
@test all(map(i -> data[i].rt[1], 1:3) .< 10)
83+
@test all(map(i -> data[i].rt[2], 1:3) .> 10)
84+
end
85+
86+
@safetestset "logpdf SSM2D 1" begin
87+
using Distributions
88+
using SequentialSamplingModels
89+
using Test
90+
91+
lbas = [
92+
LBA(; ν = [3, 2], A = 0.8, k = 0.2, τ = 0.1),
93+
LBA= [1, 2], A = 0.5, k = 0.3, τ = 10)
94+
]
95+
dist = product_distribution(lbas)
96+
data = rand(dist)
97+
98+
LL1 = logpdf(dist, data)
99+
LL2 = sum(i -> logpdf(lbas[i], data.choice[i], data.rt[i]), 1:2)
100+
@test LL1 LL2
101+
end
102+
103+
@safetestset "logpdf SSM2D 2" begin
104+
using Distributions
105+
using SequentialSamplingModels
106+
using Test
107+
108+
lbas = [
109+
LBA(; ν = [3, 2], A = 0.8, k = 0.2, τ = 0.1),
110+
LBA= [1, 2], A = 0.5, k = 0.3, τ = 10)
111+
]
112+
dist = product_distribution(lbas)
113+
data = rand(dist, 3)
114+
115+
LL1 = logpdf(dist, data)
116+
LL2 =
117+
map(j -> sum(i -> logpdf(lbas[i], data[j].choice[i], data[j].rt[i]), 1:2), 1:3)
118+
@test LL1 LL2
119+
end
120+
end

0 commit comments

Comments
 (0)