Skip to content

Commit c8208de

Browse files
authored
Merge pull request #99 from itsdfish/flexible_sigma
make sigma flexible
2 parents 003199a + 6681754 commit c8208de

File tree

9 files changed

+118
-48
lines changed

9 files changed

+118
-48
lines changed

src/LBA.jl

Lines changed: 57 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
"""
2-
LBA{T<:Real} <: AbstractLBA
2+
LBA{T <: Real, T1 <: Union{<: T, Vector{<: T}}} <: AbstractLBA
33
44
A model object for the linear ballistic accumulator.
55
66
# Parameters
77
88
- `ν::Vector{T}`: a vector of drift rates
9-
- `σ::Vector{T}`: a vector of drift rate standard deviation
9+
- `σ::T1`: a scalar or vector of drift rate standard deviation
1010
- `A::T`: max start point
1111
- `k::T`: A + k = b, where b is the decision threshold
1212
- `τ::T`: an encoding-response offset
@@ -19,7 +19,7 @@ Two constructors are defined below. The first constructor uses positional argume
1919
2020
The second constructor uses keywords with default values, and is not order dependent:
2121
22-
LBA(;τ=.3, A=.8, k=.5, ν=[2.0,1.75], σ=[1.0,1.0])
22+
LBA(;τ = .3, A = .8, k = .5, ν = [2.0,1.75], σ = 1)
2323
2424
# Example
2525
@@ -35,22 +35,22 @@ loglike = logpdf.(dist, choice, rt)
3535
3636
Brown, S. D., & Heathcote, A. (2008). The simplest complete model of choice response time: Linear ballistic accumulation. Cognitive psychology, 57(3), 153-178.
3737
"""
38-
mutable struct LBA{T <: Real} <: AbstractLBA
38+
mutable struct LBA{T <: Real, T1 <: Union{<:T, Vector{<:T}}} <: AbstractLBA{T, T1}
3939
ν::Vector{T}
40-
σ::Vector{T}
40+
σ::T1
4141
A::T
4242
k::T
4343
τ::T
4444
end
4545

46-
function LBA(ν, σ, A, k, τ)
46+
function LBA(ν, σ, A, k, τ::T) where {T}
4747
_, _, A, k, τ = promote(ν[1], σ[1], A, k, τ)
48-
ν = convert(Vector{typeof(k)}, ν)
49-
σ = convert(Vector{typeof(k)}, σ)
48+
ν = convert(Vector{T}, ν)
49+
σ = isa(σ, Vector) ? convert(Vector{T}, σ) : convert(T, σ)
5050
return LBA(ν, σ, A, k, τ)
5151
end
5252

53-
LBA(; τ = 0.3, A = 0.8, k = 0.5, ν = [2.0, 1.75], σ = fill(1.0, length(ν))) =
53+
LBA(; τ = 0.3, A = 0.8, k = 0.5, ν = [2.0, 1.75], σ = 1) =
5454
LBA(ν, σ, A, k, τ)
5555

5656
function params(d::LBA)
@@ -79,7 +79,7 @@ function sample_drift_rates(rng::AbstractRNG, ν, σ)
7979
v = similar(ν)
8080
n_options = length(ν)
8181
while negative
82-
v = [rand(rng, Normal[i], σ[i])) for i 1:n_options]
82+
v = @. rand(rng, Normal(ν, σ))
8383
negative = any(x -> x > 0, v) ? false : true
8484
end
8585
return v
@@ -97,7 +97,7 @@ function rand(rng::AbstractRNG, d::AbstractLBA)
9797
return (; choice, rt)
9898
end
9999

100-
function logpdf(d::AbstractLBA, c, rt)
100+
function logpdf(d::AbstractLBA{T, T1}, c, rt) where {T, T1 <: Vector{<:Real}}
101101
(; τ, A, k, ν, σ) = d
102102
b = A + k
103103
LL = 0.0
@@ -114,7 +114,24 @@ function logpdf(d::AbstractLBA, c, rt)
114114
return max(LL, -1000.0)
115115
end
116116

117-
function pdf(d::AbstractLBA, c, rt)
117+
function logpdf(d::AbstractLBA{T, T1}, c, rt) where {T, T1 <: Real}
118+
(; τ, A, k, ν, σ) = d
119+
b = A + k
120+
LL = 0.0
121+
rt < τ ? (return -Inf) : nothing
122+
for i 1:length(ν)
123+
if c == i
124+
LL += log_dens(d, ν[i], σ, rt)
125+
else
126+
LL += log(max(0.0, 1 - cummulative(d, ν[i], σ, rt)))
127+
end
128+
end
129+
pneg = pnegative(d)
130+
LL = LL - log(1 - pneg)
131+
return max(LL, -1000.0)
132+
end
133+
134+
function pdf(d::AbstractLBA{T, T1}, c, rt) where {T, T1 <: Vector{<:Real}}
118135
(; τ, A, k, ν, σ) = d
119136
b = A + k
120137
den = 1.0
@@ -132,6 +149,24 @@ function pdf(d::AbstractLBA, c, rt)
132149
isnan(den) ? (return 0.0) : (return den)
133150
end
134151

152+
function pdf(d::AbstractLBA{T, T1}, c, rt) where {T, T1 <: Real}
153+
(; τ, A, k, ν, σ) = d
154+
b = A + k
155+
den = 1.0
156+
rt < τ ? (return 1e-10) : nothing
157+
for i 1:length(ν)
158+
if c == i
159+
den *= dens(d, ν[i], σ, rt)
160+
else
161+
den *= (1 - cummulative(d, ν[i], σ, rt))
162+
end
163+
end
164+
pneg = pnegative(d)
165+
den = den / (1 - pneg)
166+
den = max(den, 1e-10)
167+
isnan(den) ? (return 0.0) : (return den)
168+
end
169+
135170
function dens(d::AbstractLBA, v, σ, rt)
136171
(; τ, A, k) = d
137172
dt = rt - τ
@@ -164,7 +199,7 @@ function cummulative(d::AbstractLBA, v, σ, rt)
164199
return cm
165200
end
166201

167-
function pnegative(d::AbstractLBA)
202+
function pnegative(d::AbstractLBA{T, T1}) where {T, T1 <: Vector{<:Real}}
168203
(; ν, σ) = d
169204
p = 1.0
170205
for i 1:length(ν)
@@ -173,6 +208,15 @@ function pnegative(d::AbstractLBA)
173208
return p
174209
end
175210

211+
function pnegative(d::AbstractLBA{T, T1}) where {T, T1 <: Real}
212+
(; ν, σ) = d
213+
p = 1.0
214+
for i 1:length(ν)
215+
p *= Φ(-ν[i] / σ)
216+
end
217+
return p
218+
end
219+
176220
"""
177221
simulate(model::AbstractLBA; n_steps=100, _...)
178222

src/LCA.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,9 +49,9 @@ mutable struct LCA{T <: Real} <: AbstractLCA
4949
τ::T
5050
end
5151

52-
function LCA(ν, σ, β, λ, α, τ)
52+
function LCA(ν, σ, β, λ, α, τ::T) where {T}
5353
_, σ, β, λ, α, τ = promote(ν[1], σ, β, λ, α, τ)
54-
ν = convert(Vector{typeof(τ)}, ν)
54+
ν = convert(Vector{T}, ν)
5555
return LCA(ν, σ, β, λ, α, τ)
5656
end
5757

src/LNR.jl

Lines changed: 37 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
# Parameters
55
66
- `ν`: a vector of means in log-space
7-
- `σ`: a vector of standard deviation parameter in log-space
7+
- `σ`: a scalar or vector of standard deviation parameter in log-space
88
- `τ`: a encoding-response offset
99
1010
# Constructors
@@ -15,13 +15,13 @@ Two constructors are defined below. The first constructor uses positional argume
1515
1616
The second constructor uses keywords with default values, and is not order dependent:
1717
18-
LNR(; ν = [-1, -2], σ = fill(1.0, length(ν)), τ = 0.20)
18+
LNR(; ν = [-1, -2], σ = 1, τ = 0.20)
1919
2020
# Example
2121
2222
```julia
2323
using SequentialSamplingModels
24-
dist = LNR(ν=[-2,-3], σ=[1.0,1.0], τ=.3)
24+
dist = LNR(ν = [-2,-3], σ = 1, τ = .3)
2525
choice,rt = rand(dist, 10)
2626
like = pdf.(dist, choice, rt)
2727
loglike = logpdf.(dist, choice, rt)
@@ -32,24 +32,24 @@ Rouder, J. N., Province, J. M., Morey, R. D., Gomez, P., & Heathcote, A. (2015).
3232
The lognormal race: A cognitive-process model of choice and latency with desirable
3333
psychometric properties. Psychometrika, 80(2), 491-513.
3434
"""
35-
struct LNR{T <: Real} <: AbstractLNR
35+
struct LNR{T <: Real, T1 <: Union{<:T, Vector{<:T}}} <: AbstractLNR{T, T1}
3636
ν::Vector{T}
37-
σ::Vector{T}
37+
σ::T1
3838
τ::T
3939
end
4040

41-
function LNR(ν, σ, τ)
41+
function LNR(ν, σ, τ::T) where {T}
4242
_, _, τ = promote(ν[1], σ[1], τ)
43-
ν = convert(Vector{typeof(τ)}, ν)
44-
σ = convert(Vector{typeof(τ)}, σ)
43+
ν = convert(Vector{T}, ν)
44+
σ = isa(σ, Vector) ? convert(Vector{T}, σ) : convert(T, σ)
4545
return LNR(ν, σ, τ)
4646
end
4747

4848
function params(d::AbstractLNR)
4949
return (d.ν, d.σ, d.τ)
5050
end
5151

52-
LNR(; ν = [-1, -2], σ = fill(1.0, length(ν)), τ = 0.20) = LNR(ν, σ, τ)
52+
LNR(; ν = [-1, -2], σ = 1, τ = 0.20) = LNR(ν, σ, τ)
5353

5454
function rand(rng::AbstractRNG, dist::AbstractLNR)
5555
(; ν, σ, τ) = dist
@@ -58,7 +58,7 @@ function rand(rng::AbstractRNG, dist::AbstractLNR)
5858
return (; choice, rt)
5959
end
6060

61-
function logpdf(d::AbstractLNR, r::Int, t::Float64)
61+
function logpdf(d::AbstractLNR{T, T1}, r::Int, t::Float64) where {T, T1 <: Vector{<:Real}}
6262
(; ν, σ, τ) = d
6363
LL = 0.0
6464
for i 1:length(ν)
@@ -71,7 +71,20 @@ function logpdf(d::AbstractLNR, r::Int, t::Float64)
7171
return LL
7272
end
7373

74-
function pdf(d::AbstractLNR, r::Int, t::Float64)
74+
function logpdf(d::AbstractLNR{T, T1}, r::Int, t::Float64) where {T, T1 <: Real}
75+
(; ν, σ, τ) = d
76+
LL = 0.0
77+
for i 1:length(ν)
78+
if i == r
79+
LL += logpdf(LogNormal(ν[i], σ), t - τ)
80+
else
81+
LL += logccdf(LogNormal(ν[i], σ), t - τ)
82+
end
83+
end
84+
return LL
85+
end
86+
87+
function pdf(d::AbstractLNR{T, T1}, r::Int, t::Float64) where {T, T1 <: Vector{<:Real}}
7588
(; ν, σ, τ) = d
7689
density = 1.0
7790
for i 1:length(ν)
@@ -83,3 +96,16 @@ function pdf(d::AbstractLNR, r::Int, t::Float64)
8396
end
8497
return density
8598
end
99+
100+
function pdf(d::AbstractLNR{T, T1}, r::Int, t::Float64) where {T, T1 <: Real}
101+
(; ν, σ, τ) = d
102+
density = 1.0
103+
for i 1:length(ν)
104+
if i == r
105+
density *= pdf(LogNormal(ν[i], σ), t - τ)
106+
else
107+
density *= (1 - cdf(LogNormal(ν[i], σ), t - τ))
108+
end
109+
end
110+
return density
111+
end

src/MDFT.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -90,10 +90,10 @@ mutable struct MDFT{T <: Real} <: AbstractMDFT
9090
_att_idx::Int
9191
end
9292

93-
function MDFT(σ, α, τ, γ, κ, ϕ1, ϕ2, β, C)
93+
function MDFT(σ, α, τ::T, γ, κ, ϕ1, ϕ2, β, C) where {T}
9494
σ, α, τ, γ, _, ϕ1, ϕ2, β, = promote(σ, α, τ, γ, κ[1], ϕ1, ϕ2, β)
95-
κ = convert(Vector{typeof(τ)}, κ)
96-
C = convert(Array{typeof(τ), 2}, C)
95+
κ = convert(Vector{T}, κ)
96+
C = convert(Array{T, 2}, C)
9797
_CM = zeros(size(C, 1), length(κ))
9898
S = similar(C)
9999
return MDFT(σ, α, τ, γ, κ, ϕ1, ϕ2, β, S, C, _CM, 0)

src/MLBA.jl

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
"""
2-
MLBA{T <: Real} <: AbstractMLBA
2+
MLBA{T <: Real, T1 <: Union{<: T, Vector{<: T}}} <: AbstractMLBA{T,T1}
33
44
# Fields
55
@@ -8,7 +8,7 @@
88
- `λₚ::T`: decay constant for attention weights of positive differences
99
- `λₙ::T`: decay constant for attention weights of negative differences
1010
- `γ::T`: risk aversion exponent for subjective values
11-
- `σ::Vector{T}`: a vector of drift rate standard deviation
11+
- `σ::T1`: a scalar or vector of drift rate standard deviation
1212
- `A::T`: max start point
1313
- `k::T`: A + k = b, where b is the decision threshold
1414
- `τ::T`: an encoding-response offset
@@ -27,7 +27,7 @@
2727
τ = 0.3,
2828
A = 1.0,
2929
k = 1.0,
30-
σ = fill(1.0, n_alternatives)
30+
σ = 1.0
3131
)
3232
3333
# Example
@@ -59,22 +59,22 @@ loglike = logpdf.(dist, choice, rt, (M,))
5959
6060
Trueblood, J. S., Brown, S. D., & Heathcote, A. (2014). The multiattribute linear ballistic accumulator model of context effects in multialternative choice. Psychological Review, 121(2), 179.
6161
"""
62-
mutable struct MLBA{T <: Real} <: AbstractMLBA
62+
mutable struct MLBA{T <: Real, T1 <: Union{<:T, Vector{<:T}}} <: AbstractMLBA{T, T1}
6363
ν::Vector{T}
6464
β₀::T
6565
λₚ::T
6666
λₙ::T
6767
γ::T
68-
σ::Vector{T}
68+
σ::T1
6969
A::T
7070
k::T
7171
τ::T
7272
end
7373

74-
function MLBA(ν, β₀, λₚ, λₙ, γ, σ, A, k, τ)
74+
function MLBA(ν, β₀, λₚ, λₙ, γ, σ, A, k::T, τ) where {T}
7575
_, β₀, λₚ, λₙ, γ, _, A, k, τ = promote(ν[1], β₀, λₚ, λₙ, γ, σ[1], A, k, τ)
76-
ν = convert(Vector{typeof(k)}, ν)
77-
σ = convert(Vector{typeof(k)}, σ)
76+
ν = convert(Vector{T}, ν)
77+
σ = isa(σ, Vector) ? convert(Vector{T}, σ) : convert(T, σ)
7878
return MLBA(ν, β₀, λₚ, λₙ, γ, σ, A, k, τ)
7979
end
8080

@@ -88,7 +88,7 @@ MLBA(;
8888
τ = 0.3,
8989
A = 1.0,
9090
k = 1.0,
91-
σ = fill(1.0, n_alternatives)
91+
σ = 1.0
9292
) =
9393
MLBA(ν, β₀, λₚ, λₙ, γ, σ, A, k, τ)
9494

src/RDM.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -121,9 +121,9 @@ struct RDM{T <: Real} <: AbstractRDM
121121
τ::T
122122
end
123123

124-
function RDM(ν, k, A, τ)
124+
function RDM(ν, k::T, A, τ) where {T}
125125
_, A, k, τ = promote(ν[1], k, A, τ)
126-
ν = convert(Vector{typeof(k)}, ν)
126+
ν = convert(Vector{T}, ν)
127127
return RDM(ν, A, k, τ)
128128
end
129129

src/poisson_race.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,9 +36,9 @@ struct PoissonRace{T <: Real} <: AbstractPoissonRace
3636
τ::T
3737
end
3838

39-
function PoissonRace(ν, α, τ)
39+
function PoissonRace(ν, α, τ::T) where {T}
4040
_, τ = promote(ν[1], τ)
41-
ν = convert(Vector{typeof(τ)}, ν)
41+
ν = convert(Vector{T}, ν)
4242
return PoissonRace(ν, α, τ)
4343
end
4444

src/stDDM.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -64,10 +64,10 @@ mutable struct stDDM{T <: Real} <: AbstractstDDM
6464
τ::T
6565
end
6666

67-
function stDDM(ν, σ, s, z, η, ρ, α, τ)
67+
function stDDM(ν, σ, s, z, η, ρ, α, τ::T) where {T}
6868
_, σ, s, z, _, ρ, α, τ = promote(ν[1], σ, s, z, η[1], ρ, α, τ)
69-
ν = convert(Vector{typeof(τ)}, ν)
70-
η = convert(Vector{typeof(τ)}, η)
69+
ν = convert(Vector{T}, ν)
70+
η = convert(Vector{T}, η)
7171
return stDDM(ν, σ, s, z, η, ρ, α, τ)
7272
end
7373

0 commit comments

Comments
 (0)