Skip to content

Commit fc9e022

Browse files
committed
add MultiStepSampler
1 parent a9c5b5c commit fc9e022

File tree

2 files changed

+150
-36
lines changed

2 files changed

+150
-36
lines changed

src/samplers.jl

Lines changed: 100 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
using Random
2-
export EpisodesSampler, Episode, BatchSampler, NStepBatchSampler, MetaSampler, MultiBatchSampler, DummySampler
2+
export EpisodesSampler, Episode, BatchSampler, NStepBatchSampler, MetaSampler, MultiBatchSampler, DummySampler, MultiStepSampler
33

44
struct SampleGenerator{S,T}
55
sampler::S
@@ -29,27 +29,27 @@ StatsBase.sample(::DummySampler, t) = t
2929
export BatchSampler
3030

3131
struct BatchSampler{names}
32-
batch_size::Int
32+
batchsize::Int
3333
rng::Random.AbstractRNG
3434
end
3535

3636
"""
37-
BatchSampler{names}(;batch_size, rng=Random.GLOBAL_RNG)
38-
BatchSampler{names}(batch_size ;rng=Random.GLOBAL_RNG)
37+
BatchSampler{names}(;batchsize, rng=Random.GLOBAL_RNG)
38+
BatchSampler{names}(batchsize ;rng=Random.GLOBAL_RNG)
3939
40-
Uniformly sample **ONE** batch of `batch_size` examples for each trace specified
40+
Uniformly sample **ONE** batch of `batchsize` examples for each trace specified
4141
in `names`. If `names` is not set, all the traces will be sampled.
4242
"""
43-
BatchSampler(batch_size; kw...) = BatchSampler(; batch_size=batch_size, kw...)
43+
BatchSampler(batchsize; kw...) = BatchSampler(; batchsize=batchsize, kw...)
4444
BatchSampler(; kw...) = BatchSampler{nothing}(; kw...)
45-
BatchSampler{names}(batch_size; kw...) where {names} = BatchSampler{names}(; batch_size=batch_size, kw...)
46-
BatchSampler{names}(; batch_size, rng=Random.GLOBAL_RNG) where {names} = BatchSampler{names}(batch_size, rng)
45+
BatchSampler{names}(batchsize; kw...) where {names} = BatchSampler{names}(; batchsize=batchsize, kw...)
46+
BatchSampler{names}(; batchsize, rng=Random.GLOBAL_RNG) where {names} = BatchSampler{names}(batchsize, rng)
4747

4848
StatsBase.sample(s::BatchSampler{nothing}, t::AbstractTraces) = StatsBase.sample(s, t, keys(t))
4949
StatsBase.sample(s::BatchSampler{names}, t::AbstractTraces) where {names} = StatsBase.sample(s, t, names)
5050

5151
function StatsBase.sample(s::BatchSampler, t::AbstractTraces, names, weights = StatsBase.UnitWeights{Int}(length(t)))
52-
inds = StatsBase.sample(s.rng, 1:length(t), weights, s.batch_size)
52+
inds = StatsBase.sample(s.rng, 1:length(t), weights, s.batchsize)
5353
NamedTuple{names}(map(x -> collect(t[Val(x)][inds]), names))
5454
end
5555

@@ -75,12 +75,12 @@ function StatsBase.sample(s::BatchSampler, e::EpisodesBuffer{<:Any, <:Any, <:Cir
7575
p = collect(deepcopy(t.priorities))
7676
w = StatsBase.FrequencyWeights(p)
7777
w .*= e.sampleable_inds[1:end-1]
78-
inds = StatsBase.sample(s.rng, eachindex(w), w, s.batch_size)
78+
inds = StatsBase.sample(s.rng, eachindex(w), w, s.batchsize)
7979
NamedTuple{(:key, :priority, names...)}((t.keys[inds], p[inds], map(x -> collect(t.traces[Val(x)][inds]), names)...))
8080
end
8181

8282
function StatsBase.sample(s::BatchSampler, t::CircularPrioritizedTraces, names)
83-
inds, priorities = rand(s.rng, t.priorities, s.batch_size)
83+
inds, priorities = rand(s.rng, t.priorities, s.batchsize)
8484
NamedTuple{(:key, :priority, names...)}((t.keys[inds], priorities, map(x -> collect(t.traces[Val(x)][inds]), names)...))
8585
end
8686

@@ -165,41 +165,41 @@ end
165165
export NStepBatchSampler
166166

167167
"""
168-
NStepBatchSampler{names}(; n, γ, batch_size=32, stack_size=nothing, rng=Random.GLOBAL_RNG)
168+
NStepBatchSampler{names}(; n, γ, batchsize=32, stacksize=nothing, rng=Random.GLOBAL_RNG)
169169
170170
Used to sample a discounted sum of consecutive rewards in the framework of n-step TD learning.
171171
The "next" element of Multiplexed traces (such as the next_state or the next_action) will be
172172
that in up to `n > 1` steps later in the buffer. The reward will be
173173
the discounted sum of the `n` rewards, with `γ` as the discount factor.
174174
175-
NStepBatchSampler may also be used with n ≥ 1 to sample a "stack" of states if `stack_size` is set
176-
to an integer > 1. This samples the (stack_size - 1) previous states. This is useful in the case
177-
of partial observability, for example when the state is approximated by `stack_size` consecutive
175+
NStepBatchSampler may also be used with n ≥ 1 to sample a "stack" of states if `stacksize` is set
176+
to an integer > 1. This samples the (stacksize - 1) previous states. This is useful in the case
177+
of partial observability, for example when the state is approximated by `stacksize` consecutive
178178
frames.
179179
"""
180-
mutable struct NStepBatchSampler{names, S <: Union{Nothing,Int}}
180+
mutable struct NStepBatchSampler{names, S <: Union{Nothing,Int}, R <: AbstractRNG}
181181
n::Int # !!! n starts from 1
182182
γ::Float32
183-
batch_size::Int
184-
stack_size::S
185-
rng::Any
183+
batchsize::Int
184+
stacksize::S
185+
rng::R
186186
end
187187

188188
NStepBatchSampler(t::AbstractTraces; kw...) = NStepBatchSampler{keys(t)}(; kw...)
189-
function NStepBatchSampler{names}(; n, γ, batch_size=32, stack_size=nothing, rng=Random.GLOBAL_RNG) where {names}
189+
function NStepBatchSampler{names}(; n, γ, batchsize=32, stacksize=nothing, rng=Random.default_rng()) where {names}
190190
@assert n >= 1 "n must be ≥ 1."
191-
ss = stack_size == 1 ? nothing : stack_size
192-
NStepBatchSampler{names, typeof(ss)}(n, γ, batch_size, ss, rng)
191+
ss = stacksize == 1 ? nothing : stacksize
192+
NStepBatchSampler{names, typeof(ss), typeof(rng)}(n, γ, batchsize, ss, rng)
193193
end
194194

195-
#return a boolean vector of the valid sample indices given the stack_size and the truncated n for each index.
195+
#return a boolean vector of the valid sample indices given the stacksize and the truncated n for each index.
196196
function valid_range(s::NStepBatchSampler, eb::EpisodesBuffer)
197197
range = copy(eb.sampleable_inds)
198198
ns = Vector{Int}(undef, length(eb.sampleable_inds))
199-
stack_size = isnothing(s.stack_size) ? 1 : s.stack_size
199+
stacksize = isnothing(s.stacksize) ? 1 : s.stacksize
200200
for idx in eachindex(range)
201201
step_number = eb.step_numbers[idx]
202-
range[idx] = step_number >= stack_size && eb.sampleable_inds[idx]
202+
range[idx] = step_number >= stacksize && eb.sampleable_inds[idx]
203203
ns[idx] = min(s.n, eb.episodes_lengths[idx] - step_number + 1)
204204
end
205205
return range, ns
@@ -211,19 +211,19 @@ end
211211

212212
function StatsBase.sample(s::NStepBatchSampler, t::EpisodesBuffer, ::Val{names}) where names
213213
weights, ns = valid_range(s, t)
214-
inds = StatsBase.sample(s.rng, 1:length(t), StatsBase.FrequencyWeights(weights[1:end-1]), s.batch_size)
214+
inds = StatsBase.sample(s.rng, 1:length(t), StatsBase.FrequencyWeights(weights[1:end-1]), s.batchsize)
215215
fetch(s, t, Val(names), inds, ns)
216216
end
217217

218218
function fetch(s::NStepBatchSampler, ts::EpisodesBuffer, ::Val{names}, inds, ns) where names
219219
NamedTuple{names}(map(name -> collect(fetch(s, ts[name], Val(name), inds, ns[inds])), names))
220220
end
221221

222-
#state and next_state have specialized fetch methods due to stack_size
222+
#state and next_state have specialized fetch methods due to stacksize
223223
fetch(::NStepBatchSampler{names, Nothing}, trace::AbstractTrace, ::Val{:state}, inds, ns) where {names} = trace[inds]
224-
fetch(s::NStepBatchSampler{names, Int}, trace::AbstractTrace, ::Val{:state}, inds, ns) where {names} = trace[[x + i for i in -s.stack_size+1:0, x in inds]]
224+
fetch(s::NStepBatchSampler{names, Int}, trace::AbstractTrace, ::Val{:state}, inds, ns) where {names} = trace[[x + i for i in -s.stacksize+1:0, x in inds]]
225225
fetch(::NStepBatchSampler{names, Nothing}, trace::RelativeTrace{1,0}, ::Val{:next_state}, inds, ns) where {names} = trace[inds .+ ns .- 1]
226-
fetch(s::NStepBatchSampler{names, Int}, trace::RelativeTrace{1,0}, ::Val{:next_state}, inds, ns) where {names} = trace[[x + ns[idx] - 1 + i for i in -s.stack_size+1:0, (idx,x) in enumerate(inds)]]
226+
fetch(s::NStepBatchSampler{names, Int}, trace::RelativeTrace{1,0}, ::Val{:next_state}, inds, ns) where {names} = trace[[x + ns[idx] - 1 + i for i in -s.stacksize+1:0, (idx,x) in enumerate(inds)]]
227227

228228
#reward due to discounting
229229
function fetch(s::NStepBatchSampler, trace::AbstractTrace, ::Val{:reward}, inds, ns)
@@ -247,7 +247,7 @@ function StatsBase.sample(s::NStepBatchSampler{names}, e::EpisodesBuffer{<:Any,
247247
w = StatsBase.FrequencyWeights(p)
248248
valids, ns = valid_range(s,e)
249249
w .*= valids[1:end-1]
250-
inds = StatsBase.sample(s.rng, eachindex(w), w, s.batch_size)
250+
inds = StatsBase.sample(s.rng, eachindex(w), w, s.batchsize)
251251
merge(
252252
(key=t.keys[inds], priority=p[inds]),
253253
fetch(s, e, Val(names), inds, ns)
@@ -297,3 +297,73 @@ function StatsBase.sample(::EpisodesSampler, t::EpisodesBuffer, names)
297297

298298
return [make_episode(t, r, names) for r in ranges]
299299
end
300+
301+
#####MultiStepSampler
302+
303+
"""
304+
MultiStepSampler{names}(batchsize, stacksize, n, rng)
305+
306+
A sampler that returns n consecutive steps of each trace. The samples are
307+
returned in an array of batchsize elements. For each element, n is truncated by the end
308+
of its episode. This means that the dimensions of each sample are not the same.
309+
"""
310+
struct MultiStepSampler{names, S <: Union{Nothing,Int}, R <: AbstractRNG}
311+
n::Int
312+
batchsize::Int
313+
stacksize::Int
314+
rng::R
315+
end
316+
317+
MultiStepSampler(t::AbstractTraces; kw...) = MultiStepSampler{keys(t)}(; kw...)
318+
function MultiStepSampler{names}(; n, batchsize=32, stacksize=nothing, rng=Random.default_rng()) where {names}
319+
@assert n >= 1 "n must be ≥ 1."
320+
ss = stacksize == 1 ? nothing : stacksize
321+
MultiStepSampler{names, typeof(ss), typeof(rng)}(n, batchsize, ss, rng)
322+
end
323+
324+
function valid_range(s::MultiStepSampler, eb::EpisodesBuffer)
325+
range = copy(eb.sampleable_inds)
326+
ns = Vector{Int}(undef, length(eb.sampleable_inds))
327+
stacksize = isnothing(s.stacksize) ? 1 : s.stacksize
328+
for idx in eachindex(range)
329+
step_number = eb.step_numbers[idx]
330+
range[idx] = step_number >= stacksize && eb.sampleable_inds[idx]
331+
ns[idx] = min(s.n, eb.episodes_lengths[idx] - step_number + 1)
332+
end
333+
return range, ns
334+
end
335+
336+
function StatsBase.sample(s::MultiStepSampler{names}, ts) where {names}
337+
StatsBase.sample(s, ts, Val(names))
338+
end
339+
340+
function StatsBase.sample(s::MultiStepSampler, t::EpisodesBuffer, ::Val{names}) where names
341+
weights, ns = valid_range(s, t)
342+
inds = StatsBase.sample(s.rng, 1:length(t), StatsBase.FrequencyWeights(weights[1:end-1]), s.batchsize)
343+
fetch(s, t, Val(names), inds, ns)
344+
end
345+
346+
function fetch(s::MultiStepSampler, ts::EpisodesBuffer, ::Val{names}, inds, ns) where names
347+
NamedTuple{names}(map(name -> collect(fetch(s, ts[name], Val(name), inds, ns[inds])), names))
348+
end
349+
350+
function fetch(::MultiStepSampler, trace, ::Val, inds, ns)
351+
[trace[idx:(idx + ns[i] - 1)] for (i,idx) in enumerate(inds)]
352+
end
353+
354+
function fetch(s::MultiStepSampler{names, Int}, trace::AbstractTrace, ::Union{Val{:state}, Val{:next_state}}, inds, ns) where {names}
355+
[trace[[idx + i + n - 1 for i in -s.stacksize+1:0, n in 1:ns[j]]] for (j,idx) in enumerate(inds)]
356+
end
357+
358+
function StatsBase.sample(s::MultiStepSampler{names}, e::EpisodesBuffer{<:Any, <:Any, <:CircularPrioritizedTraces}) where {names}
359+
t = e.traces
360+
p = collect(deepcopy(t.priorities))
361+
w = StatsBase.FrequencyWeights(p)
362+
valids, ns = valid_range(s,e)
363+
w .*= valids[1:end-1]
364+
inds = StatsBase.sample(s.rng, eachindex(w), w, s.batchsize)
365+
merge(
366+
(key=t.keys[inds], priority=p[inds]),
367+
fetch(s, e, Val(names), inds, ns)
368+
)
369+
end

test/samplers.jl

Lines changed: 50 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -78,9 +78,9 @@ import ReinforcementLearningTrajectories.fetch
7878
γ = 0.99
7979
n_stack = 2
8080
n_horizon = 3
81-
batch_size = 1000
81+
batchsize = 1000
8282
eb = EpisodesBuffer(CircularArraySARTSATraces(capacity=10))
83-
s1 = NStepBatchSampler(eb, n=n_horizon, γ=γ, stack_size=n_stack, batch_size=batch_size)
83+
s1 = NStepBatchSampler(eb, n=n_horizon, γ=γ, stacksize=n_stack, batchsize=batchsize)
8484

8585
push!(eb, (state = 1, action = 1))
8686
for i = 1:5
@@ -98,12 +98,12 @@ import ReinforcementLearningTrajectories.fetch
9898
for key in keys(eb)
9999
@test haskey(batch, key)
100100
end
101-
#state: samples with stack_size
101+
#state: samples with stacksize
102102
states = ReinforcementLearningTrajectories.fetch(s1, eb[:state], Val(:state), inds, ns[inds])
103103
@test states == [1 2 3 4 7 8 9;
104104
2 3 4 5 8 9 10]
105105
@test all(in(eachcol(states)), unique(eachcol(batch[:state])))
106-
#next_state: samples with stack_size and nsteps forward
106+
#next_state: samples with stacksize and nsteps forward
107107
next_states = ReinforcementLearningTrajectories.fetch(s1, eb[:next_state], Val(:next_state), inds, ns[inds])
108108
@test next_states == [4 5 5 5 10 10 10;
109109
5 6 6 6 11 11 11]
@@ -127,9 +127,9 @@ import ReinforcementLearningTrajectories.fetch
127127
### CircularPrioritizedTraces and NStepBatchSampler
128128
γ = 0.99
129129
n_horizon = 3
130-
batch_size = 4
130+
batchsize = 4
131131
eb = EpisodesBuffer(CircularPrioritizedTraces(CircularArraySARTSATraces(capacity=10), default_priority = 10f0))
132-
s1 = NStepBatchSampler(eb, n=n_horizon, γ=γ, batch_size=batch_size)
132+
s1 = NStepBatchSampler(eb, n=n_horizon, γ=γ, batchsize=batchsize)
133133

134134
push!(eb, (state = 1, action = 1))
135135
for i = 1:5
@@ -196,4 +196,48 @@ import ReinforcementLearningTrajectories.fetch
196196
@test length(b[2][:state]) == 5
197197
@test !haskey(b[1], :action)
198198
end
199+
@testset "MultiStepSampler" begin
200+
n_stack = 2
201+
n_horizon = 3
202+
batchsize = 1000
203+
eb = EpisodesBuffer(CircularArraySARTSATraces(capacity=10))
204+
s1 = MultiStepSampler(eb, n=n_horizon, stacksize=n_stack, batchsize=batchsize)
205+
206+
push!(eb, (state = 1, action = 1))
207+
for i = 1:5
208+
push!(eb, (state = i+1, action =i+1, reward = i, terminal = i == 5))
209+
end
210+
push!(eb, (state = 7, action = 7))
211+
for (j,i) = enumerate(8:11)
212+
push!(eb, (state = i, action =i, reward = i-1, terminal = false))
213+
end
214+
weights, ns = ReinforcementLearningTrajectories.valid_range(s1, eb)
215+
@test weights == [0,1,1,1,1,0,0,1,1,1,0]
216+
@test ns == [3,3,3,2,1,-1,3,3,2,1,0] #the -1 is due to ep_lengths[6] being that of 2nd episode but step_numbers[6] being that of 1st episode
217+
inds = [i for i in eachindex(weights) if weights[i] == 1]
218+
batch = sample(s1, eb)
219+
for key in keys(eb)
220+
@test haskey(batch, key)
221+
end
222+
#state and next_state: samples with stacksize
223+
states = ReinforcementLearningTrajectories.fetch(s1, eb[:state], Val(:state), inds, ns[inds])
224+
@test states == [[1 2 3; 2 3 4], [2 3 4; 3 4 5], [3 4; 4 5], [4; 5;;], [7 8 9; 8 9 10], [8 9; 9 10], [9; 10;;]]
225+
@test all(in(states), batch[:state])
226+
#next_state: samples with stacksize and nsteps forward
227+
next_states = ReinforcementLearningTrajectories.fetch(s1, eb[:next_state], Val(:next_state), inds, ns[inds])
228+
@test next_states == [[2 3 4; 3 4 5], [3 4 5; 4 5 6], [4 5; 5 6], [5; 6;;], [8 9 10; 9 10 11], [9 10; 10 11], [10; 11;;]]
229+
@test all(in(next_states), batch[:next_state])
230+
#all other traces sample normally
231+
actions = ReinforcementLearningTrajectories.fetch(s1, eb[:action], Val(:action), inds, ns[inds])
232+
@test actions == [[2,3,4], [3,4,5], [4,5], [5], [8,9,10], [9,10],[10]]
233+
@test all(in(actions), batch[:action])
234+
next_actions = ReinforcementLearningTrajectories.fetch(s1, eb[:next_action], Val(:next_action), inds, ns[inds])
235+
@test next_actions == [a .+ 1 for a in [[2,3,4], [3,4,5], [4,5], [5], [8,9,10], [9,10],[10]]]
236+
@test all(in(next_actions), batch[:next_action])
237+
rewards = ReinforcementLearningTrajectories.fetch(s1, eb[:reward], Val(:reward), inds, ns[inds])
238+
@test rewards == actions
239+
@test all(in(rewards), batch[:reward])
240+
terminals = ReinforcementLearningTrajectories.fetch(s1, eb[:terminal], Val(:terminal), inds, ns[inds])
241+
@test terminals == [[a == 5 ? 1 : 0 for a in acs] for acs in actions]
242+
end
199243
end

0 commit comments

Comments
 (0)