Skip to content

Commit 34028b0

Browse files
committed
add algorithm for weighted stream sampling
1 parent f378aad commit 34028b0

File tree

3 files changed

+98
-23
lines changed

3 files changed

+98
-23
lines changed

src/SamplingInterface.jl

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -128,9 +128,26 @@ is assumed to have a `eltype` of `T`. The methods implemented in
128128
[`StreamSampler`](@ref) require the knowledge of the total number
129129
of elements in the stream `N`, if not provided it is assumed to be
130130
available by calling `length(iter)`.
131+
132+
-----
133+
134+
StreamSampler{T}([rng], iter, wfunc, n, W, method = AlgORDWSWR())
135+
136+
Initializes a weigthed stream sampler, which can then be iterated over
137+
to return the sampling elements of the iterable `iter` which
138+
is assumed to have a `eltype` of `T`. The methods implemented in
139+
[`StreamSampler`](@ref) for weighted streams require the knowledge
140+
of the total weight of the stream `W` and a weight function `wfunc`
141+
specifying how to map an element to its weight.
131142
"""
132143
struct StreamSampler{T} 1 === 1 end
133144

145+
function StreamSampler{T}(iter, wfunc::Function, n, W, method::StreamAlgorithm = AlgORDWSWR()) where T
146+
return StreamSampler{T}(Random.default_rng(), iter, wfunc, n, W, method)
147+
end
148+
function StreamSampler{T}(rng::AbstractRNG, iter, wfunc::Function, n, W, method::StreamAlgorithm = AlgORDWSWR()) where T
149+
return StreamSampler{T}(rng, iter, wfunc, n, W, method)
150+
end
134151
function StreamSampler{T}(iter, n, N, method::StreamAlgorithm = AlgD()) where T
135152
return StreamSampler{T}(Random.default_rng(), iter, n, N, method)
136153
end
@@ -159,10 +176,11 @@ If the iterator is empty, it returns `nothing`.
159176
itsample([rng], iter, wfunc, n::Int, method = AlgAExpJ(); ordered = false)
160177
161178
Return a vector of `n` random elements of the iterator,
162-
optionally specifying a `rng` (which defaults to `Random.default_rng()`)
163-
a weight function `wfunc` and a `method`. `ordered` dictates whether an
164-
ordered sample (also called a sequential sample, i.e. a sample where items
165-
appear in the same order as in `iter`) must be collected.
179+
optionally specifying a `rng` (which defaults to `Random.default_rng()`),
180+
a weight function `wfunc` specifying how to map an element to its weight
181+
and a `method`. `ordered` dictates whether an ordered sample (also called a
182+
sequential sample, i.e. a sample where items appear in the same order as in
183+
`iter`) must be collected.
166184
167185
If the iterator has less than `n` elements, in the case of sampling without
168186
replacement, it returns a vector of those elements.

src/SortedSamplingMulti.jl

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,50 @@
11

2+
struct MultiAlgWeightedORDSampler{T,R,I,F}
3+
rng::R
4+
it::I
5+
f::F
6+
n::Int
7+
W::Float64
8+
function MultiAlgWeightedORDSampler{T}(rng::R, it::I, f::F, n, W) where {T,R,I,F}
9+
return new{T,R,I,F}(rng, it, f, n, W)
10+
end
11+
end
12+
13+
@inline function Base.iterate(s::MultiAlgWeightedORDSampler)
14+
local el, state_el
15+
w, curx, k = 0.0, 0.0, 0
16+
for i in s.n:-1:1
17+
curx += (1-exp(-randexp(s.rng)/i))*(1-curx)
18+
while w < curx * s.W
19+
nstate = k == 0 ? iterate(s.it) : iterate(s.it, state_el)
20+
nstate == nothing && return nothing
21+
k += 1
22+
el, state_el = nstate
23+
w += s.f(el)
24+
end
25+
return (el, (el, w, state_el, curx, i-1))
26+
end
27+
end
28+
@inline function Base.iterate(s::MultiAlgWeightedORDSampler, state)
29+
state[end] == 0 && return nothing
30+
el, w, state_el, curx, n = state
31+
for i in n:-1:1
32+
curx += (1-exp(-randexp(s.rng)/i))*(1-curx)
33+
while w < curx * s.W
34+
nstate = iterate(s.it, state_el)
35+
nstate == nothing && return nothing
36+
el, state_el = nstate
37+
w += s.f(el)
38+
end
39+
return (el, (el, w, state_el, curx, i-1))
40+
end
41+
end
42+
43+
Base.IteratorEltype(::MultiAlgWeightedORDSampler) = Base.HasEltype()
44+
Base.eltype(::MultiAlgWeightedORDSampler{T}) where T = T
45+
Base.IteratorSize(::MultiAlgWeightedORDSampler) = Base.HasLength()
46+
Base.length(s::MultiAlgWeightedORDSampler) = s.n
47+
248
struct MultiAlgORDSampler{T,R,I,D} <: AbstractStreamSampler
349
rng::R
450
it::I
@@ -9,6 +55,9 @@ struct MultiAlgORDSampler{T,R,I,D} <: AbstractStreamSampler
955
end
1056
end
1157

58+
function StreamSampler{T}(rng::AbstractRNG, iter, wfunc::Function, n, W, ::AlgORDWSWR) where T
59+
return MultiAlgWeightedORDSampler{T}(rng, iter, wfunc, n, W)
60+
end
1261
function StreamSampler{T}(rng::AbstractRNG, iter, n, N, ::AlgD) where T
1362
return MultiAlgORDSampler{T}(rng, iter, min(n, N), SeqSampleIter(rng, N, min(n, N)))
1463
end

src/StreamSampling.jl

Lines changed: 27 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ using StatsBase
1919

2020
export fit!, merge!, value, ordvalue, nobs, itsample
2121
export AbstractReservoirSampler, ReservoirSampler, StreamSampler
22-
export AlgL, AlgR, AlgRSWRSKIP, AlgARes, AlgAExpJ, AlgWRSWRSKIP, AlgD, AlgORDSWR
22+
export AlgL, AlgR, AlgRSWRSKIP, AlgARes, AlgAExpJ, AlgWRSWRSKIP, AlgD, AlgORDSWR, AlgORDWSWR
2323

2424
struct ImmutSampler end
2525
struct MutSampler end
@@ -34,24 +34,6 @@ abstract type AbstractWeightedReservoirSampler <: AbstractReservoirSampler end
3434
abstract type StreamAlgorithm end
3535
abstract type ReservoirAlgorithm <: StreamAlgorithm end
3636

37-
"""
38-
Implements random stream sampling without replacement. To be used with [`StreamSampler`](@ref)
39-
or [`itsample`](@ref).
40-
41-
Adapted from algorithm D described in "An Efficient Algorithm for Sequential Random Sampling,
42-
J. S. Vitter, 1987".
43-
"""
44-
struct AlgD <: StreamAlgorithm end
45-
46-
"""
47-
Implements random stream sampling with replacement. To be used with [`StreamSampler`](@ref)
48-
or [`itsample`](@ref).
49-
50-
Adapted from algorithm 4 described in "Generating Sorted Lists of Random Numbers, J. L. Bentley
51-
et al., 1980".
52-
"""
53-
struct AlgORDSWR <: StreamAlgorithm end
54-
5537
"""
5638
Implements random reservoir sampling without replacement. To be used with [`ReservoirSampler`](@ref)
5739
or [`itsample`](@ref).
@@ -104,6 +86,32 @@ Replacement, A. Meligrana, 2024".
10486
"""
10587
struct AlgWRSWRSKIP <: ReservoirAlgorithm end
10688

89+
"""
90+
Implements random stream sampling without replacement. To be used with [`StreamSampler`](@ref)
91+
or [`itsample`](@ref).
92+
93+
Adapted from algorithm D described in "An Efficient Algorithm for Sequential Random Sampling,
94+
J. S. Vitter, 1987".
95+
"""
96+
struct AlgD <: StreamAlgorithm end
97+
98+
"""
99+
Implements random stream sampling with replacement. To be used with [`StreamSampler`](@ref)
100+
or [`itsample`](@ref).
101+
102+
Adapted from algorithm 4 described in "Generating Sorted Lists of Random Numbers, J. L. Bentley
103+
et al., 1980".
104+
"""
105+
struct AlgORDSWR <: StreamAlgorithm end
106+
107+
"""
108+
Implements weighted random stream sampling with replacement. To be used with [`StreamSampler`](@ref).
109+
110+
Adapted from algorithm 3 described in "An asymptotically optimal, online algorithm for weighted random
111+
sampling with replacement, M. Startek, 2016".
112+
"""
113+
struct AlgORDWSWR <: StreamAlgorithm end
114+
107115
include("SamplingUtils.jl")
108116
include("SamplingInterface.jl")
109117
include("SortedSamplingSingle.jl")

0 commit comments

Comments
 (0)