11"""
2- LBA{T<: Real} <: AbstractLBA
2+ LBA{T <: Real, T1 <: Union{<: T, Vector{<: T}} } <: AbstractLBA
33
44A model object for the linear ballistic accumulator.
55
66# Parameters
77
88- `ν::Vector{T}`: a vector of drift rates
9- - `σ::Vector{T} `: a vector of drift rate standard deviation
9+ - `σ::T1 `: a scalar or vector of drift rate standard deviation
1010- `A::T`: max start point
1111- `k::T`: A + k = b, where b is the decision threshold
1212- `τ::T`: an encoding-response offset
@@ -19,7 +19,7 @@ Two constructors are defined below. The first constructor uses positional argume
1919
2020The second constructor uses keywords with default values, and is not order dependent:
2121
22- LBA(;τ= .3, A= .8, k= .5, ν= [2.0,1.75], σ=[1.0,1.0] )
22+ LBA(;τ = .3, A = .8, k = .5, ν = [2.0,1.75], σ = 1 )
2323
2424# Example
2525
@@ -35,22 +35,22 @@ loglike = logpdf.(dist, choice, rt)
3535
3636Brown, S. D., & Heathcote, A. (2008). The simplest complete model of choice response time: Linear ballistic accumulation. Cognitive psychology, 57(3), 153-178.
3737"""
38- mutable struct LBA{T <: Real } <: AbstractLBA
38+ mutable struct LBA{T <: Real , T1 <: Union{<:T, Vector{<:T}} } <: AbstractLBA{T, T1}
3939 ν:: Vector{T}
40- σ:: Vector{T}
40+ σ:: T1
4141 A:: T
4242 k:: T
4343 τ:: T
4444end
4545
46- function LBA (ν, σ, A, k, τ)
46+ function LBA (ν, σ, A, k, τ:: T ) where {T}
4747 _, _, A, k, τ = promote (ν[1 ], σ[1 ], A, k, τ)
48- ν = convert (Vector{typeof (k) }, ν)
49- σ = convert (Vector{typeof (k)} , σ)
48+ ν = convert (Vector{T }, ν)
49+ σ = isa (σ, Vector) ? convert (Vector{T}, σ) : convert (T , σ)
5050 return LBA (ν, σ, A, k, τ)
5151end
5252
53- LBA (; τ = 0.3 , A = 0.8 , k = 0.5 , ν = [2.0 , 1.75 ], σ = fill ( 1.0 , length (ν)) ) =
53+ LBA (; τ = 0.3 , A = 0.8 , k = 0.5 , ν = [2.0 , 1.75 ], σ = 1 ) =
5454 LBA (ν, σ, A, k, τ)
5555
5656function params (d:: LBA )
@@ -79,7 +79,7 @@ function sample_drift_rates(rng::AbstractRNG, ν, σ)
7979 v = similar (ν)
8080 n_options = length (ν)
8181 while negative
82- v = [ rand (rng, Normal (ν[i] , σ[i])) for i ∈ 1 : n_options]
82+ v = @. rand (rng, Normal (ν, σ))
8383 negative = any (x -> x > 0 , v) ? false : true
8484 end
8585 return v
@@ -97,7 +97,7 @@ function rand(rng::AbstractRNG, d::AbstractLBA)
9797 return (; choice, rt)
9898end
9999
100- function logpdf (d:: AbstractLBA , c, rt)
100+ function logpdf (d:: AbstractLBA{T, T1} , c, rt) where {T, T1 <: Vector{<:Real} }
101101 (; τ, A, k, ν, σ) = d
102102 b = A + k
103103 LL = 0.0
@@ -114,7 +114,24 @@ function logpdf(d::AbstractLBA, c, rt)
114114 return max (LL, - 1000.0 )
115115end
116116
117- function pdf (d:: AbstractLBA , c, rt)
117+ function logpdf (d:: AbstractLBA{T, T1} , c, rt) where {T, T1 <: Real }
118+ (; τ, A, k, ν, σ) = d
119+ b = A + k
120+ LL = 0.0
121+ rt < τ ? (return - Inf ) : nothing
122+ for i ∈ 1 : length (ν)
123+ if c == i
124+ LL += log_dens (d, ν[i], σ, rt)
125+ else
126+ LL += log (max (0.0 , 1 - cummulative (d, ν[i], σ, rt)))
127+ end
128+ end
129+ pneg = pnegative (d)
130+ LL = LL - log (1 - pneg)
131+ return max (LL, - 1000.0 )
132+ end
133+
134+ function pdf (d:: AbstractLBA{T, T1} , c, rt) where {T, T1 <: Vector{<:Real} }
118135 (; τ, A, k, ν, σ) = d
119136 b = A + k
120137 den = 1.0
@@ -132,6 +149,24 @@ function pdf(d::AbstractLBA, c, rt)
132149 isnan (den) ? (return 0.0 ) : (return den)
133150end
134151
152+ function pdf (d:: AbstractLBA{T, T1} , c, rt) where {T, T1 <: Real }
153+ (; τ, A, k, ν, σ) = d
154+ b = A + k
155+ den = 1.0
156+ rt < τ ? (return 1e-10 ) : nothing
157+ for i ∈ 1 : length (ν)
158+ if c == i
159+ den *= dens (d, ν[i], σ, rt)
160+ else
161+ den *= (1 - cummulative (d, ν[i], σ, rt))
162+ end
163+ end
164+ pneg = pnegative (d)
165+ den = den / (1 - pneg)
166+ den = max (den, 1e-10 )
167+ isnan (den) ? (return 0.0 ) : (return den)
168+ end
169+
135170function dens (d:: AbstractLBA , v, σ, rt)
136171 (; τ, A, k) = d
137172 dt = rt - τ
@@ -164,7 +199,7 @@ function cummulative(d::AbstractLBA, v, σ, rt)
164199 return cm
165200end
166201
167- function pnegative (d:: AbstractLBA )
202+ function pnegative (d:: AbstractLBA{T, T1} ) where {T, T1 <: Vector{<:Real} }
168203 (; ν, σ) = d
169204 p = 1.0
170205 for i ∈ 1 : length (ν)
@@ -173,6 +208,15 @@ function pnegative(d::AbstractLBA)
173208 return p
174209end
175210
211+ function pnegative (d:: AbstractLBA{T, T1} ) where {T, T1 <: Real }
212+ (; ν, σ) = d
213+ p = 1.0
214+ for i ∈ 1 : length (ν)
215+ p *= Φ (- ν[i] / σ)
216+ end
217+ return p
218+ end
219+
176220"""
177221 simulate(model::AbstractLBA; n_steps=100, _...)
178222
0 commit comments