@@ -168,86 +168,77 @@ export NStepBatchSampler
168168
169169Used to sample a discounted sum of consecutive rewards in the framework of n-step TD learning.
170170The "next" element of Multiplexed traces (such as the next_state or the next_action) will be
171- that in up to `n > 1` steps later in the buffer (or the last of the episode) . The reward will be
171+ that in up to `n > 1` steps later in the buffer. The reward will be
172172the discounted sum of the `n` rewards, with `γ` as the discount factor.
173173
174174NStepBatchSampler may also be used with n ≥ 1 to sample a "stack" of states if `stack_size` is set
175175to an integer > 1. This samples the (stack_size - 1) previous states. This is useful in the case
176176of partial observability, for example when the state is approximated by `stack_size` consecutive
177177frames.
178178"""
179- mutable struct NStepBatchSampler{traces }
179+ mutable struct NStepBatchSampler{names, S <: Union{Nothing,Int} }
180180 n:: Int # !!! n starts from 1
181181 γ:: Float32
182182 batch_size:: Int
183- stack_size:: Union{Nothing,Int}
183+ stack_size:: S
184184 rng:: Any
185185end
186186
187187NStepBatchSampler (; kw... ) = NStepBatchSampler {SS′ART} (; kw... )
188188function NStepBatchSampler {names} (; n, γ, batch_size= 32 , stack_size= nothing , rng= Random. GLOBAL_RNG) where {names}
189189 @assert n >= 1 " n must be ≥ 1."
190- NStepBatchSampler {names} (n, γ, batch_size, stack_size, rng)
190+ NStepBatchSampler {names} (n, γ, batch_size, stack_size == 1 ? nothing : stack_size , rng)
191191end
192192
193-
194- function valid_range_nbatchsampler (s:: NStepBatchSampler , ts)
195- # think about the extreme case where s.stack_size == 1 and s.n == 1
196- isnothing (s. stack_size) ? (1 : (length (ts)- s. n+ 1 )) : (s. stack_size: (length (ts)- s. n+ 1 ))
193+ function valid_range_nbatchsampler (s:: NStepBatchSampler , eb:: EpisodesBuffer )
194+ range = copy (eb. sampleable_inds)
195+ stack_size = isnothing (s. stack_size) ? 1 : s. stack_size
196+ for idx in eachindex (range)
197+ valid = eb. step_numbers[idx] >= stack_size && eb. step_numbers[idx] <= eb. episodes_lengths[idx] + 1 - eb. n && eb. sampleable_inds[idx]
198+ range[idx] = valid
199+ end
200+ return range
197201end
202+
198203function StatsBase. sample (s:: NStepBatchSampler{names} , ts) where {names}
199- valid_range = valid_range_nbatchsampler (s, ts)
200- inds = rand (s. rng, valid_range, s. batch_size)
201- StatsBase. sample (s, ts, Val (names), inds)
204+ StatsBase. sample (s, ts, Val (names))
202205end
203206
204- function StatsBase. sample (s:: NStepBatchSampler{names} , ts:: EpisodesBuffer ) where {names}
205- valid_range = valid_range_nbatchsampler (s, ts)
206- valid_range = valid_range[valid_range .∈ (findall (ts. sampleable_inds),)] # Ensure that the valid range is within the sampleable indices, probably could be done more efficiently by refactoring `valid_range_nbatchsampler`
207- inds = rand (s. rng, valid_range, s. batch_size)
208- StatsBase. sample (s, ts, Val (names), inds)
207+ function StatsBase. sample (s:: NStepBatchSampler , t:: EpisodesBuffer , names)
208+ valid_range = valid_range_nbatchsampler (s, t)
209+ StatsBase. sample (s, t. traces, names, StatsBase. FrequencyWeights (valid_range))
209210end
210211
211-
212- function StatsBase. sample (nbs:: NStepBatchSampler , ts, :: Val{SS′ART} , inds)
213- if isnothing (nbs. stack_size)
214- s = ts[:state ][inds]
215- s′ = ts[:next_state ][inds.+ (nbs. n- 1 )]
216- else
217- s = ts[:state ][[x + i for i in - nbs. stack_size+ 1 : 0 , x in inds]]
218- s′ = ts[:next_state ][[x + nbs. n - 1 + i for i in - nbs. stack_size+ 1 : 0 , x in inds]]
219- end
220-
221- a = ts[:action ][inds]
222- t_horizon = ts[:terminal ][[x + j for j in 0 : nbs. n- 1 , x in inds]]
223- r_horizon = ts[:reward ][[x + j for j in 0 : nbs. n- 1 , x in inds]]
224-
225- @assert ndims (t_horizon) == 2
226- t = any (t_horizon, dims= 1 ) |> vec
227-
228- @assert ndims (r_horizon) == 2
229- r = map (eachcol (r_horizon), eachcol (t_horizon)) do r⃗, t⃗
230- foldr (((rr, tt), init) -> rr + nbs. γ * init * (1 - tt), zip (r⃗, t⃗); init= 0.0f0 )
231- end
232-
233- NamedTuple {SS′ART} (map (collect, (s, s′, a, r, t)))
212+ function StatsBase. sample (s:: NStepBatchSampler , t:: AbstractTraces , names, weights = StatsBase. UnitWeights {Int} (length (t)))
213+ inds = StatsBase. sample (s. rng, 1 : length (t), weights, s. batch_size)
214+ NamedTuple{names}map (name -> collect (fetch (s, ts[name], Val (name), inds)), names)
234215end
235216
236- function StatsBase. sample (s:: NStepBatchSampler , ts, :: Val{SS′L′ART} , inds)
237- s, s′, a, r, t = StatsBase. sample (s, ts, Val (SSART), inds)
238- l = consecutive_view (ts[:next_legal_actions_mask ], inds)
239- NamedTuple {SSLART} (map (collect, (s, s′, l, a, r, t)))
217+ # state and next_state have specialized fetch methods due to stack_size
218+ fetch (:: NStepBatchSampler{names, Nothing} , trace, :: Val{:state} , inds) where {names} = trace[inds]
219+ fetch (s:: NStepBatchSampler{names, Int} , trace, :: Val{:state} , inds) where {names} = trace[[x + s. n - 1 + i for i in - s. stack_size+ 1 : 0 , x in inds]]
220+ fetch (s:: NStepBatchSampler{names, Nothing} , trace, :: Val{:next_state} , inds) where {names} = trace[inds.+ (s. n- 1 )]
221+ fetch (s:: NStepBatchSampler{names, Int} , trace, :: Val{:next_state} , inds) where {names} = trace[[x + s. n - 1 + i for i in - s. stack_size+ 1 : 0 , x in inds]]
222+ # reward due to discounting
223+ function fetch (s:: NStepBatchSampler{names} , trace, :: Val{:reward} , inds) where {names}
224+ rewards = trace[[x + j for j in 0 : nbs. n- 1 , x in inds]]
225+ return reduce ((x,y)-> x + s. γ* y, rewards, init = zero (eltype (rewards)), dims = 1 )
240226end
227+ # terminal is that of the nth step
228+ fetch (s:: NStepBatchSampler{names} , trace, :: Val{:terminal} , inds) where {names} = trace[inds.+ s. n]
229+ # right multiplex traces must be n-step sampled
230+ fetch (:: NStepBatchSampler{names} , trace:: RelativeTrace{1,0} , :: Val{<:Symbol} , inds) where {names} = trace[inds.+ (s. n- 1 )]
231+ # normal trace types are fetched at inds
232+ fetch (:: NStepBatchSampler{names} , trace, :: Val{<:Symbol} , inds) where {names} = trace[inds] # other types of trace are sampled normaly
241233
242234function StatsBase. sample (s:: NStepBatchSampler{names} , e:: EpisodesBuffer{<:Any, <:Any, <:CircularPrioritizedTraces} ) where {names}
243235 t = e. traces
244236 st = deepcopy (t. priorities)
245- st .*= e . sampleable_inds[ 1 : end - 1 ] # temporary sumtree that puts 0 priority to non sampleable indices.
237+ st .*= valid_range_nbatchsampler (s,e) # temporary sumtree that puts 0 priority to non sampleable indices.
246238 inds, priorities = rand (s. rng, st, s. batch_size)
247-
248239 merge (
249240 (key= t. keys[inds], priority= priorities),
250- StatsBase . sample (s, t. traces, Val (names), inds)
241+ fetch (s, t. traces, Val (names), inds)
251242 )
252243end
253244
0 commit comments