Skip to content

Commit 5b68fe6

Browse files
author
Dharanish
committed
Fix bugs with SARTSATraces
Bug 1 - When pushing more traces into CircularArraySARTSATraces than its capacity, the state and action traces are not in line anymore Bug 2 - sampleable_inds were not correct for CircularArraySARTSATraces Bug 3 - CircularArraySARTSATraces were not sampleable by a EpisodesSampler
1 parent f222bad commit 5b68fe6

File tree

4 files changed

+55
-23
lines changed

4 files changed

+55
-23
lines changed

src/common/CircularArraySARTSATraces.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,11 @@ function CircularArraySARTSATraces(;
2424
reward_eltype, reward_size = reward
2525
terminal_eltype, terminal_size = terminal
2626

27-
MultiplexTraces{SS′}(CircularArrayBuffer{state_eltype}(state_size..., capacity+1)) +
27+
MultiplexTraces{SS′}(CircularArrayBuffer{state_eltype}(state_size..., capacity+2)) +
2828
MultiplexTraces{AA′}(CircularArrayBuffer{action_eltype}(action_size..., capacity+1)) +
2929
Traces(
30-
reward=CircularArrayBuffer{reward_eltype}(reward_size..., capacity),
31-
terminal=CircularArrayBuffer{terminal_eltype}(terminal_size..., capacity),
30+
reward=CircularArrayBuffer{reward_eltype}(reward_size..., capacity+1),
31+
terminal=CircularArrayBuffer{terminal_eltype}(terminal_size..., capacity+1),
3232
)
3333
end
3434

src/episodes.jl

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,8 @@ fill_multiplex(eb::EpisodesBuffer) = fill_multiplex(eb.traces)
138138

139139
fill_multiplex(eb::EpisodesBuffer{<:Any,<:Any,<:CircularPrioritizedTraces}) = fill_multiplex(eb.traces.traces)
140140

141+
max_length(eb::EpisodesBuffer) = max_length(eb.traces)
142+
141143
function Base.push!(eb::EpisodesBuffer, xs::NamedTuple)
142144
push!(eb.traces, xs)
143145
partial = ispartial_insert(eb, xs)
@@ -146,10 +148,12 @@ function Base.push!(eb::EpisodesBuffer, xs::NamedTuple)
146148
push!(eb.episodes_lengths, 0)
147149
push!(eb.sampleable_inds, 0)
148150
elseif !partial #typical inserting
149-
if length(eb.traces) < length(eb) && length(eb) > 2 #case when PartialNamedTuple is used. Steps are indexable one step later
150-
eb.sampleable_inds[end-1] = 1
151-
else #case when we don't, length of traces and eb will match.
152-
eb.sampleable_inds[end] = 1 #previous step is now indexable
151+
if haskey(eb,:next_action) && length(eb) < max_length(eb) # if trace has next_action and lengths are mismatched
152+
if eb.step_numbers[end] > 1 # and if there are sufficient steps in the current episode
153+
eb.sampleable_inds[end-1] = 1 # steps are indexable one step later
154+
end
155+
else
156+
eb.sampleable_inds[end] = 1 # otherwise, previous step is now indexable
153157
end
154158
push!(eb.sampleable_inds, 0) #this one is no longer
155159
ep_length = last(eb.step_numbers)
@@ -172,6 +176,14 @@ function Base.push!(eb::EpisodesBuffer, xs::PartialNamedTuple) #wrap a NamedTupl
172176
eb.sampleable_inds[end-1] = 1 #completes the episode trajectory.
173177
end
174178

179+
function Base.push!(eb::EpisodesBuffer{<:Any,<:Any,<:CircularArraySARTSATraces}, xs::PartialNamedTuple)
180+
if max_length(eb) == capacity(eb.traces)
181+
popfirst!(eb)
182+
end
183+
push!(eb.traces, xs.namedtuple)
184+
eb.sampleable_inds[end-1] = 1 #completes the episode trajectory.
185+
end
186+
175187
for f in (:pop!, :popfirst!)
176188
@eval function Base.$f(eb::EpisodesBuffer)
177189
$f(eb.episodes_lengths)

test/common.jl

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
@test length(t) == 0
2525
end
2626

27-
@testset "CircularArraySARTSTraces" begin
27+
@testset "CircularArraySARTSATraces" begin
2828
t = CircularArraySARTSATraces(;
2929
capacity=3,
3030
state=Float32 => (2, 3),
@@ -35,13 +35,14 @@ end
3535

3636
@test t isa CircularArraySARTSATraces
3737

38-
push!(t, (state=ones(Float32, 2, 3), action=ones(Float32, 2)) |> gpu)
38+
push!(t, (state=ones(Float32, 2, 3),))
39+
push!(t, (action=ones(Float32, 2), next_state=ones(Float32, 2, 3) * 2) |> gpu)
3940
@test length(t) == 0
4041

4142
push!(t, (reward=1.0f0, terminal=false) |> gpu)
42-
@test length(t) == 0 # next_state and next_action is still missing
43+
@test length(t) == 0 # next_action is still missing
4344

44-
push!(t, (next_state=ones(Float32, 2, 3) * 2, next_action=ones(Float32, 2) * 2) |> gpu)
45+
push!(t, (state=ones(Float32, 2, 3) * 3, action=ones(Float32, 2) * 2) |> gpu)
4546
@test length(t) == 1
4647

4748
# this will trigger the scalar indexing of CuArray
@@ -55,17 +56,18 @@ end
5556
)
5657

5758
push!(t, (reward=2.0f0, terminal=false))
58-
push!(t, (state=ones(Float32, 2, 3) * 3, action=ones(Float32, 2) * 3) |> gpu)
59+
push!(t, (state=ones(Float32, 2, 3) * 4, action=ones(Float32, 2) * 3) |> gpu)
5960

6061
@test length(t) == 2
6162

6263
push!(t, (reward=3.0f0, terminal=false))
63-
push!(t, (state=ones(Float32, 2, 3) * 4, action=ones(Float32, 2) * 4) |> gpu)
64+
push!(t, (state=ones(Float32, 2, 3) * 5, action=ones(Float32, 2) * 4) |> gpu)
6465

6566
@test length(t) == 3
6667

6768
push!(t, (reward=4.0f0, terminal=false))
68-
push!(t, (state=ones(Float32, 2, 3) * 5, action=ones(Float32, 2) * 5) |> gpu)
69+
push!(t, (state=ones(Float32, 2, 3) * 6, action=ones(Float32, 2) * 5) |> gpu)
70+
push!(t, (reward=5.0f0, terminal=false))
6971

7072
@test length(t) == 3
7173

test/episodes.jl

Lines changed: 27 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,10 @@ using Test
100100
for i = 1:5
101101
push!(eb, (state = i+1, action =i, reward = i, terminal = false))
102102
@test eb.sampleable_inds[end] == 0
103-
@test eb.sampleable_inds[end-1] == 1
103+
@test eb.sampleable_inds[end-1] == 0
104+
if length(eb) >= 1
105+
@test eb.sampleable_inds[end-2] == 1
106+
end
104107
@test eb.step_numbers[end] == i + 1
105108
@test eb.episodes_lengths[end-i:end] == fill(i, i+1)
106109
end
@@ -116,24 +119,30 @@ using Test
116119
@test eb.sampleable_inds == [1,1,1,1,1,0,0]
117120
@test eb[:action][6] == 6
118121
@test eb[:next_action][5] == 6
119-
@test eb[6][:reward] == 0 #6 is not a valid index, the reward there is dummy, filled as zero
122+
@test eb[6][:reward] == 0 broken = true #6 is not a valid index and cannot be indexed because a PartialNamedTuple is used
120123
ep2_len = 0
121124
for (j,i) = enumerate(8:11)
122125
ep2_len += 1
123126
push!(eb, (state = i, action =i-1, reward = i-1, terminal = false))
124127
@test eb.sampleable_inds[end] == 0
125-
@test eb.sampleable_inds[end-1] == 1
128+
@test eb.sampleable_inds[end-1] == 0
129+
if eb.step_numbers[end] > 2
130+
@test eb.sampleable_inds[end-2] == 1
131+
end
126132
@test eb.step_numbers[end] == j + 1
127133
@test eb.episodes_lengths[end-j:end] == fill(ep2_len, ep2_len + 1)
128134
end
129-
@test eb.sampleable_inds == [1,1,1,1,1,0,1,1,1,1,0]
135+
@test eb.sampleable_inds == [1,1,1,1,1,0,1,1,1,0,0]
130136
@test length(eb.traces) == 9 #an action is missing at this stage
131137
#three last steps replace oldest steps in the buffer.
132138
for (i, s) = enumerate(12:13)
133139
ep2_len += 1
134140
push!(eb, (state = s, action =s-1, reward = s-1, terminal = false))
135141
@test eb.sampleable_inds[end] == 0
136-
@test eb.sampleable_inds[end-1] == 1
142+
@test eb.sampleable_inds[end-1] == 0
143+
if eb.step_numbers[end] > 2
144+
@test eb.sampleable_inds[end-2] == 1
145+
end
137146
@test eb.step_numbers[end] == i + 1 + 4
138147
@test eb.episodes_lengths[end-ep2_len:end] == fill(ep2_len, ep2_len + 1)
139148
end
@@ -298,7 +307,10 @@ using Test
298307
for i = 1:5
299308
push!(eb, (state = i+1, action =i, reward = i, terminal = false))
300309
@test eb.sampleable_inds[end] == 0
301-
@test eb.sampleable_inds[end-1] == 1
310+
@test eb.sampleable_inds[end-1] == 0
311+
if eb.step_numbers[end] > 2
312+
@test eb.sampleable_inds[end-2] == 1
313+
end
302314
@test eb.step_numbers[end] == i + 1
303315
@test eb.episodes_lengths[end-i:end] == fill(i, i+1)
304316
end
@@ -320,17 +332,23 @@ using Test
320332
ep2_len += 1
321333
push!(eb, (state = i, action =i-1, reward = i-1, terminal = false))
322334
@test eb.sampleable_inds[end] == 0
323-
@test eb.sampleable_inds[end-1] == 1
335+
@test eb.sampleable_inds[end-1] == 0
336+
if eb.step_numbers[end] > 2
337+
@test eb.sampleable_inds[end-2] == 1
338+
end
324339
@test eb.step_numbers[end] == j + 1
325340
@test eb.episodes_lengths[end-j:end] == fill(ep2_len, ep2_len + 1)
326341
end
327-
@test eb.sampleable_inds == [1,1,1,1,1,0,1,1,1,1,0]
342+
@test eb.sampleable_inds == [1,1,1,1,1,0,1,1,1,0,0]
328343
@test length(eb.traces) == 9 #an action is missing at this stage
329344
for (i, s) = enumerate(12:13)
330345
ep2_len += 1
331346
push!(eb, (state = s, action =s-1, reward = s-1, terminal = false))
332347
@test eb.sampleable_inds[end] == 0
333-
@test eb.sampleable_inds[end-1] == 1
348+
@test eb.sampleable_inds[end-1] == 0
349+
if eb.step_numbers[end] > 2
350+
@test eb.sampleable_inds[end-2] == 1
351+
end
334352
@test eb.step_numbers[end] == i + 1 + 4
335353
@test eb.episodes_lengths[end-ep2_len:end] == fill(ep2_len, ep2_len + 1)
336354
end

0 commit comments

Comments
 (0)