Skip to content

Commit f378aad

Browse files
authored
Improve weighted reservoir with replacement algorithm (#124)
1 parent 189da4b commit f378aad

File tree

3 files changed

+23
-12
lines changed

3 files changed

+23
-12
lines changed

src/SamplingUtils.jl

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,7 @@ struct SeqIterWRSampler{R}
1919
n::Int
2020
end
2121

22-
@inline function Base.iterate(s::SeqIterWRSampler)
23-
curmax = -log(Float64(s.N)) + randexp(s.rng)/s.n
24-
return (s.N - ceil(Int, exp(-curmax)) + 1, (s.n-1, curmax))
25-
end
26-
@inline function Base.iterate(s::SeqIterWRSampler, state)
22+
@inline function Base.iterate(s::SeqIterWRSampler, state = (s.n, -log(Float64(s.N))))
2723
state[1] == 0 && return nothing
2824
curmax = state[2] + randexp(s.rng)/state[1]
2925
return (s.N - ceil(Int, exp(-curmax)) + 1, (state[1]-1, curmax))
@@ -159,4 +155,4 @@ function ordmemory(n)
159155
ord = Memory{Int}(undef, n)
160156
for i in eachindex(ord) ord[i] = i end
161157
ord
162-
end
158+
end

src/UnweightedSamplingMulti.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -278,4 +278,3 @@ function ordvalue(s::MultiOrdAlgRSWRSKIPSampler)
278278
end
279279
end
280280

281-

src/WeightedSamplingMulti.jl

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -128,10 +128,18 @@ end
128128
s = @inline update_state!(s, w)
129129
if s.seen_k <= n
130130
@inbounds s.value[s.seen_k] = el
131-
@inbounds s.weights[s.seen_k] = w
131+
@inbounds s.weights[s.seen_k] = s.state
132132
if s.seen_k == n
133-
s.value .= sample(s.rng, s.value, Weights(s.weights, s.state), n;
134-
ordered = is_ordered(s))
133+
j, curx = 1, 0.0
134+
newvalues = similar(s.value)
135+
@inbounds for i in n:-1:1
136+
curx += (1-exp(-randexp(s.rng)/i))*(1-curx)
137+
while s.weights[j] < curx * s.state
138+
j += 1
139+
end
140+
newvalues[i] = s.value[j]
141+
end
142+
s.value .= newvalues
135143
s = @inline recompute_skip!(s, n)
136144
end
137145
return s
@@ -302,8 +310,12 @@ function OnlineStatsBase.value(s::Union{MultiAlgAResSampler, MultiAlgAExpJSample
302310
end
303311
end
304312
function OnlineStatsBase.value(s::MultiAlgWRSWRSKIPSampler)
313+
nobs(s) == 0 && return s.value[1:0]
305314
if nobs(s) < length(s.value)
306-
return nobs(s) == 0 ? s.value[1:0] : sample(s.rng, s.value[1:nobs(s)], weights(s.weights[1:nobs(s)]), length(s.value))
315+
weightsnew = Vector{Float64}(undef, nobs(s))
316+
weightsnew[1] = s.weights[1]
317+
for i in 2:nobs(s) weightsnew[i] = s.weights[i] - s.weights[i-1] end
318+
return sample(s.rng, s.value[1:nobs(s)], weights(weightsnew), length(s.value))
307319
else
308320
return s.value
309321
end
@@ -318,8 +330,12 @@ function ordvalue(s::Union{MultiOrdAlgAResSampler, MultiOrdAlgAExpJSampler})
318330
return first.(vals[sortperm(map(x -> x[2], vals))])
319331
end
320332
function ordvalue(s::MultiOrdAlgWRSWRSKIPSampler)
333+
nobs(s) == 0 && return s.value[1:0]
321334
if nobs(s) < length(s.value)
322-
return sample(s.rng, s.value[1:nobs(s)], weights(s.weights[1:nobs(s)]), length(s.value); ordered=true)
335+
weightsnew = Vector{Float64}(undef, nobs(s))
336+
weightsnew[1] = s.weights[1]
337+
for i in 2:nobs(s) weightsnew[i] = s.weights[i] - s.weights[i-1] end
338+
return sample(s.rng, s.value[1:nobs(s)], weights(weightsnew), length(s.value); ordered=true)
323339
else
324340
return s.value[sortperm(s.ord)]
325341
end

0 commit comments

Comments
 (0)