Skip to content

Commit ad998c1

Browse files
committed
add LNR model plot
1 parent cd6311e commit ad998c1

File tree

5 files changed

+98
-3
lines changed

5 files changed

+98
-3
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "SequentialSamplingModels"
22
uuid = "0e71a2a6-2b30-4447-8742-d083a85e82d1"
33
authors = ["itsdfish"]
4-
version = "0.11.10"
4+
version = "0.11.11"
55

66
[deps]
77
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"

docs/src/bayes_factor.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ mll_rdm = stepping_stone(pt_rdm)
131131
```
132132

133133
## Compute the Bayes Factor
134-
The bayes factor is obtained by exponentiating the difference between marginal log likelihoods. The value of `1.21` indicates that the LBA is `1.21` times more likely to have generated the data.
134+
The bayes factor is obtained by exponentiating the difference between marginal log likelihoods. The value of `1.21` indicates that the data are `1.21` times more likely under the LBA than the RDM.
135135
```julia
136136
bf = exp(mll_lba - mll_rdm)
137137
```

ext/plots/plot_model.jl

Lines changed: 46 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ end
119119
compute_threshold(model) = model.α
120120
compute_threshold(model::AbstractLBA) = model.A + model.k
121121
compute_threshold(model::AbstractRDM) = model.A + model.k
122-
122+
compute_threshold(model::AbstractLNR) = 1.0
123123
"""
124124
get_default_labels(model::AbstractRDM)
125125
@@ -153,6 +153,23 @@ function get_default_labels(model::AbstractLBA)
153153
]
154154
end
155155

156+
"""
157+
get_default_labels(model::AbstractLNR)
158+
159+
Generates default parameter labels and locations
160+
161+
# Arguments
162+
163+
- `model::AbstractLBA`: an object for the log normal race model
164+
"""
165+
function get_default_labels(model::AbstractLNR)
166+
α = 1
167+
return [
168+
(0, α, text("α", 10, :right)),
169+
(model.τ, 0, text("τ", 10, :bottom))
170+
]
171+
end
172+
156173
"""
157174
get_default_labels(model::AbstractRDM)
158175
@@ -483,6 +500,34 @@ function get_model_plot_defaults(d::AbstractLBA)
483500
)
484501
end
485502

503+
"""
504+
get_model_plot_defaults(d::AbstractLNR)
505+
506+
Returns default plot options
507+
508+
# Arguments
509+
510+
- `d::AbstractLNR`: an object for the log normal race
511+
- `n_subplots`: the number of subplots (i.e., choices)
512+
"""
513+
function get_model_plot_defaults(d::AbstractLNR)
514+
n_subplots = n_options(d)
515+
title = ["choice $i" for _ 1:1, i 1:n_subplots]
516+
return (
517+
xaxis = nothing,
518+
yaxis = nothing,
519+
xticks = nothing,
520+
yticks = nothing,
521+
grid = false,
522+
linewidth = 0.75,
523+
color = :black,
524+
leg = false,
525+
title,
526+
layout = (n_subplots, 1),
527+
arrow = :closed
528+
)
529+
end
530+
486531
"""
487532
get_model_plot_defaults(d::AbstractCDDM)
488533

src/LNR.jl

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,3 +109,29 @@ function pdf(d::AbstractLNR{T, T1}, r::Int, t::Float64) where {T, T1 <: Real}
109109
end
110110
return density
111111
end
112+
113+
"""
114+
simulate(model::AbstractLNR; n_steps=100, _...)
115+
116+
Returns a matrix containing evidence samples of the LBA decision process. In the matrix, rows
117+
represent samples of evidence per time step and columns represent different accumulators.
118+
119+
# Arguments
120+
121+
- `model::AbstractLNR`: a subtype of AbstractLNR
122+
123+
# Keywords
124+
125+
- `n_steps=100`: number of time steps at which evidence is recorded
126+
"""
127+
function simulate(rng::AbstractRNG, model::AbstractLNR; n_steps = 100, _...)
128+
(; τ, ν, σ) = model
129+
n = length(ν)
130+
νs = @. rand(rng, Normal(ν, σ))
131+
βs = @. exp(νs)
132+
_,choice = findmax(βs)
133+
t = 1 / βs[choice]
134+
evidence = collect.(range.(0, βs * t, length = 100))
135+
time_steps = range(0, t, length = n_steps)
136+
return time_steps, hcat(evidence...)
137+
end

test/plots.jl

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -242,6 +242,30 @@
242242
plot_model(dist; add_density = true, n_sim = 2, density_kwargs, xlims = (0, 1.2))
243243
end
244244

245+
@safetestset "LNR" begin
246+
using Plots
247+
using SequentialSamplingModels
248+
using Test
249+
250+
dist = LNR()
251+
252+
h = histogram(dist)
253+
plot!(h, dist)
254+
255+
histogram(dist)
256+
plot!(dist)
257+
258+
p = plot(dist)
259+
histogram!(p, dist)
260+
261+
plot(dist)
262+
histogram!(dist)
263+
264+
density_kwargs = (; t_range = range(0.1, 1.2, length = 100),)
265+
plot_model(dist; add_density = true, n_sim = 2, density_kwargs, xlims = (0, 1.2))
266+
end
267+
268+
245269
@safetestset "WaldMixture" begin
246270
using Plots
247271
using SequentialSamplingModels

0 commit comments

Comments
 (0)