Skip to content

Commit 31ce648

Browse files
Attempt to fix NStepBatchSampler...
1 parent 047e898 commit 31ce648

File tree

1 file changed

+14
-1
lines changed

1 file changed

+14
-1
lines changed

src/samplers.jl

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -173,12 +173,25 @@ end
173173
NStepBatchSampler(; kw...) = NStepBatchSampler{SS′ART}(; kw...)
174174
NStepBatchSampler{names}(; n, γ, batch_size=32, stack_size=nothing, rng=Random.GLOBAL_RNG) where {names} = NStepBatchSampler{names}(n, γ, batch_size, stack_size, rng)
175175

176+
177+
function valid_range_nbatchsampler(s::NStepBatchSampler, ts)
178+
# think about the extreme case where s.stack_size == 1 and s.n == 1
179+
isnothing(s.stack_size) ? (1:(length(ts)-s.n+1)) : (s.stack_size:(length(ts)-s.n+1))
180+
end
176181
function StatsBase.sample(s::NStepBatchSampler{names}, ts) where {names}
177-
valid_range = isnothing(s.stack_size) ? (1:(length(ts)-s.n+1)) : (s.stack_size:(length(ts)-s.n+1))# think about the exteme case where s.stack_size == 1 and s.n == 1
182+
valid_range = valid_range_nbatchsampler(s, ts)
178183
inds = rand(s.rng, valid_range, s.batch_size)
179184
StatsBase.sample(s, ts, Val(names), inds)
180185
end
181186

187+
function StatsBase.sample(s::NStepBatchSampler{names}, ts::EpisodesBuffer) where {names}
188+
valid_range = valid_range_nbatchsampler(s, ts)
189+
valid_range = valid_range[valid_range findall(ts.sampleable_inds)] # Ensure that the valid range is within the sampleable indices
190+
inds = rand(s.rng, valid_range, s.batch_size)
191+
StatsBase.sample(s, ts, Val(names), inds)
192+
end
193+
194+
182195
function StatsBase.sample(nbs::NStepBatchSampler, ts, ::Val{SS′ART}, inds)
183196
if isnothing(nbs.stack_size)
184197
s = ts[:state][inds]

0 commit comments

Comments
 (0)