11using Random
2- export EpisodesSampler, Episode, BatchSampler, NStepBatchSampler, MetaSampler, MultiBatchSampler, DummySampler
2+ export EpisodesSampler, Episode, BatchSampler, NStepBatchSampler, MetaSampler, MultiBatchSampler, DummySampler, MultiStepSampler
33
44struct SampleGenerator{S,T}
55 sampler:: S
@@ -29,27 +29,27 @@ StatsBase.sample(::DummySampler, t) = t
2929export BatchSampler
3030
3131struct BatchSampler{names}
32- batch_size :: Int
32+ batchsize :: Int
3333 rng:: Random.AbstractRNG
3434end
3535
3636"""
37- BatchSampler{names}(;batch_size , rng=Random.GLOBAL_RNG)
38- BatchSampler{names}(batch_size ;rng=Random.GLOBAL_RNG)
37+ BatchSampler{names}(;batchsize , rng=Random.GLOBAL_RNG)
38+ BatchSampler{names}(batchsize ;rng=Random.GLOBAL_RNG)
3939
40- Uniformly sample **ONE** batch of `batch_size ` examples for each trace specified
40+ Uniformly sample **ONE** batch of `batchsize ` examples for each trace specified
4141in `names`. If `names` is not set, all the traces will be sampled.
4242"""
43- BatchSampler (batch_size ; kw... ) = BatchSampler (; batch_size = batch_size , kw... )
43+ BatchSampler (batchsize ; kw... ) = BatchSampler (; batchsize = batchsize , kw... )
4444BatchSampler (; kw... ) = BatchSampler {nothing} (; kw... )
45- BatchSampler {names} (batch_size ; kw... ) where {names} = BatchSampler {names} (; batch_size = batch_size , kw... )
46- BatchSampler {names} (; batch_size , rng= Random. GLOBAL_RNG) where {names} = BatchSampler {names} (batch_size , rng)
45+ BatchSampler {names} (batchsize ; kw... ) where {names} = BatchSampler {names} (; batchsize = batchsize , kw... )
46+ BatchSampler {names} (; batchsize , rng= Random. GLOBAL_RNG) where {names} = BatchSampler {names} (batchsize , rng)
4747
4848StatsBase. sample (s:: BatchSampler{nothing} , t:: AbstractTraces ) = StatsBase. sample (s, t, keys (t))
4949StatsBase. sample (s:: BatchSampler{names} , t:: AbstractTraces ) where {names} = StatsBase. sample (s, t, names)
5050
5151function StatsBase. sample (s:: BatchSampler , t:: AbstractTraces , names, weights = StatsBase. UnitWeights {Int} (length (t)))
52- inds = StatsBase. sample (s. rng, 1 : length (t), weights, s. batch_size )
52+ inds = StatsBase. sample (s. rng, 1 : length (t), weights, s. batchsize )
5353 NamedTuple {names} (map (x -> collect (t[Val (x)][inds]), names))
5454end
5555
@@ -75,12 +75,12 @@ function StatsBase.sample(s::BatchSampler, e::EpisodesBuffer{<:Any, <:Any, <:Cir
7575 p = collect (deepcopy (t. priorities))
7676 w = StatsBase. FrequencyWeights (p)
7777 w .*= e. sampleable_inds[1 : end - 1 ]
78- inds = StatsBase. sample (s. rng, eachindex (w), w, s. batch_size )
78+ inds = StatsBase. sample (s. rng, eachindex (w), w, s. batchsize )
7979 NamedTuple {(:key, :priority, names...)} ((t. keys[inds], p[inds], map (x -> collect (t. traces[Val (x)][inds]), names)... ))
8080end
8181
8282function StatsBase. sample (s:: BatchSampler , t:: CircularPrioritizedTraces , names)
83- inds, priorities = rand (s. rng, t. priorities, s. batch_size )
83+ inds, priorities = rand (s. rng, t. priorities, s. batchsize )
8484 NamedTuple {(:key, :priority, names...)} ((t. keys[inds], priorities, map (x -> collect (t. traces[Val (x)][inds]), names)... ))
8585end
8686
@@ -165,41 +165,41 @@ end
165165export NStepBatchSampler
166166
167167"""
168- NStepBatchSampler{names}(; n, γ, batch_size =32, stack_size =nothing, rng=Random.GLOBAL_RNG)
168+ NStepBatchSampler{names}(; n, γ, batchsize =32, stacksize =nothing, rng=Random.GLOBAL_RNG)
169169
170170Used to sample a discounted sum of consecutive rewards in the framework of n-step TD learning.
171171The "next" element of Multiplexed traces (such as the next_state or the next_action) will be
172172that in up to `n > 1` steps later in the buffer. The reward will be
173173the discounted sum of the `n` rewards, with `γ` as the discount factor.
174174
175- NStepBatchSampler may also be used with n ≥ 1 to sample a "stack" of states if `stack_size ` is set
176- to an integer > 1. This samples the (stack_size - 1) previous states. This is useful in the case
177- of partial observability, for example when the state is approximated by `stack_size ` consecutive
175+ NStepBatchSampler may also be used with n ≥ 1 to sample a "stack" of states if `stacksize ` is set
176+ to an integer > 1. This samples the (stacksize - 1) previous states. This is useful in the case
177+ of partial observability, for example when the state is approximated by `stacksize ` consecutive
178178frames.
179179"""
180- mutable struct NStepBatchSampler{names, S <: Union{Nothing,Int} }
180+ mutable struct NStepBatchSampler{names, S <: Union{Nothing,Int} , R <: AbstractRNG }
181181 n:: Int # !!! n starts from 1
182182 γ:: Float32
183- batch_size :: Int
184- stack_size :: S
185- rng:: Any
183+ batchsize :: Int
184+ stacksize :: S
185+ rng:: R
186186end
187187
188188NStepBatchSampler (t:: AbstractTraces ; kw... ) = NStepBatchSampler {keys(t)} (; kw... )
189- function NStepBatchSampler {names} (; n, γ, batch_size = 32 , stack_size = nothing , rng= Random. GLOBAL_RNG ) where {names}
189+ function NStepBatchSampler {names} (; n, γ, batchsize = 32 , stacksize = nothing , rng= Random. default_rng () ) where {names}
190190 @assert n >= 1 " n must be ≥ 1."
191- ss = stack_size == 1 ? nothing : stack_size
192- NStepBatchSampler {names, typeof(ss)} (n, γ, batch_size , ss, rng)
191+ ss = stacksize == 1 ? nothing : stacksize
192+ NStepBatchSampler {names, typeof(ss), typeof(rng) } (n, γ, batchsize , ss, rng)
193193end
194194
195- # return a boolean vector of the valid sample indices given the stack_size and the truncated n for each index.
195+ # return a boolean vector of the valid sample indices given the stacksize and the truncated n for each index.
196196function valid_range (s:: NStepBatchSampler , eb:: EpisodesBuffer )
197197 range = copy (eb. sampleable_inds)
198198 ns = Vector {Int} (undef, length (eb. sampleable_inds))
199- stack_size = isnothing (s. stack_size ) ? 1 : s. stack_size
199+ stacksize = isnothing (s. stacksize ) ? 1 : s. stacksize
200200 for idx in eachindex (range)
201201 step_number = eb. step_numbers[idx]
202- range[idx] = step_number >= stack_size && eb. sampleable_inds[idx]
202+ range[idx] = step_number >= stacksize && eb. sampleable_inds[idx]
203203 ns[idx] = min (s. n, eb. episodes_lengths[idx] - step_number + 1 )
204204 end
205205 return range, ns
@@ -211,19 +211,19 @@ end
211211
212212function StatsBase. sample (s:: NStepBatchSampler , t:: EpisodesBuffer , :: Val{names} ) where names
213213 weights, ns = valid_range (s, t)
214- inds = StatsBase. sample (s. rng, 1 : length (t), StatsBase. FrequencyWeights (weights[1 : end - 1 ]), s. batch_size )
214+ inds = StatsBase. sample (s. rng, 1 : length (t), StatsBase. FrequencyWeights (weights[1 : end - 1 ]), s. batchsize )
215215 fetch (s, t, Val (names), inds, ns)
216216end
217217
218218function fetch (s:: NStepBatchSampler , ts:: EpisodesBuffer , :: Val{names} , inds, ns) where names
219219 NamedTuple {names} (map (name -> collect (fetch (s, ts[name], Val (name), inds, ns[inds])), names))
220220end
221221
222- # state and next_state have specialized fetch methods due to stack_size
222+ # state and next_state have specialized fetch methods due to stacksize
223223fetch (:: NStepBatchSampler{names, Nothing} , trace:: AbstractTrace , :: Val{:state} , inds, ns) where {names} = trace[inds]
224- fetch (s:: NStepBatchSampler{names, Int} , trace:: AbstractTrace , :: Val{:state} , inds, ns) where {names} = trace[[x + i for i in - s. stack_size + 1 : 0 , x in inds]]
224+ fetch (s:: NStepBatchSampler{names, Int} , trace:: AbstractTrace , :: Val{:state} , inds, ns) where {names} = trace[[x + i for i in - s. stacksize + 1 : 0 , x in inds]]
225225fetch (:: NStepBatchSampler{names, Nothing} , trace:: RelativeTrace{1,0} , :: Val{:next_state} , inds, ns) where {names} = trace[inds .+ ns .- 1 ]
226- fetch (s:: NStepBatchSampler{names, Int} , trace:: RelativeTrace{1,0} , :: Val{:next_state} , inds, ns) where {names} = trace[[x + ns[idx] - 1 + i for i in - s. stack_size + 1 : 0 , (idx,x) in enumerate (inds)]]
226+ fetch (s:: NStepBatchSampler{names, Int} , trace:: RelativeTrace{1,0} , :: Val{:next_state} , inds, ns) where {names} = trace[[x + ns[idx] - 1 + i for i in - s. stacksize + 1 : 0 , (idx,x) in enumerate (inds)]]
227227
228228# reward due to discounting
229229function fetch (s:: NStepBatchSampler , trace:: AbstractTrace , :: Val{:reward} , inds, ns)
@@ -247,7 +247,7 @@ function StatsBase.sample(s::NStepBatchSampler{names}, e::EpisodesBuffer{<:Any,
247247 w = StatsBase. FrequencyWeights (p)
248248 valids, ns = valid_range (s,e)
249249 w .*= valids[1 : end - 1 ]
250- inds = StatsBase. sample (s. rng, eachindex (w), w, s. batch_size )
250+ inds = StatsBase. sample (s. rng, eachindex (w), w, s. batchsize )
251251 merge (
252252 (key= t. keys[inds], priority= p[inds]),
253253 fetch (s, e, Val (names), inds, ns)
@@ -297,3 +297,73 @@ function StatsBase.sample(::EpisodesSampler, t::EpisodesBuffer, names)
297297
298298 return [make_episode (t, r, names) for r in ranges]
299299end
300+
301+ # ####MultiStepSampler
302+
303+ """
304+ MultiStepSampler{names}(batchsize, stacksize, n, rng)
305+
306+ A sampler that returns n consecutive steps of each trace. The samples are
307+ returned in an array of batchsize elements. For each element, n is truncated by the end
308+ of its episode. This means that the dimensions of each sample are not the same.
309+ """
310+ struct MultiStepSampler{names, S <: Union{Nothing,Int} , R <: AbstractRNG }
311+ n:: Int
312+ batchsize:: Int
313+ stacksize:: Int
314+ rng:: R
315+ end
316+
317+ MultiStepSampler (t:: AbstractTraces ; kw... ) = MultiStepSampler {keys(t)} (; kw... )
318+ function MultiStepSampler {names} (; n, batchsize= 32 , stacksize= nothing , rng= Random. default_rng ()) where {names}
319+ @assert n >= 1 " n must be ≥ 1."
320+ ss = stacksize == 1 ? nothing : stacksize
321+ MultiStepSampler {names, typeof(ss), typeof(rng)} (n, batchsize, ss, rng)
322+ end
323+
324+ function valid_range (s:: MultiStepSampler , eb:: EpisodesBuffer )
325+ range = copy (eb. sampleable_inds)
326+ ns = Vector {Int} (undef, length (eb. sampleable_inds))
327+ stacksize = isnothing (s. stacksize) ? 1 : s. stacksize
328+ for idx in eachindex (range)
329+ step_number = eb. step_numbers[idx]
330+ range[idx] = step_number >= stacksize && eb. sampleable_inds[idx]
331+ ns[idx] = min (s. n, eb. episodes_lengths[idx] - step_number + 1 )
332+ end
333+ return range, ns
334+ end
335+
336+ function StatsBase. sample (s:: MultiStepSampler{names} , ts) where {names}
337+ StatsBase. sample (s, ts, Val (names))
338+ end
339+
340+ function StatsBase. sample (s:: MultiStepSampler , t:: EpisodesBuffer , :: Val{names} ) where names
341+ weights, ns = valid_range (s, t)
342+ inds = StatsBase. sample (s. rng, 1 : length (t), StatsBase. FrequencyWeights (weights[1 : end - 1 ]), s. batchsize)
343+ fetch (s, t, Val (names), inds, ns)
344+ end
345+
346+ function fetch (s:: MultiStepSampler , ts:: EpisodesBuffer , :: Val{names} , inds, ns) where names
347+ NamedTuple {names} (map (name -> collect (fetch (s, ts[name], Val (name), inds, ns[inds])), names))
348+ end
349+
350+ function fetch (:: MultiStepSampler , trace, :: Val , inds, ns)
351+ [trace[idx: (idx + ns[i] - 1 )] for (i,idx) in enumerate (inds)]
352+ end
353+
354+ function fetch (s:: MultiStepSampler{names, Int} , trace:: AbstractTrace , :: Union{Val{:state}, Val{:next_state}} , inds, ns) where {names}
355+ [trace[[idx + i + n - 1 for i in - s. stacksize+ 1 : 0 , n in 1 : ns[j]]] for (j,idx) in enumerate (inds)]
356+ end
357+
358+ function StatsBase. sample (s:: MultiStepSampler{names} , e:: EpisodesBuffer{<:Any, <:Any, <:CircularPrioritizedTraces} ) where {names}
359+ t = e. traces
360+ p = collect (deepcopy (t. priorities))
361+ w = StatsBase. FrequencyWeights (p)
362+ valids, ns = valid_range (s,e)
363+ w .*= valids[1 : end - 1 ]
364+ inds = StatsBase. sample (s. rng, eachindex (w), w, s. batchsize)
365+ merge (
366+ (key= t. keys[inds], priority= p[inds]),
367+ fetch (s, e, Val (names), inds, ns)
368+ )
369+ end
0 commit comments