Skip to content

Commit 1407351

Browse files
committed
fix range
1 parent 5635aa1 commit 1407351

File tree

1 file changed

+7
-2
lines changed

1 file changed

+7
-2
lines changed

src/samplers.jl

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -258,18 +258,23 @@ end
258258
StatsBase.sample(s::EpisodesSampler{nothing}, t::EpisodesBuffer) = StatsBase.sample(s,t,keys(t))
259259
StatsBase.sample(s::EpisodesSampler{names}, t::EpisodesBuffer) where names = StatsBase.sample(s,t,names)
260260

261+
function make_episode(t::EpisodesBuffer, range, names)
262+
nt = NamedTuple{names}(map(x -> collect(t[Val(x)][range]), names))
263+
Episode(nt)
264+
end
265+
261266
function StatsBase.sample(::EpisodesSampler, t::EpisodesBuffer, names)
262267
ranges = UnitRange{Int}[]
263268
idx = 1
264269
while idx < length(t)
265270
if t.sampleable_inds[idx] == 1
266-
last_state_idx = idx + t.episodes_lengths[idx] - t.step_numbers[idx] + 1
271+
last_state_idx = idx + t.episodes_lengths[idx] - t.step_numbers[idx]
267272
push!(ranges,idx:last_state_idx)
268273
idx = last_state_idx + 1
269274
else
270275
idx += 1
271276
end
272277
end
273278

274-
return [Episode(NamedTuple{names}(map(x -> collect(t[Val(x)][r]), names))) for r in ranges]
279+
return [make_episode(t, r, names) for r in ranges]
275280
end

0 commit comments

Comments
 (0)