Skip to content

Commit 4cfd371

Browse files
authored
Add iterator methods and length support to SequentialSampler (#130)
* Add iterator methods and length support to SequentialSampler * Update SamplingReduction.jl
1 parent c53754f commit 4cfd371

File tree

2 files changed

+13
-2
lines changed

2 files changed

+13
-2
lines changed

src/SamplingInterface.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -197,6 +197,10 @@ struct SequentialSampler{S}
197197
end
198198
Base.iterate(s::SequentialSampler) = iterate(s.s)
199199
Base.iterate(s::SequentialSampler, state) = iterate(s.s, state)
200+
Base.IteratorEltype(::SequentialSampler) = Base.HasEltype()
201+
Base.eltype(::SequentialSampler) = Int
202+
Base.IteratorSize(::SequentialSampler) = Base.HasLength()
203+
Base.length(s::SequentialSampler) = s.s.n
200204

201205
"""
202206
itsample([rng], iter, method = AlgRSWRSKIP())

src/SamplingReduction.jl

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,18 @@ function reduce_samples(t::Union{TypeS,TypeUnion}, ss::BinaryHeap...)
1111
end
1212
function reduce_samples(ps::AbstractArray, rngs, t::Union{TypeS,TypeUnion}, ss::AbstractArray...)
1313
nt = length(ss)
14-
v = Vector{Vector{get_type_rs(t, ss...)}}(undef, nt)
14+
T = get_type_rs(t, ss...)
15+
v = Vector{Vector{T}}(undef, nt)
1516
n = minimum(length.(ss))
1617
ns = rand(extract_rng(rngs, 1), Multinomial(n, ps))
1718
Threads.@threads for i in 1:nt
18-
v[i] = sample(extract_rng(rngs, i), ss[i], ns[i]; replace = false)
19+
s = ss[i]
20+
vi = Vector{T}(undef, ns[i])
21+
@inbounds for (q, j) in enumerate(SequentialSampler(extract_rng(rngs, i),
22+
ns[i], length(s), AlgHiddenShuffle()))
23+
vi[q] = s[j]
24+
end
25+
v[i] = vi
1926
end
2027
return reduce(vcat, v)
2128
end

0 commit comments

Comments
 (0)